导入相关函数

1
2
3
4
5
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical

准备数据

1
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
(60000, 28, 28)
1
train_images.dtype
dtype('uint8')
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
(60000, 28, 28, 1)
1
train_images.dtype
dtype('float32')
1
train_images=(train_images-127.5)/127.1#归一化
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices(train_images)
1
datasets
<TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

搭建生成器和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Generator_model(tf.keras.Model):
def __init__(self):
super().__init__()

self.dense=tf.keras.layers.Dense(7*7*256,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.leakyrelu1=tf.keras.layers.LeakyReLU()

self.reshape=tf.keras.layers.Reshape((7,7,256))

self.convT1=tf.keras.layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False)
self.bn2=tf.keras.layers.BatchNormalization()
self.leakyrelu2=tf.keras.layers.LeakyReLU()

self.convT2=tf.keras.layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding='same',use_bias=False)
self.bn3=tf.keras.layers.BatchNormalization()
self.leakyrelu3=tf.keras.layers.LeakyReLU()

self.convT3=tf.keras.layers.Conv2DTranspose(1,(5,5),strides=(2,2),padding='same',use_bias=False,activation='tanh')

def call(self,inputs,c,training=True):

concated=tf.concat([inputs,c],axis=1)

x=self.dense(concated)
x=self.bn1(x,training)
x=self.leakyrelu1(x)

x=self.reshape(x)

x=self.convT1(x)
x=self.bn2(x,training)
x=self.leakyrelu2(x)

x=self.convT2(x)
x=self.bn3(x,training)
x=self.leakyrelu3(x)

x=self.convT3(x)

return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Discriminator_model(tf.keras.Model):
def __init__(self):
super().__init__()

self.conv1=tf.keras.layers.Conv2D(64,(5,5),strides=(2,2),padding='same')
self.leakyrelu1=tf.keras.layers.LeakyReLU()
self.dropout1=tf.keras.layers.Dropout(0.3)

self.conv2=tf.keras.layers.Conv2D(128,(5,5),strides=(2,2),padding='same')
self.leakyrelu2=tf.keras.layers.LeakyReLU()
self.dropout2=tf.keras.layers.Dropout(0.3)

self.flatten=tf.keras.layers.Flatten()

self.dense_validity=tf.keras.layers.Dense(1)
self.dense_c=tf.keras.layers.Dense(10)

def call(self,inputs,training=True):
x=self.conv1(inputs)
x=self.leakyrelu1(x)
x=self.dropout1(x,training)

x=self.conv2(inputs)
x=self.leakyrelu2(x)
x=self.dropout2(x,training)

features=self.flatten(x)

validity=self.dense_validity(features)
c_out=self.dense_c(features)


return validity,c_out

定义损失函数

1
2
3
4
5
6
7
def mutual_info(c, c_given_x):
"""The mutual information metric we aim to minimize"""
eps = 1e-8
conditional_entropy = np.mean(- np.sum(np.log(c_given_x + eps) * c, axis=1))
entropy = np.mean(- np.sum(np.log(c + eps) * c, axis=1))

return conditional_entropy + entropy
1
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
1
2
3
4
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
1
2
3
4
5
#fake_c_given_x是生成图片送入判别器后得到的c
def generator_loss(fake_out,c,fake_c_given_x):
fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out)
mutual_info_loss=mutual_info(c,fake_c_given_x)
return fake_loss+mutual_info_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
6
EPOCHS=100
noise_dim=100


generator=Generator_model()
discriminator=Discriminator_model()

定义每个batch的训练过程

1
2
3
4
5
def sample_noise_and_c(batch_size):
noise=tf.random.normal(shape=(batch_size,100))
c=np.random.randint(0, 10, batch_size).reshape(-1, 1)#随机选取一个数字c
c= to_categorical(c,num_classes=10)#将数字c编码成[10,1]的向量,即为:c向量
return noise,c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def train_step(images_one_batch):
noise,c=sample_noise_and_c(images_one_batch.shape[0])
#noise=tf.random.normal([num_examples_to_generate,noise_dim])#noise=seed
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
#real_c_given_x不是生成器所关心的,因此这里用不到
real_out,real_c_given_x=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测标签real_out

gen_image=generator(noise,c,training=True)
#生成器关心加图片送入判别器后输出的c,即:fake_c_given_x
fake_out,fake_c_given_x=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测标签fake_out

#分别计算两者的损失
gen_loss=generator_loss(fake_out,c,fake_c_given_x)
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片的展示函数

1
2
3
4
num_examples_to_generate=10

noise_seed,_=sample_noise_and_c(num_examples_to_generate)
c_seed=np.random.randint(0, 10, num_examples_to_generate).reshape(-1, 1)
1
2
3
4
5
6
7
8
9
10
11
12
13
def generate_plot_image():
row, col = 10, 10

fig, axs = plt.subplots(row, col)
for i in range(col):
sampled_noise, _ = sample_noise_and_c(row) # sampled_noise shape:[col,100]
c = to_categorical(np.full(fill_value=i, shape=(row,1)), num_classes=10)
#gen_input = np.concatenate((sampled_noise, c), axis=1)
gen_imgs = generator(sampled_noise,c,training=False)
#gen_imgs = 0.5 * gen_imgs + 0.5
for j in range(row):
axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')
axs[j,i].axis('off')

定义训练函数

1
2
3
4
5
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
generate_plot_image()

开始训练

1
train(datasets,EPOCHS)