什么是CGAN

所谓CGAN,就是在GAN的基础上,多施加了一些条件信息,比如图像的标签等,使得生成器可以按照我们指定的标签去生成所对应的图像。

普通GAN的目标函数为:

而CGAN的目标函数为:

CGAN的网络结构如下:

如何构建CGAN

本文在普通GAN(全连接层搭建)的基础上,将生成器的输入由“噪声”改为“噪声+对应该批次图像的真实标签”,将判别器的输入由”图像”改为”图像+对应该批次图像的真实标签”,最后在测试生成器的生成能力时,人为构建了0到9这10个数字作为标签(因为训练数据是mnist数据集),和随机噪声一起喂入生成器以产生新的图片。

Tensorflow2.0 实现CGAN

导入所需函数

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Input,multiply,Flatten,Embedding
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
import glob
import os
from tensorflow.keras.utils import to_categorical

准备数据

1
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
(60000, 28, 28)
1
train_labels.shape
(60000,)
1
train_images.dtype
dtype('uint8')
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
(60000, 28, 28, 1)
1
train_images.dtype
dtype('float32')
1
train_images=(train_images-127.5)/127.1#归一化
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices((train_images,train_labels))
1
datasets
<TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.uint8)>
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets
<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.uint8)>

搭建生成器网络和判别器网络

与之前的代码不同,这里使用了Embedding层,将输入的noise(in generator)或图像(in discriminator) 与label编码在一起了,并且最后返回了使用了Model来构建的模型,其中的两个参数分别是生成器和判别器型对应的输入和输出(注意并不是我们所定义的生成器和判别器的函数对应的输入与输出)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def generator_model():
model=tf.keras.Sequential()

#model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
model.add(layers.Dense(256, input_dim=100))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(512,use_bias=False))#因为有BN层
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))#因为有BN层
model.add(layers.BatchNormalization())

model.add(layers.Reshape((28,28,1)))

noise_dim=100
noise = Input(shape=(noise_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(10, noise_dim)(label)) # class, z dimension,由于这里是一维的,所以不打平也可以

model_input = multiply([noise, label_embedding]) # 把 label 和 noise embedding 在一起,作为 model 的输入

img = model(model_input) # output (28,28,1)

return Model([noise,label],img)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def discriminator_model():    
model=tf.keras.Sequential()
model.add(layers.Dense(512))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(256,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(1))#判断真假,二分类

images = Input(shape=(28, 28, 1)) # 输入 (28,28,1)
label = Input(shape=(1,), dtype='int32')

label_embedding = Flatten()(Embedding(10, np.prod((28, 28, 1)))(label))
flat_img = Flatten()(images)

model_input = multiply([flat_img, label_embedding])

validity = model(model_input)

return Model([images,label],validity)

定义损失函数

1
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
1
2
3
4
5
6
#real_out和fake_out都是预测的标签(0或1)
#分别代表真实图片和生成的假图片被送入判别器之后得到的标签
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
1
2
3
4
#fake_out是生成的假图片被送入判别器之后得到的标签
def generator_loss(fake_out):
fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out)
return fake_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
EPOCHS=100
noise_dim=100
#num_example_to_generate=16
1
2
generator=generator_model()
discriminator=discriminator_model()

定义每个batch训练的过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def train_step(images_one_batch,batch_labels):
#noise=tf.random.normal([num_example_to_generate,noise_dim])#这里生成器生成的图片个数必须和batch中样本数一样,因为用到了batch中的样本label
noise = tf.random.normal([batch_labels.shape[0],noise_dim])
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out=discriminator((images_one_batch,batch_labels),training=True)#真实图片送入判别器之后得到的预测标签

gen_image=generator((noise,batch_labels),training=True)#用原来的label训练生成器
fake_out=discriminator((gen_image,batch_labels),training=True)#生成的假图片送入判别器之后得到的预测标签

#分别计算两者的损失
gen_loss=generator_loss(fake_out)
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片展示的函数

1
2
3
4
5
6
7
8
9
10
11
#将test_noise和test_label送入gen_model,以产生指定label的假图片
def generate_plot_image(gen_model,test_noise,test_label,epoch):
pre_images=gen_model((test_noise,test_label),training=False)#此时无需训练生成器网络
fig=plt.figure(figsize=(5,18))
#pred = tf.squeeze(pred)
for i in range(10):
plt.subplot(1,10, i + 1)
plt.imshow(pre_images[i, :, :,0],cmap='Greys')
plt.axis('off')
plt.savefig('image/image_of_epoch{:04d}.png'.format(epoch))
plt.close()

定义训练函数

1
2
3
4
5
6
7
8
9
#noise_seed=tf.random.normal([num_example_to_generate,noise_dim])
noise_seed=tf.random.normal([10,noise_dim])#noise_dim=000
label_seed=np.array([[i] for i in range(10)])#指定每次分别生成0到9这10个数字

def train(dataset,epochs):
for epoch in range(epochs):
for image_batch,label in dataset:
train_step(image_batch,label)
generate_plot_image(generator,noise_seed,label_seed,epoch)

开始训练

1
train(datasets,EPOCHS)

训练结果

刚开始的时候,生成器的能力是逐步提高的,但随着训练轮次的继续增加,生成能力反而下降了,如下图所示:

之后又训练了一次,这次的生成能力是逐步提升的,最后一个轮次生成的图片如下:

这可能是基于普通GAN训练不稳定造成的,之后会考虑换其他GAN,再次实验。

经查资料,这里给出了可能的解决方案:https://www.zhihu.com/question/298003702/answer/1135515657

参考资料

https://github.com/mabagheri/CGAN/blob/master/cgan_Keras.py

https://ustccoder.github.io/2020/05/31/generative_adversarial%20CGAN/

https://blog.csdn.net/liyihao17/article/details/106592738/?utm_medium=distribute.pc_relevant_download.none-task-blog-baidujs-1.nonecase&depth_1-utm_source=distribute.pc_relevant_download.none-task-blog-baidujs-1.nonecase

https://machinelearningmastery.com/how-to-develop-a-conditional-generative-adversarial-network-from-scratch/