파이토치 모델을 저장하고, 저장된 모델을 다시 불러오는 방법에 대해 알아보자.
이 글과 읽으면 좋은글
- [딥러닝] Fine Tuning(미세 조정) 꿀 tip
- [pytorch] model.eval() vs torch.no_grad() 차이
- [pytorch] transforms.Compose 사용 방법
1. custom 모델 생성
예제에서 사용할 간단한 모델을 생성하자.
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
# Parameter 초기화
self.weights = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) # Parameter 생성 (requires_grad=True)
self.bias = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) # Parameter 생성 (requires_grad=True)
def forward(self, x): # 순방향 전파 (Forward pass)
return self.weights * x + self.bias # y = ax + b
2. [ic]model.parameters()[/ic] vs [ic]model.state_dict()[/ic]
위 메소드를 통해 모델에 저장된 파라미터값을 볼 수 있다.
[ic]model.parameters()[/ic]
model = LinearRegressionModel()
print(list(model.parameters()))
#output
'''
[Parameter containing:
tensor([0.3367], requires_grad=True),
Parameter containing:
tensor([0.1288], requires_grad=True)]
'''
[ic]model.state_dict()[/ic]
model = LinearRegressionModel()
print(print(model.state_dict()))
#output
'''
OrderedDict([('weights', tensor([0.3367])), ('bias', tensor([0.1288]))])
'''
두 메소드 모두 파라미터 정보를 가지고 있다.
차이가 있다면 [ic]model.parameters()[/ic]는 optimizer 인자로 사용된다.
optimizer = torch.optim.Adam(params=model.parameters(), # "parameters" to optimize (apply gradient descent)
lr=0.01)
반면 [ic]model.state_dict()[/ic] 는 학습이 완료된 모델을 저장할 때 넘겨준다.
3. 모델 저장/불러오기
모델을 저장/불러오기 할 때는 크게 두 가지 방식이 사용된다.
- 파라미터만 저장/불러오기
- 파라미터 + 모델 구조 저장/불러오기
1) 파라미터만 저장/불러오기
(1) 모델 저장
torch.save(obj=model.state_dict(), # only parameters
f='파일명.pth') # .pt 확장자 사용가능
obj = 모델 파라미터, [ic]state_dict()[/ic] 함수 이용
f = 파일 저장경로
(2) 모델 불러오기
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f='파일명.pth'))
파라미터만 불러오려면 모델 코드가 구현되어 있어야 한다.
먼저 모델 객체([ic]LinearRegressionModel()[/ic]) 을 생성한다.
[ic]torch.load[/ic]로 파라미터값을 읽은 뒤 [ic]load_state_dict[/ic]를 통해 모델 객체로 넘겨준다.
2) 파라미터 + 모델 구조 저장/불러오기
(1) 모델 저장
torch.save(obj=model, # entire model
f='파일명.pth') # .pt 확장자 사용가능
obj = [ic]state_dict()[/ic]를 이용하지 않고 모델을 통째로 넘긴다.
f = 파일 저장 경로
(2) 모델 불러오기
loaded_model = torch.load('파일명.pth')
[ic]torch.load[/ic]를 통해 모델 객체와 파라미터 값을 한 번에 리턴 받는다.
'머신러닝,딥러닝 > 딥러닝' 카테고리의 다른 글
[pytorch] nn.BCEWithLogitsLoss VS nn.BCELoss 차이 (0) | 2022.12.30 |
---|---|
[pytorch] Subset 사용법 정리 (0) | 2022.12.29 |
[딥러닝] Fine Tuning(미세 조정) 꿀 tip (0) | 2022.12.27 |
[pytorch] pretrained model 쉽게 사용하는 방법 (0) | 2022.12.26 |
[pytorch] nn.Dropout inplace 역할은 무엇일까? (0) | 2022.12.23 |
댓글