# Down part of UNET for feature in features: self.downs.append(DoubleConv(in_channels, feature)) in_channels = feature
# Up part of UNET for feature inreversed(features): self.ups.append( nn.ConvTranspose2d( feature*2, feature, kernel_size=2, stride=2, ) ) self.ups.append(DoubleConv(feature*2, feature)) self.bottleneck = DoubleConv(features[-1], features[-1]*2) self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
defforward(self, x): skip_connections = []
for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x)
x = self.bottleneck(x) skip_connections = skip_connections[::-1]
for idx inrange(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx//2] #使得网络能够接收任意(合理范围内,不能太小)尺寸的输入 if x.shape != skip_connection.shape: x = TF.resize(x, size=skip_connection.shape[2:])
concat_skip = torch.cat((skip_connection, x), dim=1) x = self.ups[idx+1](concat_skip)