-
[논문 리뷰 및 코드구현] AnoGANAnomaly Detection 2020. 11. 14. 21:10
[Review] Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery, 2017
이번 포스팅은 생성모델인 GAN을 활용하여 Anomaly Detection을 처음 시도했던 논문인 AnoGAN에 논문을 살펴보겠습니다. 잘못된 부분이 있으면 언제든지 댓글로 지적해주시면 감사하겠습니다.
먼저, GAN에 대해서 간단하게 살펴보겠습니다.
GAN은 판별자(분류기)로 불리는 Discriminator와 생성자로 불리는 Generator 두 네트워크가 서로 대립하면서 학습을 합니다. 각각의 네트워크를 학습하는 과정을 살펴보면 위 그림에서는 진짜 이미지를 Discriminator에게 진짜인지 가짜인지를 이진분류를 하게합니다. 실제 이미지를 넣어줬으니 Discriminator는 이미지가 진짜라고 판단을 해야하겠죠. 따라서 확률이 1에 가깝게 판단해야합니다.
그리고 아래 부분에서는 먼저 우리가 설정한 Latent vector(z)로부터 무작위의 랜덤 난수를 가져옵니다. 보통 128차원정도면 충분하다고 합니다. Original GAN에서는 Convolution연산이 아닌 Linear(Dense layer)이므로 (Batch, 128)의 Latent vector를 뽑고, DCGAN에서는 (Batch, 128, 1, 1)차원의 벡터를 Generator에 넣어 ConvTranspose를 통해 원본 이미지 크기로 생성합니다.
그렇게 Generator가 생성한 가짜 이미지를 다시 Discriminator에 넣어줍니다. 이 때, Discriminator는 가짜 이미지를 입력으로 받았으므로 가짜라고 판별을 해야합니다. 즉, 확률이 0에 가깝게 해야합니다.
AnoGAN
Abstract
의료 이미지에서 정확하게 이미지를 Labeling(정상, 비정상) 처리한 대량의 데이터를 구하는 것은 비용이 많이듭니다.
그리고 Supervised Learning같은 경우 데이터가 많고 모든 분포를 다룰 수 있다면 성능이 확실하지만, 새로운 유형의 정상, 비정상이 나타나면 그것을 정확하게 잡지 못할 수 있습니다. 또 의료 이미지는 정상 비정상 데이터 수의 불균형이 심하기 때문에 Supervised Learning만으로는 한계가 있습니다. 따라서 Unsupervised Learning을 활용해 이 문제를 해결해 보자는 것이 해당 논문에서 주장하는 바입니다.
* 여담으로 이 논문에서는 Unsupervised라고 주장하지만 제가 생각하기에는 Y 레이블이 학습되는 것이 아니므로 Unsupervised라고 할 수도 있지만, 엄밀하게는 정상, 비정상이라는 레이블을 알고 있어야 하므로 Semi-supervised 방법이라고 생각합니다.
Introduction & Method
Discriminator는 입력 이미지가 진짜인지 가짜인지를 판별해주는 분류기입니다.
그리고 Generator는 학습을 통해 가짜 이미지를 진짜 이미지처럼 생성하도록 학습됩니다. 그러나, 만약 Generator가 Discriminator를 속이도록 학습해본적이 없는 이미지들은 생성을 하게 되더라도 Discriminator가 가짜라고 판별을 할 것입니다. 즉, 이러한 개념을 Anomaly에 적용하여 주어진 데이터셋 내에서 정상 이미지만을 학습에 사용합니다. 따라서, 학습에서 사용되지 않은 비정상 이미지를 입력으로 넣었을 때, Generator는 해당 이미지를 제대로 생성해내지 못할 것이고, 그렇다면 Discriminator는 가짜라고 판별할 것이라는 아이디어입니다.
아래는 제가 생각하는 GAN과 AnoGAN의 차이점을 간략하게 정리해보았습니다.
지금부터는 AnoGAN에 대한 프로세스를 좀 더 자세히 살펴보겠습니다.
GAN은 결국 z(Latent)에서 X(image)로 향하는 이미지(가짜)를 생성하는 과정입니다.
그러나 Anomaly Detection에서 수행하기 위한 이미지는 z로 생성하는 것이 아니라 실제 이미지 X를 사용해야 합니다. 따라서, X를 z로 맵핑하는 과정이 필요합니다. 이 과정은 Discriminator와 Generator만으로는 불가능합니다.
이 논문에서는 이러한 것을 해결하기 위해 X에서 z로 가는 과정 대신에 X와 비슷한 이미지를 만들어내는 z를 학습시킵니다. 저도 처음에 이 부분이 이해가 잘 되지 않았는데, 좀 더 자세히 설명해보겠습니다. 아래 코드리뷰에서 더 명확히 확인하실 수 있습니다.
Latent z를 학습시키기 이전에, 먼저 Generator와 Discriminator는 학습이 완료되어있는 상태여야 합니다.
이후 위의 Loss Function을 살펴보면, Residual Loss와 Discrimination Loss 두 가지의 조합으로 되어있습니다.
Generator와 Discriminator의 Parameter는 고정시킨 상태로, 랜덤한 z인 z1을 샘플링했다고 합시다. 그러면 Residual Loss는 입력 이미지 x와 생성이미지 G(z1)의 Pixel별 차이를 Loss로 사용합니다.
그리고 Discrimination Loss는 입력이미지를 넣어 판별한 값과 z1으로 생성한 이미지의 차이를 Loss로 사용합니다. 그리고 이 두 Loss를 더할 때, Residual Loss에 0.9, Discrimination Loss에 0.1의 가중치를 주어 z를 업데이트합니다.
이렇게 학습된, z는 입력 이미지와 유사한 이미지라고 할 수 있습니다. 따라서, 실제 테스트 이미지와 유사하도록 학습된 z를 넣는것입니다. 그리고 실제로 Anomaly를 판별할 때, 위의 z를 학습할 때 사용했던 Loss Function을 Anomaly score로 사용하여 특정 임계치를 기준으로 정상인지 비정상인지를 판별합니다. 보통 우변의 왼쪽항인 Residual Loss가 이미지에 anomalous한 영역을 판단하는 것으로 사용됩니다.
Experiment
AnoGAN은 실험으로 SD-OCT라는 의료 영상 이미지(망막)를 사용하였습니다. 그리고 Discriminator와 Generator의 구조는 DCGAN과 동일하고, 다만 채널의 수만 절반으로 줄여서 사용하였습니다. 이는 Gray scale 이미지라서 그랬다고 합니다. 이미지의 사이즈는 64x64를 사용했습니다.
실험은 총 4가지로 비교를 했습니다. 먼저 가장 왼쪽의 ROC 커브를 살펴보면 GAN_R과 AnoGAN이 가장 좋게 나오는 것을 확인할 수 있습니다. (c)는 Residual score의 분포이고 (d)는 Discrimination score의 분포입니다. 한눈에 보기에도 Residual score가 정상과 비정상을 더 명확히 구별하는 것을 볼 수 있습니다.
**
aCAE: Adversarial Convolution AutoEncoder 구조로 Encoder-Decoder Based
P_D: Discriminator가 판단한 확률값을 Anomaly score로 사용한 것
GAN_R: Referenced Adversarial score를 Anomaly score로 사용한 것
AnoGAN: 해당 논문에서 언급한 Residual Loss를 포함한 것을 Anomaly score로 사용한 모델
**
AnoGAN 코드구현
구현에 사용한 데이터셋은 CIFAR-10을 사용하였고, 10개의 클래스중에 1번 클래스 Car를 정상으로 보고 학습시켰고, 나머지 9개는 비정상으로 테스트셋으로 사용하였습니다.
1. Import 라이브러리
import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.animation as animation import seaborn as sns import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.utils import make_grid, save_image from network import Generator, Discriminator, Encoder from util import * import torchvision.utils as vutils import warnings warnings.filterwarnings('ignore') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2. 데이터셋 불러오기 및 전처리
class anomaly_dataset(Dataset): def __init__(self, inputs, labels, transform=None): self.inputs = inputs self.labels = labels self.transform = transform def __len__(self): return len(self.inputs) def __getitem__(self, idx): inputs = np.transpose(self.inputs[idx], (2, 0, 1)) labels = self.labels[idx] if self.transform: inputs = self.transform(inputs) return inputs, labels epochs = 100 batch_size = 64 lr = 0.0002 ndf = 64 ngf = 64 latent_dim = 128 img_size = 64 channels = 3 n_critic = 5 split_rate = 0.8 # %% trainset = torchvision.datasets.CIFAR10(root='../data/', train=True, download=True) testset = torchvision.datasets.CIFAR10(root='../data/', train=False, download=True) # Label 1을 Normal data, 나머지 0,2,3,4,5,6,7,8,9는 abnormal data x_train_temp = torch.ByteTensor(trainset.data[torch.IntTensor(trainset.targets) == 1]) x_train_normal, x_valid_normal = x_train_temp.split((int(len(x_train_temp) * split_rate)), dim=0) y_train_temp = torch.ByteTensor(trainset.targets)[torch.tensor(trainset.targets)==1] y_train_normal, y_valid_normal = y_train_temp.split((int(len(y_train_temp) * split_rate)), dim=0) train_cifar10 = anomaly_dataset(x_train_normal, y_train_normal, transform=transforms.Compose([ transforms.ToPILImage(), transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) train_loader = DataLoader(train_cifar10, batch_size=batch_size, shuffle=True)
3. 모델 구축
class Generator(nn.Module): def __init__(self, latent_dim=100, num_gf=32, channels=3, bias=True): ''' latent_dim: Latent vector dimension num_gf: Number of Generator Filters channels: Number of Generator output channels ''' super(Generator, self).__init__() self.layer = nn.Sequential( nn.ConvTranspose2d(latent_dim, num_gf*8, 4, 1, 0, bias=bias), nn.BatchNorm2d(num_gf*8), nn.ReLU(), nn.ConvTranspose2d(num_gf*8, num_gf*4, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_gf*4), nn.ReLU(), nn.ConvTranspose2d(num_gf*4, num_gf*2, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_gf*2), nn.ReLU(), nn.ConvTranspose2d(num_gf*2, num_gf, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_gf), nn.ReLU(), nn.ConvTranspose2d(num_gf, channels, 4, 2, 1, bias=bias), nn.Tanh() ) def forward(self, z): z = self.layer(z) return z class Discriminator(nn.Module): def __init__(self, num_df=32, channels=3, bias=True): super(Discriminator, self).__init__() self.feature_layer = nn.Sequential( nn.Conv2d(channels, num_df, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_df), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_df, num_df*2, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_df*2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_df*2, num_df*4, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_df*4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(num_df*4, num_df*8, 4, 2, 1, bias=bias), nn.BatchNorm2d(num_df*8), nn.LeakyReLU(0.2, inplace=True), ) self.dis_layer = nn.Sequential(nn.Conv2d(num_df*8, 1, 4, 1, 0, bias=bias), nn.Sigmoid() ) def forward_features(self, x): features = self.feature_layer(x) return features def forward(self, x): features = self.forward_features(x) discrimination = self.dis_layer(features) return discrimination # weight 초기화 def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: torch.nn.init.normal_(m.weight, 0.0, 0.02) elif classname.find('BatchNorm') != -1: torch.nn.init.normal_(m.weight, 1.0, 0.02) torch.nn.init.zeros_(m.bias)
4. Loss, Optimizer, Network 선언
G = Generator(latent_dim=latent_dim, num_gf=ngf, channels=channels, bias=False).to(device) G.apply(weights_init) D = Discriminator(num_df=ndf, channels=channels, bias=False).to(device) D.apply(weights_init) criterion = nn.BCELoss() optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, weight_decay=1e-5, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, weight_decay=1e-5, betas=(0.5, 0.999)) scheduler_G = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_G, T_0=15, T_mult=2) scheduler_D = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_D, T_0=15, T_mult=2)
5. Train/valid
img_list = [] d_losses = [] g_losses = [] iters = 0 D.train() G.train() steps_per_epoch = len(train_loader) for epoch in range(1, epochs+1): for i, (images, _) in enumerate(train_loader): real_images = images.to(device) batch_num = images.size(0) ###### Update D network: maximize log(D(x)) + log(1 - D(G(z))) ###### optimizer_D.zero_grad() real_output = D(real_images) real_label = torch.ones_like(real_output, device=device) z = torch.randn(batch_num, latent_dim, 1, 1, device=device) fake_images = G(z) fake_output = D(fake_images.detach()) fake_label = torch.zeros_like(fake_output, device=device) real_lossD = criterion(real_output, real_label) fake_lossD = criterion(fake_output, fake_label) D_loss = real_lossD + fake_lossD D_loss.backward() optimizer_D.step() P_real = real_output.mean().item() # Discriminator가 real image를 진짜라고 판별한 확률 P_fake = fake_output.mean().item() # Discriminator가 fake image를 진짜라고 판별한 확률 for _ in range(3): ############# Update G network: maximize log(D(G(z))) ############## optimizer_G.zero_grad() fake_images = G(z) fake_output = D(fake_images) G_loss = criterion(fake_output, torch.ones_like(fake_output, device=device)) G_loss.backward() optimizer_G.step() scheduler_D.step(epoch + i * steps_per_epoch) scheduler_G.step(epoch + i * steps_per_epoch) d_losses.append(D_loss.item()) g_losses.append(G_loss.item()) # if i % 62==0: # print(f'Epoch {epoch}/{epochs} | Batch {i*batch_size + batch_num}/{len(train_loader.dataset)} | D loss: {D_loss.item():.6f} | G loss: {G_loss.item():.6f} | P(real): {P_real:.4f} | P(fake): {P_fake:.4f}') if (iters % 500 == 0) or ((epoch == epochs) and (i == len(train_loader)-1)): with torch.no_grad(): fake_images = G(z).detach().cpu() img_list.append(vutils.make_grid(fake_images, padding=2, normalize=True)) iters += 1 if epoch % 5 == 0: print(f'Epoch {epoch}/{epochs} | D loss: {D_loss.item():.6f} | G loss: {G_loss.item():.6f} | P(real): {P_real:.4f} | P(fake): {P_fake:.4f}')
위 코드 중간에 for _ in range(3)이 들어간 부분은 Discriminator를 1번 학습시킬 때, Generator를 3번 학습시키겠다는 것입니다. 위와 같이 한 이유는 GAN의 경우 mode collapse 즉, Discriminator나 Generator 중 하나의 모델이 너무 잘 맞추는 경우에 다른 모델이 학습이 잘 되지 않았는데, 이번 구현의 경우에도 Discriminator의 Loss가 0으로 수렴하여 Generator가 학습이 되지않았습니다. 따라서 위와같이 Generator를 3번 학습하게 하였더니 서로 균형을 맞추어가며 잘 학습하였습니다.
6. 학습 Loss Plot
# Loss Plotting plt.plot(g_losses, label='g_loss') plt.plot(d_losses, label='d_loss') plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.show()
제가 구현한 코드에서 Loss 결과는 위와 같이 초반에는심하게 요동치다가 후반으로 갈수록 서로 적절한 Loss 띄는것을 확인할 수 있었습니다.
# batch of real images from the dataloader real_batch = next(iter(train_loader)) # Plot the real images plt.figure(figsize=(15,15)) plt.subplot(1,2,1) plt.axis("off") plt.title("Real Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0))) # Plot the fake images from the last epoch plt.subplot(1,2,2) plt.axis("off") plt.title("Fake Images") plt.imshow(np.transpose(img_list[-1],(1,2,0))) plt.show()
이후 위 코드를 사용해 진짜 이미지와 생성된 가짜 이미지를 한 번 뽑아보았습니다.
왼쪽이 진짜 이미지이고, 오른쪽이 가짜 이미지입니다. 가짜 이미지가 흐릿하게 나오긴 하지만 차라는 것을 알아볼 정도로는 학습이 된 것을 확인할 수 있었습니다. 조금 더 정교하게 파라미터를 조정해 학습을 시킨다면 위의 그림보다는 개선이 될 것입니다.
7. Anomaly Score 계산 함수
def residual_loss(real_images, generated_images): subtract = real_images - generated_images return torch.sum(torch.abs(subtract)) def discriminator_loss(netD, real_images, generated_images): real_features = D.forward_features(real_images) generated_features = D.forward_features(generated_images) subtract = real_features - generated_features return torch.sum(torch.abs(subtract)) def anomaly_loss(residual_loss, d_loss, l=0.1): return (1 - l) * residual_loss + l * d_loss def estimate_anomaly_score(real_images, generated_images, D): real_images = real_images generated_images = generated_images resi_loss = residual_loss(real_images, generated_images) disc_loss = discriminator_loss(D, real_images, generated_images) ano_loss = anomaly_loss(resi_loss, disc_loss, l=0.5) return ano_loss.cpu().data.numpy() def compare_images(G, D, real_image, generated_image, labels, reverse, idx, threshold=1): score = estimate_anomaly_score(real_image, generated_image, D) score = np.round(score, 2) real_image = np.transpose(real_image.cpu().data.numpy().squeeze(), axes=(1, 2, 0)) * 255 generated_image = np.transpose(generated_image.cpu().data.numpy().squeeze(), axes=(1, 2, 0)) * 255 negative = np.zeros_like(real_image) if not reverse: diff_image = real_image - generated_image else: diff_image = generated_image - real_image diff_image[diff_image <= threshold] = 0 ano_image = np.zeros(shape=(img_size, img_size, 3)) ano_image[:, :, 0] = real_image[:, :, 0] - diff_image[:, :, 0] ano_image[:, :, 1] = real_image[:, :, 1] - diff_image[:, :, 1] ano_image[:, :, 2] = real_image[:, :, 2] - diff_image[:, :, 2] ano_image[:, :, 0] = ano_image[:,:, 0] + diff_image[:, :, 0] ano_image = ano_image.astype(np.uint8) fig, plots = plt.subplots(1, 4) fig.suptitle(f'Anomaly - (anomaly score: {score:.4f}) \n Class {labels.item()}') fig.set_figwidth(9) fig.set_tight_layout(True) plots = plots.reshape(-1) plots[0].imshow(real_image.astype('uint8'), cmap='gray', label='real') plots[1].imshow(generated_image.astype('uint8'), cmap='gray') plots[2].imshow(diff_image.astype('uint8'), cmap='gray') plots[3].imshow(ano_image.astype('uint8')) plots[0].set_title('real') plots[1].set_title('generated') plots[2].set_title('difference') plots[3].set_title('Anomaly Detection') plt.show()
8. 테스트로 사용할 데이터 정의
# Test Data Load ''' x_test_normal: Trainset 중 valid용으로 빼놓은 정상 class data(여기선 class1) x_train_abnormal: Trainset 중 비정상 class data (0, 2, 3 ~ 9) x_test_total: Total Testset (모든 class 존재) ''' x_train_abnormal = torch.ByteTensor(trainset.data[torch.IntTensor(trainset.targets) != 1]) y_train_abnormal = torch.ByteTensor(trainset.targets)[torch.tensor(trainset.targets) != 1] x_test_total = torch.ByteTensor(testset.data) y_test_total = torch.ByteTensor(testset.targets) x_test = torch.cat([x_valid_normal, x_train_abnormal, x_test_total], dim=0) y_test = torch.cat([y_valid_normal, y_train_abnormal, y_test_total], dim=0) test_cifar10 = anomaly_dataset(x_test, y_test, transform=transforms.Compose([ transforms.ToPILImage(), transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) test_loader = DataLoader(test_cifar10, batch_size=1, shuffle=False)
테스트 셋으로는 정상과 비정상을 섞어서 사용합니다.
9. Latent Z 학습
# Train Latent Space Z latent_space = [] G.eval() D.eval() z = torch.randn(1, latent_dim, 1, 1, device=device, requires_grad=True) optimizer_z = torch.optim.Adam([z], lr=lr) latent_space = [] auc = [] for i, (images, _) in enumerate(test_loader): real_images = images.to(device) print(f'image{i+1}') for step in range(401): generated_images = G(z) optimizer_z.zero_grad() resi_loss = residual_loss(real_images, generated_images) disc_loss = discriminator_loss(D, real_images, generated_images) ano_loss = anomaly_loss(resi_loss, disc_loss, l=0.1) ano_loss.backward(retain_graph = True) optimizer_z.step() if step%200 == 0: loss = ano_loss.item() noises = torch.sum(z).item() print("[%d]\t loss_Ano:%.4f Sum_of_z:%.4f" %(step,loss,noises)) if step==400: latent_space.append(z.cpu().data.numpy()) latent_space = np.array(latent_space) latent_space = torch.Tensor(latent_space).to(device)
위에서 언급했던 Latent z의 학습은 위와 같이 이루어집니다. torch.randn으로 생성한 128차원의 랜덤한 z를 변할 수 있는 Weight로 보아 optimizer의 Parameter로 넣어줍니다. 보통 optimizer에 params=model.parameters()하는 부분에 z를 넣어준다고 생각하시면 됩니다.
이후 test_loader에서 하나의 이미지씩을 뽑고, 해당 이미지와 z로 생성된 가짜이미지를 Residual loss와 Discrimination loss가 더해진 anomaly loss를 줄여나가는 식으로 z가 학습하게 됩니다. 그리고 하나의 이미지당 z를 400번씩 학습을 하도록 하였습니다. 그리고 그렇게 학습된 z를 latent_space라는 list에 담아주어 아래 테스트에서 사용하게 됩니다.
for i, (images, labels) in enumerate(test_loader): real_image = images.to(device) update_z = torch.as_tensor(latent_space[i], device=device, dtype=torch.float32) # Latent Z를 학습할때 400iter돌고난 후의 z값 generated_image = G(update_z).to(device) if i % 1==0: compare_images(G, D, real_image, generated_image, labels, False, i, threshold=50)
따라서, 위에서 모든 테스트 이미지를 latent z로 학습하여 만든 리스트 latent_space를 update_z = latent_space[i]로 하나씩 뽑아 실제 이미지와 학습된 z로 생성한 이미지를 비교하는 코드입니다.
10. 결과
돌린 모델을 토대로 나온 결과는 위와 같습니다. CIFAR-10 데이터는 정상과 비정상을 구별해냄에 있어서 비정상의 Anomaly score가 약간은 오른쪽으로 치우쳐있으나 대체로 정상데이터와 겹친다는것을 볼 수 있었습니다.
이는 CIFAR-10이 Anomaly Task로 사용하기에 무리인것도 있고, AnoGAN모델이 Anomaly를 Detection함에 있어서 완전한 성능을 보이는 것이 아니라는 것을 확인할 수 있습니다.
AnoGAN에서 Latent z를 학습시키는 부분이 매우 오래걸린다는 점을 보완한 f-AnoGAN 논문도 다음에 리뷰해보도록 하겠습니다. 읽어주셔서 감사합니다.
'Anomaly Detection' 카테고리의 다른 글
[논문 리뷰 및 코드구현] Deep One-Class Classification(Deep SVDD) (7) 2020.10.30