무효 클릭 IP 추적 중...
머신러닝,딥러닝/computer vision

mmdetection 커스텀/모델 학습 완벽정리

꼬예 2023. 1. 17.

(이 포스팅은 해당 포스팅을 읽고 왔다는 전제로 작성되었습니다.)

 

이제 본격적으로 custom을 통해 우리 데이터를 학습시켜보자.

디테일한 내용은 큰 흐름을 방해할 수 있기에 중요 부분을 위주로 정리한다.

 

mmdetection에서는 다양한 데이터 셋 형태를 지원한다.

 

개인적으로 코코데이터셋mmdetection을 사용할때 가장 편리해  CocoData를 기준으로 한다.

 

 

이 글과 읽으면 좋은글

 

Config파일 Custom

사용할 데이터셋 지정

from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset

@DATASETS.register_module(force=True)
class CupDataset(CocoDataset):
    CLASSES = ('stain', 'crumpled')

[ic]CoCoDataset[/ic] 클래스를 상속받은 클래스를 생성한다.

예제에서는 [ic]CupDataset[/ic]라는 클래스를 생성하였다. 이름은 원하는 명칭을 사용하면 된다.

[ic]CLASSES[/ic]에는 사용할 class name을 tuple형태로 지정해준다.

 

디폴트 config파일 불러오기

사용할 pretrained 모델에 상응하는 config파일을 불러온다.

from mmcv import Config

config_file = 'mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'

cfg = Config.fromfile(config_file)
print(cfg.pretty_text)

 

coco데이터셋 디렉토리 구조 만들기

예제에서 사용할 데이터셋 디렉토리 구조다.

예제 데이터 디렉토리 구조

 

PNGImages 전체이미지가 담긴 디렉토리
train.json train 데이터 annotation
train.txt(optional) train 데이터 파일 경로
valid.json valid 데이터 annotaiton
valid.txt(optional) valid 데이터 파일 경로

 

코코데이터 셋을 사용하기 전 위 디렉토리 구조를 만들어 준다.

[ic].txt파일[/ic]은 선택사항이지만 나머지 파일들은 꼭 세팅 해야 한다.

 

데이터셋 정보 config에 입력

config 파일은 dictonary형태로 저장 되어 있다.

수정하려면 딕셔너리 수정방식을 이용하면 된다.

 

아래 코드는 데이터셋 정보를 수정한 예제다.

from mmdet.apis import set_random_seed
import mmcv
import os.path as osp

cfg.dataset_type = 'CupDataset'
cfg.data_root = 'CupDataCoco/'

#train데이터 데이터셋 정보기입
cfg.data.train.type = 'CupDataset'
cfg.data.train.data_root ='CupDataCoco/'
cfg.data.train.ann_file = 'train.json'
cfg.data.train.img_prefix = 'PNGImages'

#valid 데이터 데이터셋 정보기입
cfg.data.val.type = 'CupDataset'
cfg.data.val.data_root ='CupDataCoco/'
cfg.data.val.ann_file = 'valid.json'
cfg.data.val.img_prefix = 'PNGImages'

#test데이터 데이터셋 정보기입
cfg.data.test.type = 'CupDataset'
cfg.data.test.data_root ='CupDataCoco/'
cfg.data.test.ann_file = 'valid.json'
cfg.data.test.img_prefix = 'PNGImages'

다소 중복되는 내용이 많아 보이나 이게 mmdetection의 규칙이니 따라야한다.

train, valid, test 모두 반복되는 코드임으로 이중 하나만 이해하면 충분하다.

.type 등록한 클래스명으로 작성
.data_root 데이터셋의 root가되는 디렉토리명
.ann_file annotation 파일명
.img_prefix 이미지가 들어있는 디렉토리명

 

내부적으로 [ic]data_root[/ic] 경로를 기반으로 ann_file, img_prefix 경로를 합치기 때문에 앞서 설명한 데이터 구조를 맞춰주는게 중요하다.

ex)
.data_root = ‘CupDataCoco/'
.ann_file = 'train.json'

⇒ 최종적으로 모델이 읽는 ann_file 경로는 ? ‘CupDataCoco/train.json'

 

class 갯수 변경

클래스 디폴트 상태

코코데이터 set으로 학습한 모델(faster-rcnn)이라 디폴트는 80개 설정되어있다.

 

우리 데이터의 클래스 갯수는 2개임으로 그에 맞춰 변경한다.

