ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [논문 리뷰 및 코드구현] UNet++ (Nested UNet)
    Image Segmentation 2020. 10. 10. 22:12

    [Review] UNet++: A Nested U-Net Architecture for Medical Image Segmentation,
    DLMIA(Deep Learning Medical Image Analysis) 2018

     

     이번 포스팅은 객체를 인식하는 방법 중 하나인 U-Net의 업그레이드 버전인 U-Net++ (Nested U-Net) 논문을 살펴보겠습니다. 이 포스팅은 U-Net++U-Net논문, 그리고 MEDIUM 블로그를 참고하여 작성하였습니다.

    객체를 인식하는 방법에는 아래 그림과 같이 크게 Image Classification, Detection, Segmentation로 세 가지가 있습니다.

    이미지 인식 방법 예시 (출처)

     

     U-Net은 이 중에서 Segmentation을 목적으로 제안된 End-to-End 방식의 Fully-Convolutional 기반의 모델입니다. 특히, 의료 분야의 이미지를 다루기 위한 목적으로 제안되었습니다.

    U-Net 구조

     U-Net은 크게 2개의 영역으로 구분되어 있다고 할 수 있습니다.

    1. Contracting Path(수축 경로) : 큰 입력 이미지로부터 의미(Context)있는 정보를 추출하는 부분

       일반적인 CNN Architecture가 작동하는 부분과 같습니다.

    2. Expanding Path(확장 경로) : 수축 경로에서 추출된 의미정보와 수축 경로에서 각 Layer에 존재하는 픽셀의

       위치정보를 결합(Skip)하여 Up-Sampling을 진행하는 부분

     

    먼저, Contracting Path는 이미지 픽셀의 차원을 축소하면서 의미있는 정보를 추출하는 부분입니다. 총 4개의 DownSampling 과정을 거치는데, 각 층마다 3x3 Conv를 두 번 거친 후 Stride가 2인 Pooling을 사용해 이미지 픽셀(Feature map)의 크기를 1/2로 줄입니다. Feature map의 크기는 줄어들지만 채널은 이전 단계의 2배(Feature Scale, 하이퍼파라미터)씩 증가합니다. 이렇게 여러개 층을 거쳐 Bottle Neck 구간의 Feature map이 형성됩니다.

     

    두 번째로, Expanding Path를 알아보겠습니다. U-Net구조의 중요한 부분은 Expanding Path에 있습니다.

    우리의 최종 목적은 이미지가 주어졌을 때, Segmentation을 수행하는 것입니다. 따라서 Contracting Path에서 축소하였던 정보를, 원래의 이미지와 픽셀 단위로 비교하기 위해서는 같은 크기의 픽셀로 복원을 시켜주어야 합니다.

    그런데, 우리는 차원을 축소하는 과정에서 매 Layer마다 Stride가 2인 Pooling을 사용했습니다. 이를 복원시키려면 Upsampling을 해주어야 하는데, Upsampling을 하는 과정은 매우 많은 정보가 손실됩니다.

    이를 해결하기 위해, U-Net에서는 Contracting Path 과정에서 각 레이어마다 가지고 있는 Feature map을 Expanding Path의 Feature map과 더해주어 Upsampling으로 뭉뚱그려진 위치정보를 보완해주는 효과를 가지게 합니다. 더해준다는 것을 조금 더 자세히 언급하자면, Add 연산이 아닌 Channel 차원으로 Concatenate를 시켜 다음 레이어로 넘겨줍니다.

     

     

    U-Net++을 설명하기 전에, 먼저 U-Net에 대한 설명을 해보았습니다. 지금부터는 U-Net++이 U-Net과는 어떤 차이점이 존재하는지를 설명해보겠습니다. U-Net++은 U-Net과 크게 2가지의 차이점이 있습니다. 

    1. Re-designed skip pathways : U-Net에서도 Skip-Connection을 해주는 부분이 있었지만, U-Net++에서는 DenseNet의 아이디어를 차용하여 Encoder(수축 경로)와 Decoder(확장 경로)사이의 Semantic(의미적) Gap을 연결시켜 줍니다.

    2. Deep Supervision : 각 브랜치의 출력(빨간색 선으로 표시된 부분)을 평균해서 최종 결과로서 사용하는 방법입니다.

     

     

    UNet++ 구조

     

    위는 UNet++의 구조를 나타냅니다. 검은색 동그라미와 선은 기존 UNet의 구조를 의미하는 것이고, 파란색 선과 초록색 선은 UNet++의 추가적인 아이디어를 의미합니다. 이것을 조금 더 자세하게 알아보겠습니다.

     

     

     

    위 그림은 Feature map(이미지)이 첫 번째 Skip Pathway를 통과하는 것을 보여줍니다.

    기존 U-Net에서는 X0_0에서 X0_4로 가는 하나의 Skip만이 존재하였습니다. 그러나, U-Net++에서는 X0_0이

    크기가 키워진(Upsampling) X1_0과 Concatenate되어 X0_1로 가는 것을 볼 수 있습니다. 그리고 이렇게 만들어진 X0_1은 또 다시 Upsampling된 X1_1과 Concatenate되어 X1_2로 흘려줍니다. 이런식으로 하면 저자들은 Encoder와 Decoder의 Feature map간의 Semantic Gap을 더 줄일 수 있게 된다고 언급합니다.

     

    그리고, Deep Supervision은 여러 Semantic Level(위 구조에서는 4개)은 각각 Feature map을 생성하여 정보를 가지고 있습니다. 따라서, 4개의 시맨틱 정보를 모두 이용하여 평균내어 결과를 예측하였습니다. Deep Supervision 방법은 선택적으로 적용할 수 있습니다.

     

     

    U-Net++ 코드구현

    1. Import 라이브러리

    import os
    import cv2
    from collections import OrderedDict
    from glob import glob
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    import torch
    import torch.backends.cudnn as cudnn
    import torch.nn as nn
    import torch.optim as optim
    import yaml
    from albumentations.augmentations import transforms
    from albumentations.core.composition import Compose, OneOf
    from sklearn.model_selection import train_test_split
    from torch.optim import lr_scheduler
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    
    from data_loader import Nuclie_dataset
    from model import Unet_block, UNet
    from utils import BCEDiceLoss, AverageMeter, count_params, iou_score

     

    2. 데이터셋 다운로드

    U-Net++에서는 총 4가지의 데이터 셋으로 실험하였는데 저는 그 중 하나인 Nuclie Dataset으로 진행 하겠습니다.

    데이터셋은 해당 링크를 통해 받으실 수 있습니다.

     

     

    3. 모델 구축

    class Unet_block(nn.Module):
        def __init__(self, in_channels, mid_channels, out_channels):
            super().__init__()
            self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
            self.bn1 = nn.BatchNorm2d(mid_channels)
            self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.bn2 = nn.BatchNorm2d(out_channels)        
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
            return out

    먼저, Unet은 각 레이어마다 Convolution과 Batchnorm, 그리고 ReLU를 2번 반복하는 구조를 가지고 있습니다. 따라서, 모델 구축에 있어 반복되는 부분을 하나의 Class로 선언하였습니다.

     

    class Nested_UNet(nn.Module):
        def __init__(self, num_classes, input_channels=3, deep_supervision=False):
            super().__init__()
    
            num_filter = [32, 64, 128, 256, 512]
            self.deep_supervision = deep_supervision
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            
            # DownSampling
            self.conv0_0 = Unet_block(input_channels, num_filter[0], num_filter[0])
            self.conv1_0 = Unet_block(num_filter[0], num_filter[1], num_filter[1])
            self.conv2_0 = Unet_block(num_filter[1], num_filter[2], num_filter[2])
            self.conv3_0 = Unet_block(num_filter[2], num_filter[3], num_filter[3])
            self.conv4_0 = Unet_block(num_filter[3], num_filter[4], num_filter[4])
    
            # Upsampling & Dense skip
            # N to 1 skip
            self.conv0_1 = Unet_block(num_filter[0] + num_filter[1], num_filter[0], num_filter[0])
            self.conv1_1 = Unet_block(num_filter[1] + num_filter[2], num_filter[1], num_filter[1])
            self.conv2_1 = Unet_block(num_filter[2] + num_filter[3], num_filter[2], num_filter[2])
            self.conv3_1 = Unet_block(num_filter[3] + num_filter[4], num_filter[3], num_filter[3])
           
            # N to 2 skip
            self.conv0_2 = Unet_block(num_filter[0]*2 + num_filter[1], num_filter[0], num_filter[0])
            self.conv1_2 = Unet_block(num_filter[1]*2 + num_filter[2], num_filter[1], num_filter[1])
            self.conv2_2 = Unet_block(num_filter[2]*2 + num_filter[3], num_filter[2], num_filter[2])
    
            # N to 3 skip
            self.conv0_3 = Unet_block(num_filter[0]*3 + num_filter[1], num_filter[0], num_filter[0])
            self.conv1_3 = Unet_block(num_filter[1]*3 + num_filter[2], num_filter[1], num_filter[1])
    
            # N to 4 skip
            self.conv0_4 = Unet_block(num_filter[0]*4 + num_filter[1], num_filter[0], num_filter[0])
    
            if self.deep_supervision:
                self.output1 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
                self.output2 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
                self.output3 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
                self.output4 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
    
            else:
                self.output = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
    
            # initialise weights
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    init_weights(m, init_type='kaiming')
                elif isinstance(m, nn.BatchNorm2d):
                    init_weights(m, init_type='kaiming')
    
        def forward(self, x):                    # (Batch, 3, 256, 256)
    
            x0_0 = self.conv0_0(x)               
            x1_0 = self.conv1_0(self.pool(x0_0))
            x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], dim=1))
            
            x2_0 = self.conv2_0(self.pool(x1_0))
            x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], dim=1))
            x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], dim=1))
    
            x3_0 = self.conv3_0(self.pool(x2_0))
            x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], dim=1))
            x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], dim=1))
            x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], dim=1))
    
            x4_0 = self.conv4_0(self.pool(x3_0))
            x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], dim=1))
            x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], dim=1))
            x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], dim=1))
            x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], dim=1))
    
            if self.deep_supervision:
                output1 = self.output1(x0_1)
                output2 = self.output2(x0_2)
                output3 = self.output3(x0_3)
                output4 = self.output4(x0_4)
                output = (output1 + output2 + output3 + output4) / 4
            else:
                output = self.output(x0_4)
    
            return output

     위는, U-Net++의 모델구조입니다. forward 살펴보겠습니다. 처음에 이미지(x)가 하나의 Convolution Block

    (Conv-BN-ReLU -> 2번 반복)을 거칩니다. 그렇게 해서 나온 결과 x0_0를 Pooling을 통해 x1_0이 나옵니다. 여기까지는 U-Net과 동일합니다. 이후, 풀링한 x1_0을 Upsampling을 통해 사이즈를 키우고 x0_0과 Concatenate를 합니다. 이 때, dim=1은 채널 차원으로 결합하는 것입니다. 이를 차원을 써가며 살펴보겠습니다.

     

    원본 이미지의 차원은 3개의 채널을 가진 (Batch Size, 3, 256, 256) 입니다. 그리고 하나의 Convolution Block을 통과한 x0_0의 차원은 (Batch, 32, 256, 256)입니다. Convolution Block을 통과하고도 이미지 사이즈가 같은 이유는 Padding을 해주었기 때문입니다. 그리고, x1_0은 x0_0를 풀링하고 Convolution Block을 통과해 (Batch, 64, 128, 128)이 됩니다. 이후 x1_0을 Upsampling하면 (Batch, 64, 256, 256)이 됩니다. 이를, x0_0 (Batch, 32, 256, 256)와 Concatenate를 해주면 (Batch, 96, 256, 256)이 됩니다. 이를, 한 번 더 Convolution Block에 통과시키면 (Batch, 32, 256, 256)이 됩니다. 이런 방식으로 기존 U-Net에 Skip을 촘촘히 연결하여 U-Net++ 모델이 구성됩니다.

     

    4. Preprocessing & Dataset

    IMG_HEIGHT=256
    IMG_WIDTH=256
    
    # Image Preprocessing & Augmentation
    train_transform = Compose([
        transforms.Resize(IMG_HEIGHT, IMG_WIDTH),
        OneOf([
        transforms.HorizontalFlip(),
        transforms.VerticalFlip(),
        transforms.RandomRotate90(),], p=1),
        OneOf([
            transforms.HueSaturationValue(),
            transforms.RandomBrightness(),
            transforms.RandomContrast(),], p=1),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    
    val_transform = Compose([
        transforms.Resize(IMG_HEIGHT, IMG_WIDTH),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    
    

    이후, 이미지를 전처리 및 Augmentation하는 부분인 transform을 정의합니다.

    # Image Folder 위치
    base_path = '../data/stage1_train/'
    
    # DataLoader
    train_dataset = Nuclie_dataset(base_path, train=True, transform=train_transform)
    val_dataset = Nuclie_dataset(base_path, train=False, transform=val_transform)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    # Model & Loss & Optimizer
    model = Nested_UNet(1, 3, deep_supervision=False).to(device)
    criterion = BCEDiceLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.000068)

    그 후, 앞서 정의했던 Dataset을 train과 val로 나누고, 모델과 Loss Fuction, Optimizer를 선언합니다.

     

     

    5. Training & Validation

    def train(train_loader, model, criterion, optimizer):
        avg_meters = {'loss':AverageMeter(),
                      'iou' :AverageMeter()}
    
        model.train()
        for inputs, labels in train_loader:
            inputs = torch.tensor(inputs, device=device, dtype=torch.float32)
            labels = torch.tensor(labels, device=device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # log
            iou = iou_score(outputs, labels, threshold=0.8)
            avg_meters['loss'].update(loss.item(), n=inputs.size(0))
            avg_meters['iou'].update(iou, n=inputs.size(0))
    
            log = OrderedDict([
                        ('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                    ])
        return log, model
    
    def validation(val_loader, model, criterion):
        avg_meters = {'loss': AverageMeter(),
                      'iou': AverageMeter()}
        
        model.eval()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = torch.tensor(inputs, device=device, dtype=torch.float32)
                labels = torch.tensor(labels, device=device, dtype=torch.float32)
    
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                iou = iou_score(outputs, labels, threshold=0.8)
    
                avg_meters['loss'].update(loss.item(), n=inputs.size(0))
                avg_meters['iou'].update(iou, n=inputs.size(0))
    
                log = OrderedDict([
                            ('loss', avg_meters['loss'].avg),
                            ('iou', avg_meters['iou'].avg),
                        ])
        return log

    그리고 Train과 Validation 함수를 정의합니다.

     

     

    6. 학습

    epochs=20
    best_iou = 0
    for epoch in range(1, epochs+1):
        train_log, model = train(train_loader, model, criterion, optimizer)
        val_log =  validation(val_loader, model, criterion)
        print(f'{epoch}Epoch')
        print(f'train loss:{train_log["loss"]:.3f} |train iou:{train_log["iou"]:.3f}')
        print(f'val loss:{val_log["loss"]:.3f} |val iou:{val_log["iou"]:.3f}\n')
        valid_iou = val_log['iou']
        if best_iou < valid_iou:
            best_iou = valid_iou
            torch.save(model.state_dict(), f'../results/unet/best_model.pth')

    최종적으로 학습을 수행하고 iou가 가장 높을 때, 모델을 저장합니다.

     

     

    7. 결과

     

    추가적으로 UNet을 구현해 UNet++ 모델과 비교해본 결과는 위와 같습니다.

    시간상 최적의 HyperParameter를 찾지는 못했으나, 간략하게 모델을 학습한 결과는 위에서 보는 것처럼 UNet++ 모델이 더 깔끔하게 분할해내는 것을 볼 수 있었습니다.

    댓글

Designed by Tistory.