무효 클릭 IP 추적 중...
머신러닝,딥러닝/딥러닝

AutoEncoder(오토인코더) 예제 코드

꼬예 2022. 12. 19.
import torch
from torchvision import transforms, datasets
BATCH_SIZE = 64

trainset = datasets.FashionMNIST(
    root = './data/FASHIONMNIST/', # 저장할 공간
    train = True, # 학습 여부
    download = True, # downlaod 여부
    transform = transforms.ToTensor() # Tensor로 타입 변환
)

trainloader = torch.utils.data.DataLoader(
    dataset = trainset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 2
)

 

from torch import nn, optim

class AE(nn.Module):
    def __init__(self):
        super().__init__()
        
        # feature 크기를 줄임으로써 중요 feature추출
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Sigmoid(),
            nn.Linear(128, 64),
        ) 
        
        # 다시 원래 크기로 복원
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.Sigmoid(),
            nn.Linear(128, 28*28),
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

model = AE().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr = 0.005)
criterion = nn.MSELoss()

print("Model: ", model)
print("Device: ", DEVICE)

 

# Define Train & Evaluate

def train(model, train_loader, optimizer):
    model.train()
    for step, (x, label) in enumerate(train_loader):
        x = x.view(-1, 28 * 28).to(DEVICE)
        y = x.view(-1, 28 * 28).to(DEVICE)
        label = label.to(DEVICE)
        
        encoded, decoded = model(x)
        loss = criterion(decoded, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}"\
                  .format(epoch, step * len(x), len(train_loader.dataset),\
                  100. * step / len(train_loader), loss.item()))

 

# Training
import numpy as np 
import matplotlib.pyplot as plt

EPOCHS = 100
for epoch in range(1, EPOCHS + 1):
    train(model, trainloader, optimizer)
    test_x = view_data.to(DEVICE)
    encoded_data, decoded_data = model(test_x)
    f, a = plt.subplots(2, 5, figsize= (10, 4))
    print("[Epoch {}]".format(epoch))
    for idx in range(5):
        img = np.reshape(view_data.data.numpy()[idx], (28, 28))
        a[0][idx].imshow(img, cmap = "gray")
        a[0][idx].set_xticks(())
        a[0][idx].set_yticks(())
        
    for idx in range(5):
        img = np.reshape(decoded_data.to("cpu").data.numpy()[idx], (28, 28))
        a[1][idx].imshow(img, cmap = "gray")
        a[1][idx].set_xticks(())
        a[1][idx].set_yticks(())
    plt.show()
  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글