之前介绍了GAN的原理,并使用celeba数据集训练了一个基于DCGAN的”假”人脸生成器(传送门戳我),这里我把它的生成效果图搬运过来了
Alt text
在GAN问世后,其出色的表现使得对于GAN的研究一时风生水起(至今还在持续),越来越多关于GAN的研究成果被发表,GAN本身存在的缺陷也逐步被挖掘出来。

本文不会陷入繁杂的数学推导中,而是指出WGAN相比于原始GAN的改进之处,以及进一步提出的WGAN-GP,并动手用PyTorch进行实现。

WGAN

WGAN便是对于原始GAN的一种改进方案,它的作者用了大量篇幅指出了原始GAN的不足之处,并最终给出了自己的解决方案。虽然其中蕴含了大量的数学推导,但推导的结论却出乎意料的简单,或许这就是数学的魅力。

说完一堆废话后,来看看改进得到的WGAN相比于原始GAN有哪些改动,这里直接把WGAN作者给出的训练算法贴出来,然后做简要分析。
Alt text

510分别给出了判别器和生成器的损失函数,相比于原始GAN 的损失函数,仅仅是去掉了log。如果不能一下子看出来,可以把原始GAN的目标函数搬过来对比下。

原始GAN的优化目标:
Alt text

其中的E代表期望,由大数定律,我们可以用均值近似代替期望,这便有了WGAN中的$\frac 1m\sum_{i=1}^{m}$。

7将判别器的权值截断到一个指定的区间内,这使得判别器满足$Lipschitz$限制,其导数取值会被限制在这个区间内。

在原始的GAN中,判别器做的是二分类任务,所以最后一层是Sigmoid。但在WGAN中,判别器拟合的是$Wasserstein$距离,做的是一个回归任务,因此最后一层的输出不再需要Sigmoid了。

总结来说,WGAN相比于原始GANA 的改动如下:

(1) 损失函数不取log。原始GAN的损失函数有log,因此在代码实现时选择了交叉熵度量损失,而WGAN的损失函数不取log,从而可以根据大数定律用均值近似代替期望,这样在代码实现时用类似torch.mean的函数来度量损失即可;
(2) 将判别器的权值截断到一个指定的区间内;
(3) 去掉判别器最后一层的Simoid
(4) 不用基于动量的优化算法(momentum,Adam),这里按照原论文用RMSProp。

WGAN-GP

WGAN虽然在理论上改正了原始GAN 的一些问题,但是它对于判别器采用的梯度截断技巧,使得更新判别器时的梯度并不是真正的属于判别器的梯度,事实上,这不利于网络的训练。我之前用Keras训练过一次WGAN,大家可以看看效果
Alt text

嗯,效果还没有之前的DCGAN结果好。

主要原因就在于权值截断这个动作太过于一刀切

若判别器的权值被限制在很小的区间,那么梯度在经过很多层传播后可能会变为0(梯度消失);若判别器权值的取值范围很大,那么梯度在经过很多层传播后可能会飞出去(梯度爆炸)。也就是说,只有不断尝试,找到合适的限制范围,才能保证梯度是正常的。那有没有更好的解决方案呢?

所以有了接下来针对WGAN的改进:WGAN-GP

还是直接把WGAN-GP的算法贴出来

Alt text

它摒弃了权值截断这种粗暴的方式,而是采用添加正则项的方式来使得判别器满足$Lipschitz$限制。7中的$\lambda$后面就是该正则项,它迫使判别器的梯度与1接近,接近的具体程度由$\lambda$来控制。这等价于对判别器实施了$Lipschitz$限制。

用PyTorch实现WGAN-GP

本节将根据WGAN-GP的算法步骤,使用PyTorch,基于之前使用过的celeba数据集训练WGAN-GP模型,得到一个WGAN-GP人脸生成器。

导入依赖库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

参数设置+数据集准备

这里还是使用celeba人脸数据集。
Alt text

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

dataroot='celeba'
dataset = datasets.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)

