[ic]Subset[/ic] 모듈은 데이터셋을 관리하는데 편리함을 제공한다.
먼저 기본 컨셉부터 이해하자.
1) [ic]Subset[/ic] 기본 컨셉
(1) toy 데이터셋 준비
import torch
from torch.utils.data import Subset, DataLoader
# Create a dataset with 5 examples
dataset = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
print(dataset)
#output
'''
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10]])
'''
(2) 추출하고 싶은 데이터의 인덱스 넘버
# Create a list of indices to select a subset of the dataset
subset_indices = [0, 2, 4]
index 번호에 상응하는 데이터는 아래와 같다.
(3) [ic]subset[/ic] 사용
# Create a Subset of the original dataset using the subset_indices
subset = Subset(dataset, subset_indices)
첫 번째 인자 => 전체 데이터셋
두 번째 인자 => 추려내고 싶은 데이터의 인덱스번호(리스트형태로 넘김)
아래코드를 확인해 보자.
for a in subset:
print(a)
#output
'''
tensor([1, 2])
tensor([5, 6])
tensor([ 9, 10])
'''
인덱스에 해당하는 데이터가 subset에 저장되었다.
2) Subset을 언제 사용할까?
- training/validation 데이터 분리
- 디버그/테스트용으로 소량의 데이터 추출
- 데이터의 특정 부분만 원하는 대로 추출
3) 3번째 케이스 실습
3번째에 해당하는 '데이터의 특정 부분만 원하는 대로 추출'을 기준으로 예제를 다루겠다.
이 글을 읽기 전 선수 지식 포스팅
pretrained 모델과 Food101 데이터셋을 기반으로 한 예제다.
import torchvision
weights = torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT # .DEFAULT = best available weights on ImageNet
transform_convnext_tiny = weights.transforms() # "transforms" that were used to train "weights"
train_data = torchvision.datasets.Food101(root="data_full",
split="train",
transform=transform_convnext_tiny,
download=True)
데이터 정보
print('총 데이터 수: ', len(train_data))
print('class 수: ',len(train_data.classes))
#output
'''
총 데이터 수: 75750
class 수: 101
'''
print(train_data.class_to_idx)
#output
'''
{'apple_pie': 0, 'baby_back_ribs': 1, 'baklava': 2, 'beef_carpaccio': 3,
'beef_tartare': 4, 'beet_salad': 5, 'beignets': 6, 'bibimbap': 7, 'bread_pudding': 8,
'breakfast_burrito': 9, 'bruschetta': 10, 'caesar_salad': 11, 'cannoli': 12,
'caprese_salad': 13, 'carrot_cake': 14, 'ceviche': 15, 'cheese_plate': 16,
'cheesecake': 17, 'chicken_curry': 18, 'chicken_quesadilla': 19, 'chicken_wings': 20,
'chocolate_cake': 21, 'chocolate_mousse': 22, 'churros': 23, 'clam_chowder': 24,
'club_sandwich': 25, 'crab_cakes': 26, 'creme_brulee': 27, 'croque_madame': 28,
'cup_cakes': 29, 'deviled_eggs': 30, 'donuts': 31, 'dumplings': 32, 'edamame': 33,
'eggs_benedict': 34, 'escargots': 35, 'falafel': 36, 'filet_mignon': 37,
'fish_and_chips': 38, 'foie_gras': 39, 'french_fries': 40, 'french_onion_soup': 41,
'french_toast': 42, 'fried_calamari': 43, 'fried_rice': 44, 'frozen_yogurt': 45,
'garlic_bread': 46, 'gnocchi': 47, 'greek_salad': 48, 'grilled_cheese_sandwich': 49,
'grilled_salmon': 50, 'guacamole': 51, 'gyoza': 52, 'hamburger': 53,
'hot_and_sour_soup': 54, 'hot_dog': 55, 'huevos_rancheros': 56, 'hummus': 57,
'ice_cream': 58, 'lasagna': 59, 'lobster_bisque': 60, 'lobster_roll_sandwich': 61,
'macaroni_and_cheese': 62, 'macarons': 63, 'miso_soup': 64, 'mussels': 65,
'nachos': 66, 'omelette': 67, 'onion_rings': 68, 'oysters': 69, 'pad_thai': 70,
'paella': 71, 'pancakes': 72, 'panna_cotta': 73, 'peking_duck': 74, 'pho': 75,
'pizza': 76, 'pork_chop': 77, 'poutine': 78, 'prime_rib': 79,
'pulled_pork_sandwich': 80, 'ramen': 81, 'ravioli': 82, 'red_velvet_cake': 83,
'risotto': 84, 'samosa': 85, 'sashimi': 86, 'scallops': 87, 'seaweed_salad': 88,
'shrimp_and_grits': 89, 'spaghetti_bolognese': 90, 'spaghetti_carbonara': 91,
'spring_rolls': 92, 'steak': 93, 'strawberry_shortcake': 94, 'sushi': 95, 'tacos': 96,
'takoyaki': 97, 'tiramisu': 98, 'tuna_tartare': 99, 'waffles': 100}
'''
이 중 3개 클래스([ic]pizza[/ic], [ic]steak[/ic], [ic]sush[/ic])에 해당하는 데이터만 사용하고 싶다.
3개 클래스의 라벨 번호를 추출한다.
for key in train_data.class_to_idx:
if key in ['pizza', 'steak', 'sushi']:
print(key, ':', train_data.class_to_idx[key])
# output
'''
pizza : 76
steak : 93
sushi : 95
'''
라벨 번호를 기반으로 전체데이터에서 사용할 index번호를 알아낸다.
target_idx = []
for idx, label in enumerate(train_data._labels): # len(train_data._labels) => 75750
if label in [76, 93, 95]: # 전체 데이터중 해당 라벨 갑을 가진 데이터 idx 모으기
target_idx.append(idx)
else:
pass
필요한 인덱스번호를 구했으므로 [ic]Subset[/ic]을 적용한다.
train_data = Subset(train_data, target_idx) # 원하는 라벨들로만 이루어진 데이터를 구할 수있다.
추출한 데이터를 [ic]DataLoader[/ic]에 넣어주면 손쉽게 데이터셋을 구성할 수 있다.
train_dataloader = DataLoader(train_data,
batch_size=BATCH_SIZE,
num_workers=os.cpu_count(), # number of subprocesses to use for data loading
shuffle=True,
pin_memory=True)
이 글과 읽으면 좋은 글
'머신러닝,딥러닝 > 딥러닝' 카테고리의 다른 글
[pytorch] nn.BCEWithLogitsLoss VS nn.BCELoss 차이 (0) | 2022.12.30 |
---|---|
[pytorch] 모델 save/load 하는 방법 (0) | 2022.12.29 |
[딥러닝] Fine Tuning(미세 조정) 꿀 tip (0) | 2022.12.27 |
[pytorch] pretrained model 쉽게 사용하는 방법 (0) | 2022.12.26 |
[pytorch] nn.Dropout inplace 역할은 무엇일까? (0) | 2022.12.23 |
댓글