무효 클릭 IP 추적 중...
머신러닝,딥러닝/딥러닝

[pytorch] Subset 사용법 정리

꼬예 2022. 12. 29.

[ic]Subset[/ic] 모듈은 데이터셋을 관리하는데 편리함을 제공한다.

 

먼저 기본 컨셉부터 이해하자.

 

1) [ic]Subset[/ic] 기본 컨셉

(1) toy 데이터셋 준비

import torch
from torch.utils.data import Subset, DataLoader

# Create a dataset with 5 examples
dataset = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
print(dataset)

#output
'''
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10]])
'''

 

(2) 추출하고 싶은 데이터의 인덱스 넘버

# Create a list of indices to select a subset of the dataset
subset_indices = [0, 2, 4]

index 번호에 상응하는 데이터는 아래와 같다.

추출할 인덱스 데이터

 

(3) [ic]subset[/ic] 사용

# Create a Subset of the original dataset using the subset_indices
subset = Subset(dataset, subset_indices)
첫 번째 인자 => 전체 데이터셋
두 번째 인자 => 추려내고 싶은 데이터의 인덱스번호(리스트형태로 넘김)

 

아래코드를 확인해 보자. 

for a in subset:
    print(a)

#output
'''
tensor([1, 2])
tensor([5, 6])
tensor([ 9, 10])
'''

인덱스에 해당하는 데이터가 subset에 저장되었다.

 

2) Subset을 언제 사용할까?

  • training/validation 데이터 분리
  • 디버그/테스트용으로 소량의 데이터 추출
  • 데이터의 특정 부분만 원하는 대로 추출

 

3) 3번째 케이스 실습

3번째에 해당하는 '데이터의 특정 부분만 원하는 대로 추출'을 기준으로 예제를 다루겠다.

 

 

이 글을 읽기 전 선수 지식 포스팅

 

pretrained 모델과 Food101 데이터셋을 기반으로 한 예제다.

import torchvision

weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT # .DEFAULT = best available weights on ImageNet
transform_convnext_tiny = weights.transforms() # "transforms" that were used to train "weights"

train_data = torchvision.datasets.Food101(root="data_full",
                                          split="train",
                                          transform=transform_convnext_tiny,
                                          download=True)

 

데이터 정보

print('총 데이터 수: ', len(train_data)) 
print('class 수: ',len(train_data.classes))

#output
'''
총 데이터 수:  75750
class 수:  101
'''
print(train_data.class_to_idx)
#output
'''
{'apple_pie': 0, 'baby_back_ribs': 1, 'baklava': 2, 'beef_carpaccio': 3, 
'beef_tartare': 4, 'beet_salad': 5, 'beignets': 6, 'bibimbap': 7, 'bread_pudding': 8, 
'breakfast_burrito': 9, 'bruschetta': 10, 'caesar_salad': 11, 'cannoli': 12, 
'caprese_salad': 13, 'carrot_cake': 14, 'ceviche': 15, 'cheese_plate': 16, 
'cheesecake': 17, 'chicken_curry': 18, 'chicken_quesadilla': 19, 'chicken_wings': 20,
 'chocolate_cake': 21, 'chocolate_mousse': 22, 'churros': 23, 'clam_chowder': 24, 
'club_sandwich': 25, 'crab_cakes': 26, 'creme_brulee': 27, 'croque_madame': 28, 
'cup_cakes': 29, 'deviled_eggs': 30, 'donuts': 31, 'dumplings': 32, 'edamame': 33, 
'eggs_benedict': 34, 'escargots': 35, 'falafel': 36, 'filet_mignon': 37, 
'fish_and_chips': 38, 'foie_gras': 39, 'french_fries': 40, 'french_onion_soup': 41, 
'french_toast': 42, 'fried_calamari': 43, 'fried_rice': 44, 'frozen_yogurt': 45, 
'garlic_bread': 46, 'gnocchi': 47, 'greek_salad': 48, 'grilled_cheese_sandwich': 49, 
'grilled_salmon': 50, 'guacamole': 51, 'gyoza': 52, 'hamburger': 53,
 'hot_and_sour_soup': 54, 'hot_dog': 55, 'huevos_rancheros': 56, 'hummus': 57, 
'ice_cream': 58, 'lasagna': 59, 'lobster_bisque': 60, 'lobster_roll_sandwich': 61,
 'macaroni_and_cheese': 62, 'macarons': 63, 'miso_soup': 64, 'mussels': 65, 
'nachos': 66, 'omelette': 67, 'onion_rings': 68, 'oysters': 69, 'pad_thai': 70, 
'paella': 71, 'pancakes': 72, 'panna_cotta': 73, 'peking_duck': 74, 'pho': 75, 
'pizza': 76, 'pork_chop': 77, 'poutine': 78, 'prime_rib': 79,
 'pulled_pork_sandwich': 80, 'ramen': 81, 'ravioli': 82, 'red_velvet_cake': 83, 
'risotto': 84, 'samosa': 85, 'sashimi': 86, 'scallops': 87, 'seaweed_salad': 88, 
'shrimp_and_grits': 89, 'spaghetti_bolognese': 90, 'spaghetti_carbonara': 91, 
'spring_rolls': 92, 'steak': 93, 'strawberry_shortcake': 94, 'sushi': 95, 'tacos': 96, 
'takoyaki': 97, 'tiramisu': 98, 'tuna_tartare': 99, 'waffles': 100}
'''

 

이 중 3개 클래스([ic]pizza[/ic], [ic]steak[/ic], [ic]sush[/ic])에 해당하는 데이터만 사용하고 싶다.

 

3개 클래스의 라벨 번호를 추출한다.

for key in train_data.class_to_idx:
    if key in ['pizza', 'steak', 'sushi']:
        print(key, ':', train_data.class_to_idx[key])

# output
'''
pizza : 76
steak : 93
sushi : 95
'''

 

라벨 번호를 기반으로 전체데이터에서 사용할 index번호를 알아낸다.

target_idx = []

for idx, label in enumerate(train_data._labels): # len(train_data._labels) => 75750
    if label in [76, 93, 95]: # 전체 데이터중 해당 라벨 갑을 가진 데이터 idx 모으기
        target_idx.append(idx)
    else:
        pass

 

필요한 인덱스번호를 구했으므로 [ic]Subset[/ic]을 적용한다.

train_data = Subset(train_data, target_idx) # 원하는 라벨들로만 이루어진 데이터를 구할 수있다.

 

추출한 데이터를 [ic]DataLoader[/ic]에 넣어주면 손쉽게 데이터셋을 구성할 수 있다.

train_dataloader = DataLoader(train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=os.cpu_count(), # number of subprocesses to use for data loading
                              shuffle=True,
                              pin_memory=True)

 

이 글과 읽으면 좋은 글

  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글