导入相关函数 1 2 3 4 5 import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npfrom tensorflow.keras import layersfrom tensorflow.keras.utils import to_categorical
准备数据 1 (train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
(60000, 28, 28)
dtype('uint8')
1 train_images=train_images.reshape(train_images.shape[0 ],28 ,28 ,1 ).astype('float32' )
(60000, 28, 28, 1)
dtype('float32')
1 train_images=(train_images-127.5 )/127.1
1 2 BATCH_SIZE=256 BUFFER_SIZE=60000
1 datasets=tf.data.Dataset.from_tensor_slices(train_images)
<TensorSliceDataset shapes: (28, 28, 1), types: tf.float32>
1 datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
搭建生成器和判别器网络 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 class Generator_model (tf.keras.Model ): def __init__ (self ): super ().__init__() self.dense=tf.keras.layers.Dense(7 *7 *256 ,use_bias=False ) self.bn1=tf.keras.layers.BatchNormalization() self.leakyrelu1=tf.keras.layers.LeakyReLU() self.reshape=tf.keras.layers.Reshape((7 ,7 ,256 )) self.convT1=tf.keras.layers.Conv2DTranspose(128 ,(5 ,5 ),strides=(1 ,1 ),padding='same' ,use_bias=False ) self.bn2=tf.keras.layers.BatchNormalization() self.leakyrelu2=tf.keras.layers.LeakyReLU() self.convT2=tf.keras.layers.Conv2DTranspose(64 ,(5 ,5 ),strides=(2 ,2 ),padding='same' ,use_bias=False ) self.bn3=tf.keras.layers.BatchNormalization() self.leakyrelu3=tf.keras.layers.LeakyReLU() self.convT3=tf.keras.layers.Conv2DTranspose(1 ,(5 ,5 ),strides=(2 ,2 ),padding='same' ,use_bias=False ,activation='tanh' ) def call (self,inputs,c,training=True ): concated=tf.concat([inputs,c],axis=1 ) x=self.dense(concated) x=self.bn1(x,training) x=self.leakyrelu1(x) x=self.reshape(x) x=self.convT1(x) x=self.bn2(x,training) x=self.leakyrelu2(x) x=self.convT2(x) x=self.bn3(x,training) x=self.leakyrelu3(x) x=self.convT3(x) return x
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 class Discriminator_model (tf.keras.Model ): def __init__ (self ): super ().__init__() self.conv1=tf.keras.layers.Conv2D(64 ,(5 ,5 ),strides=(2 ,2 ),padding='same' ) self.leakyrelu1=tf.keras.layers.LeakyReLU() self.dropout1=tf.keras.layers.Dropout(0.3 ) self.conv2=tf.keras.layers.Conv2D(128 ,(5 ,5 ),strides=(2 ,2 ),padding='same' ) self.leakyrelu2=tf.keras.layers.LeakyReLU() self.dropout2=tf.keras.layers.Dropout(0.3 ) self.flatten=tf.keras.layers.Flatten() self.dense_validity=tf.keras.layers.Dense(1 ) self.dense_c=tf.keras.layers.Dense(10 ) def call (self,inputs,training=True ): x=self.conv1(inputs) x=self.leakyrelu1(x) x=self.dropout1(x,training) x=self.conv2(inputs) x=self.leakyrelu2(x) x=self.dropout2(x,training) features=self.flatten(x) validity=self.dense_validity(features) c_out=self.dense_c(features) return validity,c_out
定义损失函数 1 2 3 4 5 6 7 def mutual_info (c, c_given_x ): """The mutual information metric we aim to minimize""" eps = 1e-8 conditional_entropy = np.mean(- np.sum (np.log(c_given_x + eps) * c, axis=1 )) entropy = np.mean(- np.sum (np.log(c + eps) * c, axis=1 )) return conditional_entropy + entropy
1 cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True )
1 2 3 4 def discriminator_loss (real_out,fake_out ): real_loss=cross_entropy(tf.ones_like(real_out),real_out) fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out) return real_loss+fake_loss
1 2 3 4 5 def generator_loss (fake_out,c,fake_c_given_x ): fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out) mutual_info_loss=mutual_info(c,fake_c_given_x) return fake_loss+mutual_info_loss
定义优化器 1 2 generator_opt=tf.keras.optimizers.Adam(1e-4 ) discriminator_opt=tf.keras.optimizers.Adam(1e-4 )
设置超参数,实例化生成器和判别器 1 2 3 4 5 6 EPOCHS=100 noise_dim=100 generator=Generator_model() discriminator=Discriminator_model()
定义每个batch的训练过程 1 2 3 4 5 def sample_noise_and_c (batch_size ): noise=tf.random.normal(shape=(batch_size,100 )) c=np.random.randint(0 , 10 , batch_size).reshape(-1 , 1 ) c= to_categorical(c,num_classes=10 ) return noise,c
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 def train_step (images_one_batch ): noise,c=sample_noise_and_c(images_one_batch.shape[0 ]) with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape: real_out,real_c_given_x=discriminator(images_one_batch,training=True ) gen_image=generator(noise,c,training=True ) fake_out,fake_c_given_x=discriminator(gen_image,training=True ) gen_loss=generator_loss(fake_out,c,fake_c_given_x) disc_loss=discriminator_loss(real_out,fake_out) gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables) gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables) generator_opt.apply_gradients(zip (gradient_gen,generator.trainable_variables)) discriminator_opt.apply_gradients(zip (gradient_disc,discriminator.trainable_variables))
定义生成图片的展示函数 1 2 3 4 num_examples_to_generate=10 noise_seed,_=sample_noise_and_c(num_examples_to_generate) c_seed=np.random.randint(0 , 10 , num_examples_to_generate).reshape(-1 , 1 )
1 2 3 4 5 6 7 8 9 10 11 12 13 def generate_plot_image (): row, col = 10 , 10 fig, axs = plt.subplots(row, col) for i in range (col): sampled_noise, _ = sample_noise_and_c(row) c = to_categorical(np.full(fill_value=i, shape=(row,1 )), num_classes=10 ) gen_imgs = generator(sampled_noise,c,training=False ) for j in range (row): axs[j,i].imshow(gen_imgs[j,:,:,0 ], cmap='gray' ) axs[j,i].axis('off' )
定义训练函数 1 2 3 4 5 def train (dataset,epochs ): for epoch in range (epochs): for image_batch in dataset: train_step(image_batch) generate_plot_image()
开始训练