网络结构


导入相关函数

1
2
3
4
5
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

准备数据

1
(train_images,_),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
(60000, 28, 28)
1
train_images.dtype
dtype('uint8')
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
(60000, 28, 28, 1)
1
train_images.dtype
dtype('float32')
1
train_images=(train_images-127.5)/127.1#归一化
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices(train_images)
1
datasets
<TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

搭建生成器网络和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def generator_model():
model=tf.keras.Sequential()
model.add(layers.Dense(256,input_shape=(100,),use_bias=False))#因为有BN层
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(512,use_bias=False))#因为有BN层
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))#因为有BN层
model.add(layers.BatchNormalization())

model.add(layers.Reshape((28,28,1)))

return model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def discriminator_model():
model=tf.keras.Sequential()
model.add(layers.Flatten())
model.add(layers.Dense(512,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(256,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(1))#判断真假,二分类

return model

定义损失函数

1
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
1
2
3
4
5
6
#real_out和fake_out都是预测的标签(0或1)
#分别代表真实图片和生成的假图片被送入判别器之后得到的标签
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
1
2
3
4
#fake_out是生成的假图片被送入判别器之后得到的标签
def generator_loss(fake_out):
fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out)
return fake_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
6
7
EPOCHS=100
noise_dim=100
num_example_to_generate=16
seed=tf.random.normal([num_example_to_generate,noise_dim])

generator=generator_model()
discriminator=discriminator_model()

定义每个batch训练的过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def train_step(images_one_batch):
noise=tf.random.normal([num_example_to_generate,noise_dim])#noise=seed
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测标签

gen_image=generator(noise,training=True)#生成的假图片
fake_out=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测标签

#分别计算两者的损失
gen_loss=generator_loss(fake_out)
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片展示的函数

1
2
3
4
5
6
7
8
9
#将test_noise送入gen_model,以产生假图片
def generate_plot_image(gen_model,test_noise):
pre_images=gen_model(test_noise,training=False)#此时无需训练生成器网络
fig=plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')
plt.axis('off')
plt.show()

定义训练函数

1
2
3
4
5
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
generate_plot_image(generator,seed)

开始训练

1
train(datasets,EPOCHS)

漫长的等待过后,最终的生成图片如下:


参考代码

https://blog.auberginesolutions.com/implementing-gan-using-tensorflow-2-0/

https://www.bilibili.com/video/BV1f7411E7wU?

More

  • 可以考虑去掉BatchNormalization层,再次实验。

  • 另一个版本的代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    from tensorflow.keras.layers import Input,Dense,Reshape,Flatten
    from tensorflow.keras.layers import BatchNormalization,Activation,ZeroPadding2D
    from tensorflow.keras.layers import LeakyReLU
    from tensorflow.keras.models import Sequential,Model
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.datasets import mnist
    import matplotlib.pyplot as plt
    import os
    import numpy as np


    class GAN():
    def __init__(self):
    self.img_rows,self.img_cols,self.channels=28,28,1
    self.img_shape=(self.img_rows,self.img_cols,self.channels)
    #隐含变量的属性个数为100
    self.latent_dim=100
    #优化器
    optimizer=Adam(0.0002,0.5)

    #判别器
    self.discriminator=self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])

    #生成器
    self.generator=self.build_generator()
    gan_input=Input(shape=(self.latent_dim,))
    img=self.generator(gan_input)

    #在训练generator的时候不训练discriminator
    self.discriminator.trainable=False

    #对生成的假图片进行预测
    validity=self.discriminator(img)
    self.combined=Model(gan_input,validity)
    self.combined.compile(loss='binary_crossentropy',optimizer=optimizer)

    #构建生成器
    def build_generator(self):
    model=Sequential()
    model.add(Dense(256,input_dim=self.latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(self.img_shape),activation='tanh'))
    model.add(Reshape(self.img_shape))

    noise=Input(shape=(self.latent_dim,))
    img=model(noise)

    return Model(noise,img)

    def build_discriminator(self):
    model=Sequential()
    model.add(Flatten(input_shape=self.img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(1,activation='sigmoid'))

    #分类判真伪:validity
    img=Input(shape=self.img_shape)
    validity=model(img)

    return Model(img,validity)

    def sample_images(self,epoch):
    r,c = 5,5
    noise = np.random.normal(0, 1, (r * c, self.latent_dim))
    gen_imgs = self.generator.predict(noise)

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
    for j in range(c):
    axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
    axs[i, j].axis('off')
    cnt += 1
    fig.savefig("gan_mnist/%d.png" % epoch)
    plt.close()

    def train(self,epochs,batch_size=128,sample_interval=50):
    (X_train,_),(_,_)=mnist.load_data()
    #归一化到[-1,1]
    X_train=X_train/127.5-1.
    X_train=np.expand_dims(X_train,axis=3)

    #创建标签
    valid=np.ones((batch_size,1))
    fake=np.zeros((batch_size,1))

    for epoch in range(epochs):
    #训练判别器discriminator
    idx=np.random.randint(0,X_train.shape[0],batch_size)
    imgs=X_train[idx]

    noise=np.random.normal(0,1,(batch_size,self.latent_dim))

    gen_imgs=self.generator.predict(noise)#采样的噪声生成的图片

    # train_on_batch返回loss(若模型编译时赋予metrics=['acc']则还返回acc)
    d_loss_real=self.discriminator.train_on_batch(imgs,valid)#真实图片与1差距越小越好
    d_loss_fake=self.discriminator.train_on_batch(gen_imgs,fake)#假图片与0差距越小越好
    d_loss=0.5*np.add(d_loss_real,d_loss_fake)#总的loss是两部分之和


    #训练generator
    noise=np.random.normal(0,1,(batch_size,self.latent_dim))
    g_loss=self.combined.train_on_batch(noise,valid)#计算噪声与1的距离,越小越好
    print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

    if epoch%sample_interval==0:
    self.sample_images(epoch)

    if __name__ == '__main__':
    if not os.path.exists("./gan_mnist"):
    os.makedirs("./gan_mnist")
    gan = GAN()
    gan.train(epochs=3000, batch_size=256, sample_interval=200)