deftrain_step(images_one_batch): noise=tf.random.normal([num_example_to_generate,noise_dim])#noise=seed with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape: real_out=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测标签 gen_image=generator(noise,training=True)#生成的假图片 fake_out=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测标签 #分别计算两者的损失 gen_loss=generator_loss(fake_out) disc_loss=discriminator_loss(real_out,fake_out) #求可训练参数的梯度 gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables) gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables) #使用优化器更新可训练参数的权值 generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables)) discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables)) # Clip 判别器 weights,即判别器权重裁剪 clip_value=0.01 for l in discriminator.layers: weights = l.get_weights() weights = [np.clip(w, -clip_value, clip_value) for w in weights] l.set_weights(weights)
定义生成图片展示的函数
1 2 3 4 5 6 7 8 9
#将test_noise送入gen_model,以产生假图片 defgenerate_plot_image(gen_model,test_noise): pre_images=gen_model(test_noise,training=False)#此时无需训练生成器网络 fig=plt.figure(figsize=(4,4)) for i inrange(pre_images.shape[0]): plt.subplot(4,4,i+1) plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') plt.axis('off') plt.show()
定义训练函数
1 2 3 4 5
deftrain(dataset,epochs): for epoch inrange(epochs): for image_batch in dataset: train_step(image_batch) generate_plot_image(generator,seed)