무효 클릭 IP 추적 중...
머신러닝,딥러닝/NLP

[keras] 패딩(padding) 하는법 with pad_sequences

꼬예 2023. 1. 2.
[keras] 패딩(padding) 하는법 with pad_sequences

자연어처리에서   Padding  을 왜 해줄까?

 

각 문장 길이가 다르면 병렬 연산이 어렵기 때문이다.

 

이때 패딩을 통해 고정된 길이를 맞춰주면 작업이 용이하다.

 

1) 데이터 준비

길이가 다른 각 문장을 정수 인코딩한 형태다.

data = [
                 [1, 2, 3], 
                 [4, 5], 
                 [6, 7, 8, 9, 10, 11]
        ]

 

 

2)   pad_sequences   사용법

keras에서 제공하는   pad_sequences   를 통해 패딩을 적용해보겠다.

from keras_preprocessing.sequence import pad_sequences

 

(1) default 세팅

data_padded = pad_sequences(sequences=data)
print(data_padded)

#output
'''
[[ 0  0  0  1  2  3]
 [ 0  0  0  0  4  5]
 [ 6  7  8  9 10 11]]
'''

 

제일 긴녀석(3행)을 기준으로 값(0)을 채운다.

[keras] 패딩(padding) 하는법 with pad_sequences - undefined - 2)   pad_sequences   사용법

 

(2)   padding   방향 설정

# padding post 적용
data_padded = pad_sequences(sequences=data, padding='post') #default padding='pre'
print(data_padded)

#output
'''
[[ 1  2  3  0  0  0]
 [ 4  5  0  0  0  0]
 [ 6  7  8  9 10 11]]
'''

  padding='post'  는 뒤에서부터 0을 채우는 기능이다.

default는   padding='pre'  이다.

 

(3)   value   적용

data_padded = pad_sequences(sequences=data, value=9999) 
print(data_padded)

#output
'''
[[9999 9999 9999    1    2    3]
 [9999 9999 9999 9999    4    5]
 [   6    7    8    9   10   11]]
'''

  value  값은 기존  '0'을 다른 값으로 변경하는 인자다.

 

예제에서는 '9999'로 변경해보았다.

 

(4)   maxlen   적용

지금까지는 제일 긴문장 기준으로 패딩하였다.

  maxlen  을 통해 임의로 기준을 변경할 수 있다.

data_padded = pad_sequences(sequences=data, maxlen=10) 
print(data_padded)

#output
'''
[[ 0  0  0  0  0  0  0  1  2  3]
 [ 0  0  0  0  0  0  0  0  4  5]
 [ 0  0  0  0  6  7  8  9 10 11]]
'''

  maxlen  을 10으로 적용한 코드다.

길이 10을 기준으로 부족한 수만큼 0으로 패딩하였다.

 

심지어 줄이는것도 가능하다.

data_padded = pad_sequences(sequences=data, maxlen=3) 
print(data_padded)

#output
'''
[[ 1  2  3]
 [ 0  4  5]
 [ 9 10 11]]
'''

줄일때 주의할 부분은 정보가 소실된다는거다.

[ 6 7 8 9 10 11] ⇒ [ 9 10 11] [6 7 8 ] 정보 소실.

앞 정보부터 없어졌다.

  truncating  를 통해 앞부터 없앨지, 뒤부터 없앨지 정할 수 있다.

 

(5)   truncating   적용

  truncating='post'   적용시 뒤에서부터 truncate한다.(default는   truncating='pre'  )

data_padded = pad_sequences(sequences=data, maxlen=3, truncating='post') 
print(data_padded) 

#output
'''
[[1 2 3]
 [0 4 5]
 [6 7 8]]
'''
[ 6 7 8 9 10 11] ⇒ [ 6 7 8] [9 10 11 ] 정보 소실.
  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글

꼬예님의
글이 좋았다면 응원을 보내주세요!