본문 바로가기
딥러닝/프로젝트

[Kaggle] CT Medical Image - (5) Augmentation

by 혜 림 2021. 10. 1.
CT Medical Image Classification

(0) 서론, 액션플랜
(1) DICOM 파일 array로 전환 및 시각화
(2) DataLoader 만들기
(3) 모델 만들기
(4) 성능 지표 만들기

 

 

실패기록... 

Augmentation에서 유명한 라이브러리 중 하나인 albumentation을 이용해 보고 싶어서 만든 부분이었는데! 맘 아프다. 일단은 여기서 마무리하기로 해서 정리는 하기로 한다. 

 

일단 albumentation은 파이토치 기반의 라이브러리라고 한다. (파이토치 기반이기는 하지만 텐서플로우에서도 이용가능한 듯 ) 그리고 코드 작성 역시 비슷하다. 

 

0. 라이브러리 임포트

 

왜 줄여서 A라고 부르는지는 모르겠지만 간단해서 좋다. 

 

import albumentations as A

 

1. Custom Transform 생성

 

transform = A.Compose([
		# 원하는 augmentation 기법 작성
    A.HorizontalFlip(p=1),
])

 

 위와 같이 transform을 A.Compose를 이용해서 만들어준다.

 약간 파이토치에서 nn.Sequential 같은 느낌이다. 

 

 그 안에 원하는 augmentation 기법들을 작성해준다. 

 이번에 어렵게 하고 싶지 않아서 horizontalflip만 작성하였다. 

 

 그런데도 실패할 줄은...ㅜ

 

2. CustomDataset 수정하기

 

 이렇게 만든 transform을 만들기 위해서는 customdataset에다가 이 transform을 시행하는 코드를 넣어주어야 한다. 

 주석 사이가 코드를 변경한 부분이다. 

 

class CustomDataset(Dataset):
    def __init__(self,train_images,train_labels,transform=None):
    
        self.x_data = [dicom_2_array('/content/drive/MyDrive/video/archive/dicom_dir/' + path) for path in train_images]
        self.y_data = [[i] for i in train_labels]
        self.transform = transform

    def __len__(self):

        return len(self.x_data)

    def __getitem__(self, idx):

        x = torch.FloatTensor(self.x_data[idx])
        y = torch.FloatTensor(self.y_data[idx])
####################################################
        if self.transform:
          sample = self.transform(**{
                  'image': x,
                  'labels': y
              })
          x = sample['image']
          y = sample['label']
####################################################
        return x, y

 

 transform을 하게 되면 딕셔너리 형태로 augmentation한 데이터를 준다. 

 이때 image 아이디에 붙은 벨류 값과 id 아이디에 붙은 벨류 값이 각각 이미지와 라벨이다. 

 

만약 object detection이나 segmentation 을 시행한다면 여기가 좀 더 복잡해진다. 왜냐하면 bbox와 segmentation의 좌표 값 역시 수정해야 하기 때문이다. 하지만 우리가 지금 하는 것은 간단한 segmentation이니까 일단 나 몰라라~

 

dataset = CustomDataset(train_images, train_labels,transform=transform)
sampler = customsampler(train_labels, 4,4)
dataloader = DataLoader(dataset, batch_size=8,sampler=sampler)

total_batch = len(dataloader)

이제 customdataset 객체를 생성할 때 위처럼 인자로 transform을 추가해서 주면 완성된다~ 

 

3. 그러나 겪고야 만 error

 

 알게 되는대로 업데이트 할 예정이지만 아무래도 나는 잘 모르겠다.

 심지어 이 에러는 스택오버플로우에도 없음.. 대체 step이 뭘 말하는거지????

 

 

 

 그래도 한 번 공부는 해봤다~로 의의를 가진다. 나이스 츄라이.

 다음번엔 꼭 에러 없이 쓸 수 있기를! 

 

댓글