UNet

UNet是为了解决医学图像分割任务提出的,它基于FCN,并设计了更为巧妙的网络结构:
Alt text

由于网络结构看起来像一个U字母,因此得名UNet.

从整体来看,左侧是编码器部分,负责提取输入图像的高层语义信息;右侧是解码器部分,负责解码高层语义信息并输出最终的分割图。

在解码阶段,还引入了类似跳连的结构,将编码阶段的中间结果直接加了进来。

观察UNet的网络结构,可以发现:

编码器重复如下结构:两次卷积使得通道数加倍,后接一次池化使得特征图尺寸减半。

解码器重复如下结构:上采样使得特征图尺寸加倍,后接两次卷积使得通道数减半。

值得注意的是,在语义分割中,输入与输出的尺寸一般都是是一致的,而UNet却不同。

UNet是为医学图像分割任务而设计的,为了更好的预测图像边界区域,采用了Overlap-tile strategy,这种策略将待预测区域使用镜像进行了填充。

如下图所示,黄色部分是待预测区域,而真正输入UNet的是蓝色区域:
Alt text

PyTorch实现UNet

具体实现与原论文中有些许不同。

这里实现的UNet会保证输入与输出的尺寸一致,且能够适应任意(合理范围内,不能太小)的输入尺寸。

先导入所需库:

1
2
3
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

在论文中,UNet中的两次卷积会使得特征图尺寸发生改变,而在实现时,我们可以通过padding的方式使得这两次卷积操作不改变特征图的尺寸。

两次卷积的实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#只改变通道数,不改变特征图尺寸的卷积过程
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)

def forward(self, x):
return self.conv(x)

测试一下:
Alt text

看,特征图尺寸并不会发生变化。

观察UNet结构,特征图的通道数变化如下:

1
2
3
4
5
编码器:输入特征图通道数->64->128->256->512

中间:512->1024

解码器:1024->512->256->128->64->输出特征图通道数

因此可以用一个列表记录特通道数的变化:

1
features=[64, 128, 256, 512]

UNet实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class UNET(nn.Module):
def __init__(
self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
):
super(UNET, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
#池化下采样,使得特征图尺寸减半
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

# Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature

# Up part of UNET
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(
feature*2, feature, kernel_size=2, stride=2,
)
)
self.ups.append(DoubleConv(feature*2, feature))
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

def forward(self, x):
skip_connections = []

for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)

x = self.bottleneck(x)
skip_connections = skip_connections[::-1]

for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
#使得网络能够接收任意(合理范围内,不能太小)尺寸的输入
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])

concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx+1](concat_skip)

return self.final_conv(x)

ups和downs分别存储了编码器和解码器的网络层,bottleneck用于将通道数从512升到1024,final_conv用于将通道数变成总类别数。

为了使得网络能够适应任意尺寸(不能太小,否则卷积层无法处理)的输入,引入了resize操作。

在原论文中,上采样不改变通道数,而这里,需要将解码阶段的中间结果与上采样的结果做concat(类似跳连操作,只不过这里用的concat而不是直接相加),然后输入DoubleConv模块,因此在上采样的同时也将通道数减半。

测试一下,假设输入大小为(4,3,72,72),经过UNet,其维度变化如下:

Encoder:

1
2
3
4
5
6
7
8
1down: torch.Size([4, 64, 72, 72])
1down pool: torch.Size([4, 64, 36, 36])
2down: torch.Size([4, 128, 36, 36])
2down pool: torch.Size([4, 128, 18, 18])
3down: torch.Size([4, 256, 18, 18])
3down pool: torch.Size([4, 256, 9, 9])
4down: torch.Size([4, 512, 9, 9])
4down pool: torch.Size([4, 512, 4, 4])

bottleneck:

1
after bottleneck: torch.Size([4, 1024, 4, 4])

Decoder:

1
2
3
4
5
6
7
8
9
10
11
1up ConvTranspose2d: torch.Size([4, 512, 8, 8])

x.shape != skip_connection.shape,需要resize,resize之后的维度是:torch.Size([4, 512, 9, 9])

1up DoubleConv: torch.Size([4, 512, 9, 9])
2up ConvTranspose2d: torch.Size([4, 256, 18, 18])
2up DoubleConv: torch.Size([4, 256, 18, 18])
3up ConvTranspose2d: torch.Size([4, 128, 36, 36])
3up DoubleConv: torch.Size([4, 128, 36, 36])
4up ConvTranspose2d: torch.Size([4, 64, 72, 72])
4up DoubleConv: torch.Size([4, 64, 72, 72])

final_conv:

1
final output:torch.Size([4, 1, 72, 72])

参考: