무효 클릭 IP 추적 중...
머신러닝,딥러닝/NLP

[keras] 패딩(padding) 하는법 with pad_sequences

꼬예 2023. 1. 2.

자연어처리에서 [ic]Padding[/ic]을 왜 해줄까?

 

각 문장 길이가 다르면 병렬 연산이 어렵기 때문이다.

 

이때 패딩을 통해 고정된 길이를 맞춰주면 작업이 용이하다.

 

1) 데이터 준비

길이가 다른 각 문장을 정수 인코딩한 형태다.

data = [
                 [1, 2, 3], 
                 [4, 5], 
                 [6, 7, 8, 9, 10, 11]
        ]

 

 

2) [ic]pad_sequences[/ic] 사용법

keras에서 제공하는 [ic]pad_sequences[/ic] 를 통해 패딩을 적용해보겠다.

from keras_preprocessing.sequence import pad_sequences

 

(1) default 세팅

data_padded = pad_sequences(sequences=data)
print(data_padded)

#output
'''
[[ 0  0  0  1  2  3]
 [ 0  0  0  0  4  5]
 [ 6  7  8  9 10 11]]
'''

 

제일 긴녀석(3행)을 기준으로 값(0)을 채운다.

 

(2) [ic]padding[/ic] 방향 설정

# padding post 적용
data_padded = pad_sequences(sequences=data, padding='post') #default padding='pre'
print(data_padded)

#output
'''
[[ 1  2  3  0  0  0]
 [ 4  5  0  0  0  0]
 [ 6  7  8  9 10 11]]
'''

[ic]padding='post'[/ic]는 뒤에서부터 0을 채우는 기능이다.

default는 [ic]padding='pre'[/ic]이다.

 

(3) [ic]value[/ic] 적용

data_padded = pad_sequences(sequences=data, value=9999) 
print(data_padded)

#output
'''
[[9999 9999 9999    1    2    3]
 [9999 9999 9999 9999    4    5]
 [   6    7    8    9   10   11]]
'''

[ic]value[/ic]값은 기존  '0'을 다른 값으로 변경하는 인자다.

 

예제에서는 '9999'로 변경해보았다.

 

(4) [ic]maxlen[/ic] 적용

지금까지는 제일 긴문장 기준으로 패딩하였다.

[ic]maxlen[/ic]을 통해 임의로 기준을 변경할 수 있다.

data_padded = pad_sequences(sequences=data, maxlen=10) 
print(data_padded)

#output
'''
[[ 0  0  0  0  0  0  0  1  2  3]
 [ 0  0  0  0  0  0  0  0  4  5]
 [ 0  0  0  0  6  7  8  9 10 11]]
'''

[ic]maxlen[/ic]을 10으로 적용한 코드다.

길이 10을 기준으로 부족한 수만큼 0으로 패딩하였다.

 

심지어 줄이는것도 가능하다.

data_padded = pad_sequences(sequences=data, maxlen=3) 
print(data_padded)

#output
'''
[[ 1  2  3]
 [ 0  4  5]
 [ 9 10 11]]
'''

줄일때 주의할 부분은 정보가 소실된다는거다.

[ 6 7 8 9 10 11] ⇒ [ 9 10 11] [6 7 8 ] 정보 소실.

앞 정보부터 없어졌다.

[ic]truncating[/ic]를 통해 앞부터 없앨지, 뒤부터 없앨지 정할 수 있다.

 

(5)[ic] truncating[/ic] 적용

[ic]truncating='post'[/ic] 적용시 뒤에서부터 truncate한다.(default는 [ic]truncating='pre'[/ic])

data_padded = pad_sequences(sequences=data, maxlen=3, truncating='post') 
print(data_padded) 

#output
'''
[[1 2 3]
 [0 4 5]
 [6 7 8]]
'''
[ 6 7 8 9 10 11] ⇒ [ 6 7 8] [9 10 11 ] 정보 소실.
  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글