网络结构

基础GANDCGAN(本文)的基础上,使用推土机距离衡量真实样本分布与生成样本分布之间的距离,此时,即使两个分布没有重合部分(这经常发生,容易导致梯度突变),也能准确的衡量分布的差异。


从上图可以看出,WGAN的梯度永远不会为0,而普通GAN会出现梯度为0的情况。

根据


可以确定判别器和生成器的损失为:


其中,1-Lipscgitz是指$||f(x_1)-f(x_2)||\le1*||x_1-x_2||$

也就是导数要≤1

这一点,可以通过对权值进行clip的方式,使得权值固定在某个区间内,从而也使得导数固定在某一区间内(比如≤1)

下面的代码在DCGAN的基础上,修改了判别器和生成器的损失函数,并在每次权值更新后做了权值裁剪,其余未变。

导入相关函数

1
2
3
4
5
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

准备数据

1
(train_images,_),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
(60000, 28, 28)
1
train_images.dtype
dtype('uint8')
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
(60000, 28, 28, 1)
1
train_images.dtype
dtype('float32')
1
train_images=(train_images-127.5)/127.1#归一化
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices(train_images)
1
datasets
<TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

搭建生成器网络和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def generator_model():
model=tf.keras.Sequential()

model.add(layers.Dense(7*7*256,use_bias=False,input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Reshape((7,7,256)))

model.add(layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

return model


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def discriminator_model():
model=tf.keras.Sequential()

model.add(layers.Conv2D(64,(5,5),strides=(2,2),padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))

model.add(layers.Conv2D(128,(5,5),strides=(2,2),padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))

model.add(layers.Flatten())
model.add(layers.Dense(1))#判断真假,二分类

return model

定义损失函数

1
2
3
4
5
6
#real_out和fake_out都是预测的输出值(连续型)
#分别代表真实图片和生成的假图片被送入判别器之后得到的输出值
def discriminator_loss(real_out,fake_out):
real_loss = - tf.reduce_mean(real_out)#越接近1越好(大),加负号就是变小了
fake_loss = tf.reduce_mean(fake_out)#越接近0越好(小)
return real_loss+fake_loss
1
2
3
4
#fake_out是生成的假图片被送入判别器之后得到的输出值
def generator_loss(fake_out):
fake_loss = - tf.reduce_mean(fake_out)#越接近1越好(大),加负号就是变小了
return fake_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)#原文使用的是非动量的优化器
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
6
7
EPOCHS=100
noise_dim=100
num_example_to_generate=16
seed=tf.random.normal([num_example_to_generate,noise_dim])

generator=generator_model()
discriminator=discriminator_model()

定义每个batch训练的过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def train_step(images_one_batch):
noise=tf.random.normal([num_example_to_generate,noise_dim])#noise=seed
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测标签

gen_image=generator(noise,training=True)#生成的假图片
fake_out=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测标签

#分别计算两者的损失
gen_loss=generator_loss(fake_out)
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

# Clip 判别器 weights,即判别器权重裁剪
clip_value=0.01
for l in discriminator.layers:
weights = l.get_weights()
weights = [np.clip(w, -clip_value, clip_value) for w in weights]
l.set_weights(weights)

定义生成图片展示的函数

1
2
3
4
5
6
7
8
9
#将test_noise送入gen_model,以产生假图片
def generate_plot_image(gen_model,test_noise):
pre_images=gen_model(test_noise,training=False)#此时无需训练生成器网络
fig=plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')
plt.axis('off')
plt.show()

定义训练函数

1
2
3
4
5
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
generate_plot_image(generator,seed)

开始训练

1
train(datasets,EPOCHS)

漫长的等待过后,最终的生成图片如下:

进一步思考

可以发现,最终的效果并没有之前的好了

猜想:可能是由于将权值限制在了[-0.01,0.01]之间(代码中clip_value=0.01)?

https://kionkim.github.io/2018/07/26/WGAN_3/ 里面说即使是WGAN甚至其改进版WGAN-GP,也无法很好的生成图片

参考资料

https://zhuanlan.zhihu.com/p/25071913

https://blog.csdn.net/u013289254/article/details/97561162

https://github.com/KUASWoodyLIN/TF2-WGAN/blob/master/utils/losses.py

(损失函数参考第三个链接)

https://github.com/keras-team/keras-contrib/issues/280

https://zhuanlan.zhihu.com/p/52799555

https://github.com/keras-team/keras-contrib/blob/3fc5ef709e061416f4bc8a92ca3750c824b5d2b0/examples/improved_wgan.py#L50

(推土机距离作为loss的代码中,有的直接写 mean(y_true * y_pred) ,比如上面的链接,此时的label是-1和1,但还是有点不明白,没有tf实现的代码中的loss容易理解。)