무효 클릭 IP 추적 중...
머신러닝,딥러닝/딥러닝

[pytorch] 모델 save/load 하는 방법

꼬예 2022. 12. 29.

파이토치 모델을 저장하고, 저장된 모델을 다시 불러오는 방법에 대해 알아보자.

 

 

이 글과 읽으면 좋은글

 

1. custom 모델 생성

예제에서 사용할 간단한 모델을 생성하자.

class LinearRegressionModel(nn.Module): 
    
    def __init__(self): 
        
        super().__init__()
        
        # Parameter 초기화
        self.weights = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) # Parameter 생성 (requires_grad=True)
        self.bias = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float))    # Parameter 생성 (requires_grad=True)
        

    def forward(self, x): # 순방향 전파 (Forward pass)
        
        return self.weights * x + self.bias # y = ax + b

 

2. [ic]model.parameters()[/ic] vs [ic]model.state_dict()[/ic]

위 메소드를 통해 모델에 저장된 파라미터값을 볼 수 있다.

 

[ic]model.parameters()[/ic]

model = LinearRegressionModel()

print(list(model.parameters()))

#output
'''
[Parameter containing:
 tensor([0.3367], requires_grad=True),
 Parameter containing:
 tensor([0.1288], requires_grad=True)]
'''

 

[ic]model.state_dict()[/ic]

model = LinearRegressionModel()

print(print(model.state_dict()))

#output
'''
OrderedDict([('weights', tensor([0.3367])), ('bias', tensor([0.1288]))])
'''

 

두 메소드 모두 파라미터 정보를 가지고 있다. 

 

차이가 있다면 [ic]model.parameters()[/ic]는 optimizer 인자로 사용된다.

optimizer = torch.optim.Adam(params=model.parameters(), # "parameters" to optimize (apply gradient descent)
                             lr=0.01)

 

반면 [ic]model.state_dict()[/ic] 는 학습이 완료된 모델을 저장할 때 넘겨준다.

 

3. 모델 저장/불러오기

모델을 저장/불러오기 할 때는 크게 두 가지 방식이 사용된다.

- 파라미터만 저장/불러오기

- 파라미터 + 모델 구조 저장/불러오기

 

1) 파라미터만 저장/불러오기

(1) 모델 저장

torch.save(obj=model.state_dict(), # only parameters
           f='파일명.pth') # .pt 확장자 사용가능
obj = 모델 파라미터, [ic]state_dict()[/ic] 함수 이용
f = 파일 저장경로

 

(2) 모델 불러오기

loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f='파일명.pth'))

파라미터만 불러오려면 모델 코드가 구현되어 있어야 한다.

먼저 모델 객체([ic]LinearRegressionModel()[/ic]) 을 생성한다.

 

[ic]torch.load[/ic]로 파라미터값을 읽은 뒤 [ic]load_state_dict[/ic]를 통해 모델 객체로 넘겨준다.

 

2) 파라미터 + 모델 구조 저장/불러오기

(1) 모델 저장

torch.save(obj=model, # entire model
           f='파일명.pth') # .pt 확장자 사용가능
obj = [ic]state_dict()[/ic]를 이용하지 않고 모델을 통째로 넘긴다.
f = 파일 저장 경로

 

(2) 모델 불러오기

loaded_model = torch.load('파일명.pth')

[ic]torch.load[/ic]를 통해 모델 객체와 파라미터 값을 한 번에 리턴 받는다.

  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글