搭建生成器和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Discriminator(nn.Module):
def __init__(self, channels_img, features_d):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# input: N x channels_img x 64 x 64
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
# _block(in_channels, out_channels, kernel_size, stride, padding)
self._block(features_d, features_d * 2, 4, 2, 1),
self._block(features_d * 2, features_d * 4, 4, 2, 1),
self._block(features_d * 4, features_d * 8, 4, 2, 1),
# After all _block img output is 4x4 (Conv2d below makes into 1x1)
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
)

def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
),
#注意不能用BN
nn.InstanceNorm2d(out_channels, affine=True),
nn.LeakyReLU(0.2),
)

def forward(self, x):
return self.disc(x)


class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# Input: N x channels_noise x 1 x 1
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
nn.ConvTranspose2d(
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
),
# Output: N x channels_img x 64 x 64
nn.Tanh(),
)

def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)

def forward(self, x):
return self.net(x)

Alt text
Alt text

权值初始化方案

1
2
3
4
5
# 权值初始化方案
def initialize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)

初始化网络+设置优化器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#这里的critic就是discriminator
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

#测试用
fixed_noise = torch.randn(49, Z_DIM, 1, 1).to(device)

# 优化器初始化
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

#设置为训练状态
gen.train()
critic.train()

梯度惩罚(GP)

这就是WGAN-GP中的GP,按照论文中给出的算法步骤实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def gradient_penalty(critic, real, fake, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
# 从均匀分布U(0,1)中采样
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
# 插值
interpolated_images = real * alpha + fake * (1 - alpha)

# 计算混合后的判别器输出
mixed_scores = critic(interpolated_images)

# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
# 打平
gradient = gradient.view(gradient.shape[0], -1)
# 计算梯度向量的2范数
gradient_norm = gradient.norm(2, dim=1)
# 计算最终的gp
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty

开始训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
img_list = []
print('training...')
for epoch in range(NUM_EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (real, _) in enumerate(dataloader):
real = real.to(device)
cur_batch_size = real.shape[0]

# 训练判别器: max E[critic(real)] - E[critic(fake)]
#等价于 min -(E[critic(real)] - E[critic(fake)])
for _ in range(CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
fake = gen(noise)
critic_real = critic(real).reshape(-1)
critic_fake = critic(fake).reshape(-1)
gp = gradient_penalty(critic, real, fake, device=device)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()
loss_critic.backward(retain_graph=True)
opt_critic.step()

# 训练生成器: max E[critic(gen_fake)]
# 等价于 min -E[critic(gen_fake)]
gen_fake = critic(fake).reshape(-1)
loss_gen = -torch.mean(gen_fake)
gen.zero_grad()
loss_gen.backward()
opt_gen.step()

# 打印训练过程信息
if batch_idx % 100 == 0 and batch_idx > 0:
print(
f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
)
#测试
with torch.no_grad():
fake = gen(fixed_noise).detach().cpu()
# 取49张图片
img_list.append(torchvision.utils.make_grid(fake[:49], normalize=True))

Alt text

人脸生成质量演变过程可视化

1
2
3
4
5
6
7
8
fig = plt.figure(figsize=(8,8))
plt.axis("off")
#需要事先安装ffmpeg:https://www.gyan.dev/ffmpeg/builds/packages/
#plt.rcParams['animation.ffmpeg_path'] = r'D:\ffmpeg\bin\ffmpeg.exe'
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list ]
ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000, blit=True)
#Writer = animation.FFMpegWriter(fps=20, metadata=dict(artist='Me'))
ani.save('out.gif', writer = 'pillow')

Alt text

一开始生成的是无意义的噪声图像,随着训练迭代次数的增加,生成的图像越来越接近人脸,不过总体来说还是与真实人脸图像有些差距。

从中可以很清楚的观察到WGAN-GP相较于原始GAN的一个优点,那就是WGN-GP更加稳定。具体来说,WGAN-GP生成的人脸的大致形状在若干次迭代后基本固定了,在后续的迭代中,做的更多的是对生成人脸的细节进行微调,而没有大幅度的变动。

参考: