导入相关函数

1
2
3
4
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers

准备数据

1
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
1
train_images.dtype
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
1
train_images.dtype
1
train_images=(train_images-127.5)/127.1#归一化
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices(train_images)
1
datasets
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets

搭建生成器和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Generator_model(tf.keras.Model):
def __init__(self):
super().__init__()

self.dense=tf.keras.layers.Dense(7*7*256,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.leakyrelu1=tf.keras.layers.LeakyReLU()

self.reshape=tf.keras.layers.Reshape((7,7,256))

self.convT1=tf.keras.layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False)
self.bn2=tf.keras.layers.BatchNormalization()
self.leakyrelu2=tf.keras.layers.LeakyReLU()

self.convT2=tf.keras.layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding='same',use_bias=False)
self.bn3=tf.keras.layers.BatchNormalization()
self.leakyrelu3=tf.keras.layers.LeakyReLU()

self.convT3=tf.keras.layers.Conv2DTranspose(1,(5,5),strides=(2,2),padding='same',use_bias=False,activation='tanh')

def call(self,inputs,training=True):
x=self.dense(inputs)
x=self.bn1(x,training)
x=self.leakyrelu1(x)

x=self.reshape(x)

x=self.convT1(x)
x=self.bn2(x,training)
x=self.leakyrelu2(x)

x=self.convT2(x)
x=self.bn3(x,training)
x=self.leakyrelu3(x)

x=self.convT3(x)

return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Discriminator_model(tf.keras.Model):
def __init__(self):
super().__init__()

self.conv1=tf.keras.layers.Conv2D(64,(5,5),strides=(2,2),padding='same')
self.leakyrelu1=tf.keras.layers.LeakyReLU()
self.dropout1=tf.keras.layers.Dropout(0.3)

self.conv2=tf.keras.layers.Conv2D(128,(5,5),strides=(2,2),padding='same')
self.leakyrelu2=tf.keras.layers.LeakyReLU()
self.dropout2=tf.keras.layers.Dropout(0.3)

self.flatten=tf.keras.layers.Flatten()

self.dense=tf.keras.layers.Dense(1)

def call(self,inputs,training=True):
x=self.conv1(inputs)
x=self.leakyrelu1(x)
x=self.dropout1(x,training)#一定要设置traing,否则出现warning,下同,具体见上一篇文章:warning-tensorflow-Grandients-do-not-exist

x=self.conv2(inputs)
x=self.leakyrelu2(x)
x=self.dropout2(x,training)

x=self.flatten(x)

x=self.dense(x)

return x

定义损失函数

1
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
1
2
3
4
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
1
2
3
def generator_loss(fake_out):
fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out)
return fake_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
6
7
EPOCHS=100
noise_dim=100
num_examples_to_generate=16
seed=tf.random.normal([num_examples_to_generate,noise_dim])

generator=Generator_model()
discriminator=discriminator_model()

定义每个batch的训练过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def train_step(images_one_batch):
noise=tf.random.normal([num_examples_to_generate,noise_dim])#noise=seed
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测标签

gen_image=generator(noise,training=True)
fake_out=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测标签

#分别计算两者的损失
gen_loss=generator_loss(fake_out)
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片的展示函数

1
2
3
4
5
6
7
8
9
#将test_noise送入gen_model,以产生假图片
def generate_plot_image(gen_model,test_noise):
pre_images=gen_model(test_noise,training=False)#此时无需训练生成器网络
fig=plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')
plt.axis('off')
plt.show()

定义训练函数

1
2
3
4
5
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
generate_plot_image(generator,seed)

开始训练

1
train(datasets,EPOCHS)

结果展示