网络结构

同普通GAN ,只是将生成器与判别器网络中的Dense层换为了卷积层与转置卷积层,故整体代码只需改动生成器和判别器的网络搭建函数即可。

导入相关函数

1
2
3
4
5
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

准备数据

1
(train_images,_),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
(60000, 28, 28)
1
train_images.dtype
dtype('uint8')
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
(60000, 28, 28, 1)
1
train_images.dtype
dtype('float32')
1
train_images=(train_images-127.5)/127.1#归一化
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices(train_images)
1
datasets
<TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

搭建生成器网络和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def generator_model():
model=tf.keras.Sequential()

model.add(layers.Dense(7*7*256,use_bias=False,input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Reshape((7,7,256)))

model.add(layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding='same',use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

return model


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def discriminator_model():
model=tf.keras.Sequential()

model.add(layers.Conv2D(64,(5,5),strides=(2,2),padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))

model.add(layers.Conv2D(128,(5,5),strides=(2,2),padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))

model.add(layers.Flatten())
model.add(layers.Dense(1))#判断真假,二分类

return model

定义损失函数

1
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
1
2
3
4
5
6
#real_out和fake_out都是预测的标签(0或1)
#分别代表真实图片和生成的假图片被送入判别器之后得到的标签
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
1
2
3
4
#fake_out是生成的假图片被送入判别器之后得到的标签
def generator_loss(fake_out):
fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out)
return fake_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
6
7
EPOCHS=100
noise_dim=100
num_example_to_generate=16
seed=tf.random.normal([num_example_to_generate,noise_dim])

generator=generator_model()
discriminator=discriminator_model()

定义每个batch训练的过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def train_step(images_one_batch):
noise=tf.random.normal([num_example_to_generate,noise_dim])#noise=seed
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测标签

gen_image=generator(noise,training=True)#生成的假图片
fake_out=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测标签

#分别计算两者的损失
gen_loss=generator_loss(fake_out)
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片展示的函数

1
2
3
4
5
6
7
8
9
#将test_noise送入gen_model,以产生假图片
def generate_plot_image(gen_model,test_noise):
pre_images=gen_model(test_noise,training=False)#此时无需训练生成器网络
fig=plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')
plt.axis('off')
plt.show()

定义训练函数

1
2
3
4
5
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
generate_plot_image(generator,seed)

开始训练

1
train(datasets,EPOCHS)

漫长的等待过后,最终的生成图片如下: