무효 클릭 IP 추적 중...
머신러닝,딥러닝

[sklearn] 'stratify' 의 역할(train_test_split)

꼬예 2023. 1. 5.
[sklearn] 'stratify' 의 역할(train_test_split)

  train_test_split  에서   stratify  가 뭐 하는 녀석인지 헷갈리는가?

stratify 예시

 

그렇다면 잘 들어왔다.

 

이번 포스팅에서는   stratify  미적용했을 때 어떤 문제가 발생하는지 알아보고,   stratify  를 통해 문제를 해결해볼 거다.

 

1) 예제 데이터 준비

df_2 = pd.DataFrame({'class_id': ['A', 'A', 'A', 'A', 'A','A' ,'B', 'B', 'B'],
                   'feature1': [1, 2, 3, 4, 5, 6, 7, 8, 9],})
                   
print(df_2)

 

df_2 출력

 

2)   stratify   미적용

from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df_2, test_size=0.2, random_state=2021)

  test_size=0.2   비율로 train, valid 데이터를 나누었다.

 

print(len(train_df))
print(len(val_df))

#output
7
2

  len  을 통해 개수를 확인해 보니   test_size   비율이 잘 적용되었다.

 

문제는 다음에 발생한다.

 

 train 데이터, valid 데이터에 속한 클래스 분포를 확인해 보자.

print(train_df['class_id'].value_counts())
#output
A    6
B    1

print(val_df['class_id'].value_counts())
#output
B    2

  A  클래스는 valid 데이터에 아예 할당되지 않았다.

다시 말해 클래스 분포에는   test_size   비율이 적용되지 않았다.

 

이렇게 되면 뭐가 문제일까?

 

train 단계에서   A  클래스에 대해 아무리 학습해도 valid 단계에서 검증할 수 없다.

다시 말하면 overfitting 여부 확인이 불가하다.

 

그때 사용하는 것이   stratify  다.

 

  stratify  는 기존 데이터를 나누는 것에 그치는 게 아니라, 클래스 분포 비율까지 맞춰 준다.

 

무슨 말인지 예시를 통해 확인하자.

 

3)   stratify   적용

train_df, val_df = train_test_split(df_2, test_size=0.2, stratify=df_2['class_id'], random_state=2021)

  stratify    df_2['class_id']  를 할당하였다

이는   'class_id'   별 분포를 비율에 따라 맞춰주라고 요청하는 거다.

 

print(train_df['class_id'].value_counts())
#output
A    5
B    2

print(val_df['class_id'].value_counts())
#output
B    1
A    1

 

이전과 다르게 valid 데이터에도   A  가 포함되었다.

 

물론 실제 개수가 비율 0.2로 딱 맞는 것은 아니다.

왜냐하면 갯수 자체가 나눠 떨어지지 않기 때문이다.

 

하지만 확실한 건 그 비율을 맞추기 위해 최대한 노력했다는 점이다.

 

  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글

꼬예님의
글이 좋았다면 응원을 보내주세요!