무효 클릭 IP 추적 중...
머신러닝,딥러닝

[sklearn] 'stratify' 의 역할(train_test_split)

꼬예 2023. 1. 5.

[ic]train_test_split[/ic]에서 [ic]stratify[/ic]가 뭐 하는 녀석인지 헷갈리는가?

stratify 예시

 

그렇다면 잘 들어왔다.

 

이번 포스팅에서는 [ic]stratify[/ic]를 미적용했을 때 어떤 문제가 발생하는지 알아보고, [ic]stratify[/ic]를 통해 문제를 해결해볼 거다.

 

1) 예제 데이터 준비

df_2 = pd.DataFrame({'class_id': ['A', 'A', 'A', 'A', 'A','A' ,'B', 'B', 'B'],
                   'feature1': [1, 2, 3, 4, 5, 6, 7, 8, 9],})
                   
print(df_2)

 

df_2 출력

 

2) [ic]stratify[/ic] 미적용

from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df_2, test_size=0.2, random_state=2021)

[ic]test_size=0.2[/ic] 비율로 train, valid 데이터를 나누었다.

 

print(len(train_df))
print(len(val_df))

#output
7
2

[ic]len[/ic]을 통해 개수를 확인해 보니 [ic]test_size[/ic] 비율이 잘 적용되었다.

 

문제는 다음에 발생한다.

 

 train 데이터, valid 데이터에 속한 클래스 분포를 확인해 보자.

print(train_df['class_id'].value_counts())
#output
A    6
B    1

print(val_df['class_id'].value_counts())
#output
B    2

[ic]A[/ic]클래스는 valid 데이터에 아예 할당되지 않았다.

다시 말해 클래스 분포에는 [ic]test_size[/ic] 비율이 적용되지 않았다.

 

이렇게 되면 뭐가 문제일까?

 

train 단계에서 [ic]A[/ic]클래스에 대해 아무리 학습해도 valid 단계에서 검증할 수 없다.

다시 말하면 overfitting 여부 확인이 불가하다.

 

그때 사용하는 것이 [ic]stratify[/ic]다.

 

[ic]stratify[/ic]는 기존 데이터를 나누는 것에 그치는 게 아니라, 클래스 분포 비율까지 맞춰 준다.

 

무슨 말인지 예시를 통해 확인하자.

 

3) [ic]stratify[/ic] 적용

train_df, val_df = train_test_split(df_2, test_size=0.2, stratify=df_2['class_id'], random_state=2021)

[ic]stratify[/ic] 에 [ic]df_2['class_id'][/ic]를 할당하였다

이는 [ic]'class_id'[/ic] 별 분포를 비율에 따라 맞춰주라고 요청하는 거다.

 

print(train_df['class_id'].value_counts())
#output
A    5
B    2

print(val_df['class_id'].value_counts())
#output
B    1
A    1

 

이전과 다르게 valid 데이터에도 [ic]A[/ic]가 포함되었다.

 

물론 실제 개수가 비율 0.2로 딱 맞는 것은 아니다.

왜냐하면 갯수 자체가 나눠 떨어지지 않기 때문이다.

 

하지만 확실한 건 그 비율을 맞추기 위해 최대한 노력했다는 점이다.

 

  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글