Image Segmentation

[논문 리뷰 및 코드구현] Simple Does It: Weakly Supervised Instance and Semantic Segmentation

Barca 2020. 12. 27. 23:06

Simple Does It: Weakly Supervised Instance and Semantic Segmentation, CVPR 2017

 

이번 포스팅은 Semantic Segmentation을 위한 Weakly-Supervised Learning 관련 논문에 대해서 살펴보겠습니다.

 

먼저, Weakly Supervised Learning이라는 분야에 대해서 살펴보겠습니다. 기존의 Computer Vision 분야는 크게 이미지 인식(Recognition)과 이미지 분할(Segmentation), 그리고 객체 탐지(Object Detection)으로 이루어져 있습니다.

이미지 인식은 우리가 주로 알고 있는 태스크로, 학습시에 입력으로 이미지와 해당 이미지에 대한 Class(Label) 정보를 사용하게 됩니다. 이미지 분할에서는 이미지와 해당 이미지에서 분할하고자 하는 영역이 Pixel 단위로 라벨링이 되어있는 Mask가 사용됩니다. 그리고 객체 탐지는 이미지와 이미지의 Class정보, 그리고 Bounding Box의 좌표가 함께 학습에 사용되게 됩니다.

 

그러나, 이미지 분할과 객체 탐지 영역에서 Labeling을 하는 작업은 많은 인력과 시간, 그리고 비용이 소모되게 됩니다. 특히, 픽셀단위로 정확하게 분할해야하는 이미지 분할(Segmentation)의 라벨링은 더욱 시간이 많이 소모됩니다. 그리고 Labeling을 하는 사람의 컨디션이나 익숙함에 따라 라벨링된 결과물의 차이가 발생할 수도 있습니다. Weakly Supervised Learning은 학습에있어 필수적인 Label을 최대한 사용하지 않는 방법이 없을까? 하는 의문에서 시작된 연구입니다. 따라서, 이미지 분할이나 객체 탐지를 할 때, 상대적으로 약한(Weakly) 정보로 분할이나 탐지를 해보자는 아이디어 입니다. 여기서 약한 정보라는 것은 이미지 분할이나 객체 탐지를 할 때, 기존에 주어지던 분할 Mask(Label) 또는 Bounding box는 상대적으로 강한 정보이고, Class에 대한 Label 정보는 약한정보라고 할 수 있습니다. 실제로 픽셀별로 분할을 하거나 바운딩박스를 치는 것은 단순히 해당 이미지에 어떤 물체의 Class가 있다라는 라벨링보다 비용이 훨씬 많이 듭니다. 따라서, 이미지 분할 또는 객체 탐지 시에 Class에 대한 정보(or +위치 정보)만을 이용하자는 것입니다.

 

그렇다면 해당 논문에서는 어떤 방식으로 Segmentation을 진행했는지 살펴보겠습니다.

 위 그림은 해당 논문에서 발췌한 그림으로 (a)는 입력 이미지이고, (b)는 학습에 사용되는 것이 아닌 평가에만 사용되는 GT이고, (c)는 완전한 Labeled Mask가 아닌 박스 형태로 해당 영역에 어떤 물체들이 있다라는 Box형태로 주어진 Label입니다. (b)와 (c)를 한눈에 비교해봐도 라벨링을 할 시에 (b)가 훨씬 어려운 작업이라는 것을 알 수 있습니다.

따라서, 해당 논문에서는 사람과 말이라는 Class정보와 위치정보인 Bounding Box만을 이용하여 Segmentation을 수행하였습니다.

 

해당 논문에서는 Recursive Learning이라는 방법을 사용하였습니다.

위 그림은 Recursive Learning 어떻게 구현하였는지에 대한 그림입니다. 먼저 처음에는 Ground Truth가 Bounding box형태로 되어있어 Pixel 단위로의 분할보다는 훨씬 거칩니다. 처음 1Round동안은 이 Box를 사용하여 모델 학습을 수행합니다. 그리고 다음 Round를 학습하기전에 다음과 같은 3가지의 후처리 단계를 거치게 됩니다.

1. Bounding box 외부의 모든 픽셀은 배경 레이블로 설정합니다.

2. Segment된 영역이 해당 Bounding box에 비해 너무 작은 경우(IOU < 50%)는 Bounding box를 그대로 레이블로 가져갑니다. 이러한 방식은 최소한의 영역은 겹치게하려는 후처리입니다.

3. 이미지의 경계부분을 더 잘 보존하기 위해서, 후처리 방법으로 자주 사용되는 Dense CRF(Conditional Random Field) 기법을 사용하였습니다.

 

따라서, 이렇게 매 라운드마다 3가지의 후처리를 거친 후의 Segment된 영역을 다음 영역의 Ground Truth로 사용하는 방식으로 학습이 진행됩니다. 이렇게 하였을 때, 10Round를 거친뒤의 예측 맵은 Ground Truth와 완벽히 일치하진 않지만, 꽤 디테일하고 준수한 성능을 보여주는 것을 확인할 수 있습니다.

 

