import torch
import torch.nn as nn
class Unet(nn.Module):
def __init__(self):
super().__init__()
# convolution + batch normalize + pooling
# 3x3 convolution, unppaded convolution
def CBR_2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
layer = [] # := 왜 안되는거지
layer = [nn.Conv2d(in_channels, out_channels)]
layer += [nn.BatchNorm2d(num_features=out_channels)]
layer += [nn.ReLU()]
layers = nn.Sequential(*layer)
return layers
# --- extract --- #
self.down_1_1 = CBR_2d(in_channels=1, out_channels=64)
self.down_1_2 = CBR_2d(in_channels=64, out_channels=64)
self.pool_1 = nn.MaxPool2d(kernel_size=2)
self.down_2_1 = CBR_2d(in_channels=64, out_channels=128)
self.down_2_2 = CBR_2d(in_channels=128, out_channels=128)
self.pool_2 = nn.MaxPool2d(kernel_size=2)
self.down_3_1 = CBR_2d(in_channels=256, out_channels=256)
self.down_3_2 = CBR_2d(in_channels=256, out_channels=256)
self.pool_3 = nn.MaxPool2d(kernel_size=2)
self.down_4_1 = CBR_2d(in_channels=512, out_channels=512)
self.down_4_2 = CBR_2d(in_channels=512, out_channels=512)
self.pool_4 = nn.MaxPool2d(kernel_size=2)
# --- bridge --- #
self.bridge_1 = CBR_2d(in_channels=512, out_channels=1024)
self.bridge_2 = CBR_2d(in_channels=1024, out_channels=512)
# --- expand --- #
# 2x2 convolution (“up-convolution”)
# halves the number of feature channels
self.up_4_1 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, padding=0)
self.up_4_2 = CBR_2d(in_channels=512, out_channels=512)
self.up_4_3 = CBR_2d(in_channels=512, out_channels=512)
self.up_3_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, padding=0)
self.up_3_2 = CBR_2d(in_channels=256, out_channels=256)
self.up_3_3 = CBR_2d(in_channels=256, out_channels=256)
self.up_2_1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, padding=0)
self.up_2_2 = CBR_2d(in_channels=128, out_channels=128)
self.up_2_3 = CBR_2d(in_channels=128, out_channels=128)
self.up_1_1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, padding=0)
self.up_1_2 = CBR_2d(in_channels=64, out_channels=64)
self.up_1_3 = CBR_2d(in_channels=64, out_channels=64)
self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)
def forward(self, x):
# down
down_1_1 = self.down_1_1(x)
down_1_2 = self.down_1_2(down_1_1)
pool_1 = self.pool_1(down_1_2)
down_2_1 = self.down_2_1(pool_1)
down_2_2 = self.down_2_2(down_2_1)
pool_2 = self.pool_2(down_2_2)
down_3_1 = self.down_3_1(pool_2)
down_3_2 = self.down_3_2(down_3_1)
pool_3 = self.pool_3(down_3_2)
down_4_1 = self.down_4_1(pool_3)
down_4_2 = self.down_4_2(down_4_1)
pool_4 = self.pool_4(down_4_2)
# bridge
bridge_1 = self.bridge(pool_4)
bridge_2 = self.bridge(bridge_1)
# up
upconv_4 = self.up_4_1(bridge_2) # up convolution
concat_4 = torch.cat((upconv_4, down_4_2), dim = 1)
up_4_2 = self.up_4_2(concat_4)
up_4_1 = self.up_4_3(up_4_2)
upconv_3 = self.up_3_1(up_4_1)
concat_3 = torch.cat((upconv_3, down_3_2), dim = 1)
up_3_2 = self.up_3_2(concat_3)
up_3_1 = self.up_3_3(up_3_2)
upconv_2 = self.up_2_1(up_3_1)
concat_2 = torch.cat((upconv_2, down_2_2), dim = 1)
up_2_2 = self.up_2_2(concat_2)
up_2_1 = self.up_2_3(up_2_2)
upconv_1 = self.up_1_1(up_2_1)
concat_1 = torch.cat((upconv_1, down_1_2), dim = 1)
up_1_2 = self.up_1_2(concat_1)
up_1_1 = self.up_1_3(up_1_2)
# out
x = self.fc(up_1_1)
return x
목표
Swin U-Net을 이해하기 위한 기반 지식을 쌓는다.
Abstract
제안할 내용: New network, Training strategy, strong Augmentation
- New Network; Architecture that enables precise localization1
- symmetric expanding path
- 특징 추출을 위한 contracting path
- 위치 파악을 위한 expanding path
- end to end
- speed
- symmetric expanding path
- 성능평가
1 segmentation에서 중요
2 아마도 당시 SOTA 방법론
3 segmentation of neuronal structures in electron microscopic stacks
Kvasir-SEG (dataset)
- MediaEval에서 공개된 데이터셋으로 위장관4 내시경5 이미지를 포함하고 있다.
- 8 classes, 1000 images per class, total 8000
- 검증된 의사들에 의해 annotate 되었다.
- 기준:
- 해부학적 지점 3개: Z line, pylorus 유문, cecum 맹장
- 병리학적 소견 3개: esophagitis 식도염, polyps 용종, ulcerative colitis 궤양성 대장염
- 용종 제거 과정 2개
그 외
Kvasir-SEG | CVC-ClinicDB | ACDC |
---|---|---|
- 대체로 U-Net 또는 U-Net based model이 벤치마크의 시작점에 있는 것을 확인할 수 있다.
- 많아봐야 40개 내외의 모델이 테스트 되었다.
왜 벤치마크를 사용하는 모델의 수가 적은 편일까?
- 일단 개인정보라는 점이 주된 문제로 보인다.
- 벤치마크가 아니라 각자가 선택한 데이터셋을 사용하는 이유는 대부분 해당 모델이 해결하려는 문제가 domain specific한 문제이기 때문 아닐까?
- 도메인 특성상 벤치마크 성능평가가 의료영상에서는 중요하지 않은 것 같다.
5 endoscopy
4 GI: Gastrointestinal의 약어, 위장관은 GI tract라고 한다.
특징 추출 없이 원본 데이터를 그대로 사용하는 모델을 E2E 모델이라고 한다.
이미지 분류의 예 |
---|
원래 인공지능이 이런게 아니었나? 생각해보면 아니다. 인공지능을 처음 배울때 하는 feature engineering 방법이 전통적 방식의 인공지능이다. 나는 머신러닝이라고 분류해서 부르고 있었던 방식이 알고보니 전통적 방식이었고, 지금의 인공지능 학습 방식은 대부분 E2E 방식이라고 할 수 있겠다.6 방대한 양의 데이터를 다루게 되면서 데이터 자체의 특수함까지 학습할 수 있게 되었다.
- 효율: 기존의 방식은 인간이 개입해야 하는 부분이 많았고 도메인 지식이 많은 영향을 미쳤는데 E2E는 그 과정을 건너뛰어도 된다.
- 한계: 일반적인 딥러닝의 한계다. 거대한 데이터셋이 필요하며 내부에서 일어나는 일을 일반적으로는 설명할 수 없다.
6 출처에서도 그렇게 부르는 것 같은데 일반적인 명명 방식인지는 모르겠다.
선행연구
- 선행연구 1에서는 레이어를 늘리고 데이터를 불려 심층적인 네트워크를 구축하는 것이었다.
- single class label을 결과값으로 하는 분류 모델이라는 한계점이 있다.
- 그러나 의료영상 분야에서는 지역화7가 필요하고 데이터셋을 수천개씩 구축할 수 없다.
- 선행연구 2는 각 픽셀의 클래스를 예측하는 방식으로 지역화에 성공했다. 패치를 기준으로 학습하기 때문에 적은양의 데이터로도 학습할 수 있고, 본 연구와 같은 challenge인 ISBI 2012에서 우승했다.
- 지역화 정확도가 높지 않다.
- 패치가 클수록 더 큰 Max Pooling layer가 필요하므로8 정확도가 낮아진다.
- 선행 연구 3은 FCN으로, pooling layer를 upsampling layers 로 대체하는 방식으로, 출력의 해상도9를 높일 수 있다.
- 즉, 레이어와 데이터를 무작정 키우는 것은 의료 영상 분야에서 방법이 되지 못하고, 각 픽셀을 분류함으로써 localization을 할 수 있으니 pooling layer를 upsampling layer로 대체하여 정확도와 출력 해상도를 높이는 방법이 현재 의료영상 분야에서 시행 또는 연구되고 있다.
7 localization
8 레이어를 줄일수록 feature map은 줄어드는데 크기는 지켜야 하므로
9 왜?
U-Net
U-Net |
One important modification in our architecture is that in the upsampling part we have also a large number of feature channels, which allow the network to propagate context information to higher resolution layers
‘Upsampling’ 과정에서 더 많은 feature channels 부분을 추가하는게 주된 아이디어다. 그 결과: Fully Connected Layer를 빼고 지역화에 필요한 기능을 수행하는 레이어만 사용하게 된다. 픽셀 예측을 위해 대칭된 위치에서10 입력 이미지를 참조한다.11
구조상 출력 해상도가 더 낮으므로 목표 이미지의 크기보다 더 큰 이미지를 입력한다. (extrapolation)
손실 함수는 cross-entropy 를 사용하는데, 거기에 가중치 함수를 곱하는 형태다.
\[E = \sum_{x \in \Omega} w(x) \log(p_{l(x)}(x))\]
왜? 세포를 대상으로 하는 모델이었으므로 세포를 명확히 구분하는 것이 중요하다. 따라서 작은 분리 경계 (small separation border)를 학습해야 하다. 이 작업을 \(w(x)\)가 배경 레이블에 높은 가중치를 부여함으로써 진행한다.
10 mirroring
11 tiling strategy, GPU 메모리 보존 가능
Auto Encoder
Skip-Connection
구현
- 모델만 구현해보았다.
- 상세 구조는 아래를 참고했다.
brain-segmentation-pytorch |