무효 클릭 IP 추적 중...
머신러닝,딥러닝/딥러닝

AutoEncoder(오토인코더) 예제 코드

꼬예 2022. 12. 19.
AutoEncoder(오토인코더) 예제 코드

목차

    import torch
    from torchvision import transforms, datasets
    BATCH_SIZE = 64
    
    trainset = datasets.FashionMNIST(
        root = './data/FASHIONMNIST/', # 저장할 공간
        train = True, # 학습 여부
        download = True, # downlaod 여부
        transform = transforms.ToTensor() # Tensor로 타입 변환
    )
    
    trainloader = torch.utils.data.DataLoader(
        dataset = trainset,
        batch_size = BATCH_SIZE,
        shuffle = True,
        num_workers = 2
    )

     

    from torch import nn, optim
    
    class AE(nn.Module):
        def __init__(self):
            super().__init__()
            
            # feature 크기를 줄임으로써 중요 feature추출
            self.encoder = nn.Sequential(
                nn.Linear(28*28, 128),
                nn.Sigmoid(),
                nn.Linear(128, 64),
            ) 
            
            # 다시 원래 크기로 복원
            self.decoder = nn.Sequential(
                nn.Linear(64, 128),
                nn.Sigmoid(),
                nn.Linear(128, 28*28),
            )
        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return encoded, decoded
        
    USE_CUDA = torch.cuda.is_available()
    DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
    
    model = AE().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr = 0.005)
    criterion = nn.MSELoss()
    
    print("Model: ", model)
    print("Device: ", DEVICE)

     

    # Define Train & Evaluate
    
    def train(model, train_loader, optimizer):
        model.train()
        for step, (x, label) in enumerate(train_loader):
            x = x.view(-1, 28 * 28).to(DEVICE)
            y = x.view(-1, 28 * 28).to(DEVICE)
            label = label.to(DEVICE)
            
            encoded, decoded = model(x)
            loss = criterion(decoded, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step % 100 == 0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}"\
                      .format(epoch, step * len(x), len(train_loader.dataset),\
                      100. * step / len(train_loader), loss.item()))

     

    # Training
    import numpy as np 
    import matplotlib.pyplot as plt
    
    EPOCHS = 100
    for epoch in range(1, EPOCHS + 1):
        train(model, trainloader, optimizer)
        test_x = view_data.to(DEVICE)
        encoded_data, decoded_data = model(test_x)
        f, a = plt.subplots(2, 5, figsize= (10, 4))
        print("[Epoch {}]".format(epoch))
        for idx in range(5):
            img = np.reshape(view_data.data.numpy()[idx], (28, 28))
            a[0][idx].imshow(img, cmap = "gray")
            a[0][idx].set_xticks(())
            a[0][idx].set_yticks(())
            
        for idx in range(5):
            img = np.reshape(decoded_data.to("cpu").data.numpy()[idx], (28, 28))
            a[1][idx].imshow(img, cmap = "gray")
            a[1][idx].set_xticks(())
            a[1][idx].set_yticks(())
        plt.show()
    • 트위터 공유하기
    • 페이스북 공유하기
    • 카카오톡 공유하기
    이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

    댓글

    꼬예님의
    글이 좋았다면 응원을 보내주세요!