import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.layers import Input,multiply,Flatten,Embedding from tensorflow.keras.models import Model import matplotlib.pyplot as plt import numpy as np import glob import os from tensorflow.keras.utils import to_categorical
deftrain_step(images_one_batch,batch_labels): #noise=tf.random.normal([num_example_to_generate,noise_dim])#这里生成器生成的图片个数必须和batch中样本数一样,因为用到了batch中的样本label noise = tf.random.normal([batch_labels.shape[0],noise_dim]) with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape: real_out=discriminator((images_one_batch,batch_labels),training=True)#真实图片送入判别器之后得到的预测标签 gen_image=generator((noise,batch_labels),training=True)#用原来的label训练生成器 fake_out=discriminator((gen_image,batch_labels),training=True)#生成的假图片送入判别器之后得到的预测标签 #分别计算两者的损失 gen_loss=generator_loss(fake_out) disc_loss=discriminator_loss(real_out,fake_out) #求可训练参数的梯度 gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables) gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables) #使用优化器更新可训练参数的权值 generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables)) discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
定义生成图片展示的函数
1 2 3 4 5 6 7 8 9 10 11
#将test_noise和test_label送入gen_model,以产生指定label的假图片 defgenerate_plot_image(gen_model,test_noise,test_label,epoch): pre_images=gen_model((test_noise,test_label),training=False)#此时无需训练生成器网络 fig=plt.figure(figsize=(5,18)) #pred = tf.squeeze(pred) for i inrange(10): plt.subplot(1,10, i + 1) plt.imshow(pre_images[i, :, :,0],cmap='Greys') plt.axis('off') plt.savefig('image/image_of_epoch{:04d}.png'.format(epoch)) plt.close()
定义训练函数
1 2 3 4 5 6 7 8 9
#noise_seed=tf.random.normal([num_example_to_generate,noise_dim]) noise_seed=tf.random.normal([10,noise_dim])#noise_dim=000 label_seed=np.array([[i] for i inrange(10)])#指定每次分别生成0到9这10个数字
deftrain(dataset,epochs): for epoch inrange(epochs): for image_batch,label in dataset: train_step(image_batch,label) generate_plot_image(generator,noise_seed,label_seed,epoch)