import torch import torch.nn as nn import torch.nn.functional as F from math import log2 import torchvision import torch.optim as optim import torchvision.datasets as datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import time
defforward(self, x): x = self.leaky(self.conv1(x)) x = self.pn(x) if self.use_pn else x x = self.leaky(self.conv2(x)) x = self.pn(x) if self.use_pn else x return x
# Create progression blocks and rgb layers channels = in_channels#16
# we need to double img for log2(img_size/4) and # +1 in loop for initial 4x4 #log2(512/4)=7,7+1=8 for idx inrange(int(log2(img_size/4)) + 1): conv_in = channels# 16 #特征图个数缩小的倍数:factors[idx] conv_out = int(in_channels*factors[idx])#16 #prog_blocks不改变特征图尺寸 self.prog_blocks.append(ConvBlock(conv_in, conv_out)) # rgb_layer就是把channel映射为3 #这里的kernel_size=1,是1*1卷积,也不会改变特征图尺寸 self.rgb_layers.append(WSConv2d(conv_out, img_channels, kernel_size=1, stride=1, padding=0)) channels = conv_out
deffade_in(self, alpha, upscaled, generated): #assert 0 <= alpha <= 1, "Alpha not between 0 and 1" #assert upscaled.shape == generated.shape return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
defforward(self, x, alpha, steps): #print('x:',x.shape)#torch.Size([5, 100, 1, 1]) upscaled = self.initial(x)#只有这一步用了转置卷积进行上采样,后面的上采样都是用的插值 #print('upscaled:',upscaled.shape)# torch.Size([5, 16, 4, 4]) out = self.prog_blocks[0](upscaled)#特征图尺寸不变 #print(self.prog_blocks) #print('out1:',out.shape)#torch.Size([5, 16, 4, 4]) #如果当前是第一层,那么无需fade_in,直接to_rgb图像:channel=3,over~ if steps == 0: return self.rgb_layers[0](out) #多次grow for step inrange(1, steps+1): #插值进行上采样,特征图尺寸变大2倍 upscaled = F.interpolate(out, scale_factor=2, mode="nearest") #print('upscaled in for loop:',upscaled.shape) out = self.prog_blocks[step](upscaled) #print('out in for loop:',out.shape) #print('out after:',out.shape)#torch.Size([5, 2, 512, 512])
# The number of channels in upscale will stay the same, while # out which has moved through prog_blocks might change. To ensure # we can convert both to rgb we use different rgb_layers # (steps-1) and steps for upscaled, out respectively final_upscaled = self.rgb_layers[steps - 1](upscaled)#倒数第二层 #print('final_upscaled:',final_upscaled.shape)#torch.Size([5, 3, 512, 512]) final_out = self.rgb_layers[steps](out)#最后一层 #print('final_out:',final_out.shape)#torch.Size([5, 3, 512, 512]) return self.fade_in(alpha, final_upscaled, final_out)#做fade_in
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) # +1 to in_channels because we concatenate from minibatch std self.conv = WSConv2d(in_channels + 1, z_dim, kernel_size=4, stride=1, padding=0) self.linear = nn.Linear(z_dim, 1)
deffade_in(self, alpha, downscaled, out): """Used to fade in downscaled using avgpooling and output from CNN""" #assert 0 <= alpha <= 1, "Alpha needs to be between [0, 1]" #assert downscaled.shape == out.shape return alpha * out + (1 - alpha) * downscaled
defforward(self, x, alpha, steps): out = self.rgb_layers[steps](x) # convert from rgb as initial step #print('out from rgb as initial step:',out.shape)#torch.Size([5, 2, 512, 512]) #如果是第一层,那么无需做fade_in,直接输出shape为[5,1],本层完成 if steps == 0: # i.e, image is 4x4 out = self.minibatch_std(out) out = self.conv(out) return self.linear(out.view(-1, out.shape[1])) #如果不是第一层的话,就需要做fade_in # index steps which has the "reverse" fade_in downscaled = self.rgb_layers[steps - 1](self.avg_pool(x))#前一层 #print('downscaled:',downscaled.shape)#torch.Size([5, 4, 256, 256]) out = self.avg_pool(self.prog_blocks[steps](out))#最后一层 #print('out after avg_pool:',out.shape)#torch.Size([5, 4, 256, 256]) #进行fade_in out = self.fade_in(alpha, downscaled, out) #print('out after fade_in:',out.shape)#
#后续操作主要是做下采样,将out映射到判别结果的格式 #跟生成器正好倒过来,主要是prog_blocks, #刚开始特征图个数是少的,后来逐渐变多 for step inrange(steps - 1, 0, -1): #通过平均池化进行下采样,特征图尺寸在减小 downscaled = self.avg_pool(out) #下采样结果放入prog_block out = self.prog_blocks[step](downscaled) #print('out bef:',out.shape)#torch.Size([5, 16, 4, 4]) #该操作会使得channel维度加一 out = self.minibatch_std(out) #print('out after minibatch_std:',out.shape)#torch.Size([5, 17, 4, 4]) out = self.conv(out) #print('out after conv:',out.shape)#torch.Size([5, 100, 1, 1]) #print(out.view(-1, out.shape[1]).shape)#[5,100] #print(self.linear(out.view(-1, out.shape[1])).shape)#torch.Size([5, 1]) return self.linear(out.view(-1, out.shape[1]))
defmain(): # initialize gen and disc, note: discriminator should be called critic, # according to WGAN paper (since it no longer outputs between [0, 1]) gen = Generator(Z_DIM, IN_CHANNELS, img_size=IMAGE_SIZE, img_channels=CHANNELS_IMG).to(device) critic = Discriminator(IMAGE_SIZE, Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)