그리고, 해당 논문은 Recursive Learning외에도 GrabCut, GrabCut+, MCG 등의 알고리즘을 추가로 사용하여 성능을 개선했습니다.

 

GrabCut알고리즘은 전통적인 Computer vision 알고리즘으로, 물체가 있는 곳에 Bounding box를 치면 물체 외의 배경 부분을 제거해주는 알고리즘입니다.

GrabCut Algorithm

GrabCut+ 알고리즘은 GrabCut에 이미지의 경계 영역의 테두리를 잘 추출할 수 있는 HED Boundary Detector를 추가로 적용한 방법입니다.

 

그리고, 마찬가지로 CV 알고리즘인 MCG(Multiscale combinatorial grouping) 알고리즘을 적용한 것을 Ground Truth로 사용하였습니다.

 

그리고 MCG와 GrabCut+을 동시에 적용한 방법으로도 실험을 진행하였습니다.

 

 위 그림은 해당 알고리즘들을 적용하여 사용한 초기 Ground Truth입니다. i가 붙은 방법은 해당 방법에 ignore region이라는 방법을 통해 Box 내부의 20%만 GT Mask로 사용하고 나머지 영역은 무시하겠다는 방법입니다. (어떻게보면 Recall을 낮추는 대신 Precision을 높이는 방법이라고 할 수 있습니다.)

위 방법들의 그림을 보시면 기존 Box만 있는 방법보다 훨씬 나은 수준의 Ground Truth를 초기 GT로 사용할 수 있는 것을 확인할 수 있습니다. 

 

따라서, 위의 방법들을 토대로 대표적인 벤치마크 데이터셋인 PASCAL VOC와 MS COCO에 대해 실험을 수행하였습니다.

결과를 살펴보면, Box, Box(i), MCG, GrapCut, GrapCut+, MCG & GrapCut+ 방법순으로 점점 Mean IOU가 좋아지는 것을 볼 수 있습니다. 그리고, 기존에 Label이 섬세하게 된 GT를 사용한 Fully Supervised 방법의 성능과 비교했을 때도 3.4point정도만이 차이가 나는 것을 확인할 수 있습니다. 이를 통해 구체적으로 Labeling된 Mask가 존재하지 않아도 Weak(Box)정보만으로도 준수한 성능을 내는 것을 볼 수 있습니다.

 

 

해당 논문은 Labeling Cost가 매우 큰 Segmentation의 문제점을 Weakly supervised Learning으로 해결하였다는 점에서 좋은 방향성을 제시했다고 생각합니다. 그리고 이를 달성하기 위해 기존의 전통적인 Computer Vision 알고리즘들을 적절하게 조합하여 Fully Supervised 방법과 비슷한 수준의 성능을 달성하였다는 점에서 인상깊은 논문이라고 생각합니다.

 

 

 

코드리뷰

 

1. Import Library

import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as Et
from xml.etree.ElementTree import Element, ElementTree
import pydensecrf.densecrf as dcrf

Dense CRF를 사용하기 위해 pip install pydensecrf를 진행하여줍니다. pip in

tall pydensecrf

 

2. Dataset (PASCAL-VOC 2012)

 

PASCAL-VOC 2012 데이터셋을 받은 후 아래와 같이 Dataset Class를 구성해줍니다.

해당 논문에서 필요한 Initial GT를 생성하기 위하여 아래의 코드를 사용하여 getitem에서 Annotation (Bounding box) 좌표를 추출하여 만들어줍니다.

image_path = '../data/VOCdevkit/VOC2012/JPEGImages/'
anno_path = '../data/VOCdevkit/VOC2012/Annotations/'
GT_path = '../data/VOCdevkit/VOC2012/SegmentationClass/'

class WSSS_dataset(Dataset):
    def __init__(self, image_path, anno_path, GT_path, transform):
        self.images = glob.glob(image_path + '*')
        self.annotations = glob.glob(anno_path + '*')
        self.GTs = glob.glob(GT_path + '*')
        self.images = sorted(self.images)
        self.annotations = sorted(self.annotations)
        self.GTs = sorted(self.GTs)
        self.transform = transform

    def __getitem__(self, idx):

        img = cv2.imread(self.images[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        eval_GT = cv2.imread(self.GTs[idx])
        eval_GT = cv2.cvtColor(eval_GT, cv2.COLOR_BGR2GRAY)

        # Generate Weak GT
        xml = open(self.annotations[idx], 'r')
        tree = Et.parse(xml)
        root = tree.getroot()
        objects = root.findall("object")
        
        empty_GT = torch.zeros(img.shape[:2])
        for i, _object in enumerate(objects):
            name = _object.find("name").text
            bndbox = _object.find("bndbox")
            xmin = int(bndbox.find("xmin").text)
            ymin = int(bndbox.find("ymin").text)
            xmax = int(bndbox.find("xmax").text)
            ymax = int(bndbox.find("ymax").text)

            img = cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255,0,0), 1)  # Image + Bounding Box
            empty_GT[ymin:ymax, xmin:xmax] = i+1

        if self.transform is not None:
            img = self.transform(img)

        return img, empty_GT, eval_GT

    def __len__(self):
        return self.images

 