cfg.model.roi_head.bbox_head.num_classes = 2
모델에 따라 구조가 다르기때문에 어떤 모델은 roi_head밑에 num_classes가 없는 경우가 있다.
이럴때는 config파일을 보며 적절하게 변경해주면된다.

 

pretrained모델 불러오기/저장

cfg.load_from = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
cfg.work_dir = './tutorial_exps_faster_rcnn'

 

.load_form  pretrained 모델 경로 입력
.work_dir 학습 완료 후 생성된 weight 저장 경로 입력

 

Optimizer 정보 입력

cfg.optimizer.lr = 0.02 
cfg.log_config.interval = 10
cfg.runner.max_epochs = 20
.optimizer.lr  learning rate 변경
.log_config.interval log정보를 몇 에폭마다 보여줄지 정한다.
.runner.max_epochs 에폭수 지정

 

기타

# CoCoDataset의 경우 metric을 bbox로 설정 해야함(maP아님. bbox로 설정하면 mAP를 iou threshold 0.5 ~ 0.95까지 변경하면서 측정)
cfg.evaluation.metric = 'bbox'

cfg.evaluation.interval = 12
cfg.checkpoint_config.interval = 12
cfg.lr_config.policy='step'

cfg.device = 'cuda'
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

mmdetection 업데이트마다 새롭게 요구하는 값들이 있다.

현재 기준 필요한 config값들을 셋팅하였다.

 

Learn about Configs — MMDetection 3.2.0 documentation

Learn about Configs MMDetection and other OpenMMLab repositories use MMEngine’s config system. It has a modular and inheritance design, which is convenient to conduct various experiments. Config file content MMDetection uses a modular design, all modules

mmdetection.readthedocs.io

수정 완료된 config 파일은 [ic].dump[/ic]를 통해 새로운 [ic].py[/ic]파일로 만든다.

cfg.dump('faster_rcnn_config.py')

 

config 파일 등록

from mmdet.datasets import build_dataset

datasets = [build_dataset(cfg.data.train)]

#output
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!

정상적으로 완료되면 위와 같은 output이 나온다.

무언가 잘못되었을 경우 친절한(?) 에러메세지가 나오니, 에러를 보며 config파일을 수정해주면 된다.

 

[ic]datasets[0][/ic]을 출력하면 mmdetection이 파악한 데이터 정보를 확인할 수 있다.

print(datasets[0])

#output
CupDataset Train dataset with number of images 98, and instance counts: 
+-----------+-------+--------------+-------+----------+-------+----------+-------+----------+-------+
| category  | count | category     | count | category | count | category | count | category | count |
+-----------+-------+--------------+-------+----------+-------+----------+-------+----------+-------+
|           |       |              |       |          |       |          |       |          |       |
| 0 [stain] | 747   | 1 [crumpled] | 44    |          |       |          |       |          |       |
+-----------+-------+--------------+-------+----------+-------+----------+-------+----------+-------+

 

작성한 config파일을 기반으로 모델을 build하고 학습해보자.

 

모델 학습(train)/추론(inference)

모델 build

from mmdet.models import build_detector
from mmdet.apis import train_detector

model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
model.CLASSES = datasets[0].CLASSES

 

모델 weight를 저장할 폴더를 만든다.

mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))

 

모델 학습(train)

train_detector(model, datasets, cfg, distributed=False, validate=True)

[ic]validate=True[/ic]로 설정하면 학습하면서 valid evaluation을 동시에 실행한다.

 

모델 추론(inference)

모델 정의

from mmdet.apis import inference_detector, init_detector , show_result_pyplot
import cv2

checkpoint_file = 'tutorial_exps_faster_rcnn/epoch_20.pth'
model_ckpt = init_detector(cfg, checkpoint_file, device='cuda:0')
cfg  Config.fromfile()로 생성한 config정보
checkpoin_file  학습한 weight파일 경로

 

모델 테스트

# 이미지(test용) 불러오기
img = cv2.imread('CupDataYolo/valid/save_221216_163506.png')

result = inference_detector(model_ckpt, img)
show_result_pyplot(model_ckpt, img, result, score_thr=0.3)

(테스트 코드가 낯선분은 해당 포스팅 참조)

  • 트위터 공유하기
  • 페이스북 공유하기
  • 카카오톡 공유하기
이 컨텐츠가 마음에 드셨다면 커피 한잔(후원) ☕

댓글