위 그림은 코드로 추출한 이미지와 Initial GT입니다. 이제 위와 같이 준비된 데이터로 학습을 진행합니다.

 

3. Model & Utils

해당 논문에서는 DeepLab v2와 VGG16을 Backbone으로 사용하였으나, 본 논문에서는 DeepLab v3와 ResNet50을 사용하였습니다. 그리고 후처리에 사용되는 알고리즘인 Dense CRF와 iou score, 그리고 후처리로 제안하는 3가지를 함수로 정의합니다.

class dense_CRF:
    def __init__(self, img, masks):
        self.img = img
        self.masks = masks
        self.width = masks.shape[0]
        self.height = masks.shape[1]
        self.classes = masks.shape[2]

    def run_dense_CRF(self):
        # [w, h, class] to [class, w, h]
        U = self.masks.transpose(2, 0, 1).reshape((self.classes, -1))
        U = U.copy(order='C')
        # declare width, height, class
        d = dcrf.DenseCRF2D(self.height, self.width, self.classes)
        # set unary potential
        d.setUnaryEnergy(-np.log(U))
        # set pairwise potentials
        d.addPairwiseGaussian(sxy=(3, 3), compat=3)
        d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=self.img, compat=10)
        # inference with 5 iterations
        Q = d.inference(5)
        # MAP prediction
        map = np.argmax(Q, axis=0).reshape((self.width, self.height))
        # class-probabilities
        proba = np.array(map)

        return proba

def post_processing(outputs, init_GT):
    
    # Post process1
    init_GT = (init_GT>0).long    # 값이 있으면 1, 없으면 0
    outputs = outputs * init_GT   # 곱하면 BBox 바깥 배경은 0으로 사라짐

    # Post process2
    iou = iou_score(outputs, init_GT)
    if iou < 0.5:
        outputs = init_GT
    
    # Post process3
    proba = dense_CRF(outputs, init_GT).run_dense_CRF()
    outputs = outputs*proba
    return outputs

def iou_score(im1, im2):
    overlap = (im1>0.5) * (im2>0.5)
    union = (im1>0.5) + (im2>0.5)
    return overlap.sum()/float(union.sum())
    
    
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

 

4. Train

학습은 아래와 같이 첫 에폭에는 모델의 output과 init_GT와 Loss를 계산하여 학습을 수행하고, 이후부터는 output에 post_processing이라는 후처리 함수를 적용한 뒤 해당 Map을 리스트로 담아 다음 에폭에서 Loss로 사용합니다.

replace_GT_list = []
num_epochs = 35
for epoch in num_epochs:
    for i, (images, init_GT, eval_GT) in enumerate(train_loader):
        images = torch.tensor(images, device=self.device, dtype=torch.float32)
        init_GT = torch.tensor(init_GT, device=self.device, dtype=torch.float32)
        eval_GT = torch.tensor(eval_GT, device=self.device, dtype=torch.float32)

        optimizer.zero_grad()
        outputs = model(images)

        # First Round
        if epoch == 0:    
            loss = criterion(outputs, init_GT)
            loss.backward()
            optimizer.step()

            outputs = post_processing(outputs, init_GT)
            replace_GT_list.append(outputs)
        
        else:
            loss = criterion(outputs, replace_GT_list[i])
            loss.backward()
            optimizer.step()

            outputs = post_processing(outputs, replace_GT_list[i])
            replace_GT_list.append(outputs)
        
    replace_GT_list = torch.stack(replace_GT_list, dim=0)

 

 

 최종적으로 학습된 것을 통해 Bbox만을 사용하여 CRF등의 후처리를 사용한 최종맵은 위와 같습니다. GT에 비해 뚜렷하게 좋은 것은 아니지만 어느정도 소의 윤곽등은 잘 분할한 것을 확인할 수 있습니다.

 

 

이번 포스팅은 Weakly Supervised Learning이라는 개념과 이를 Segmentation에 적용한 논문에 대해 살펴보았습니다. 이 논문에서는 Weak label로 Bbox정보를 사용하였는데, 다음 포스팅은 이보다 더 약한 label이라고 할 수 있는 Class 정보(위 그림에서는 소)만을 이용하여 Segmentation 또는 Object detection을 수행하는 논문리뷰로 돌아오겠습니다. 읽어주셔서 감사합니다.