DCGAN_V2.0
导入相关函数
python
1 | import tensorflow as tf |
准备数据
python
1 | (train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data() |
python
1 | train_images.shape |
python
1 | train_images.dtype |
python
1 | train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32') |
python
1 | train_images.shape |
python
1 | train_images.dtype |
python
1 | train_images=(train_images-127.5)/127.1#归一化 |
python
1 | BATCH_SIZE=256 |
python
1 | datasets=tf.data.Dataset.from_tensor_slices(train_images) |
python
1 | datasets |
python
1 | datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE) |
python
1 | datasets |
搭建生成器和判别器网络
python
1 | class Generator_model(tf.keras.Model): |
python
1 | class Discriminator_model(tf.keras.Model): |
定义损失函数
python
1 | cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True) |
python
1 | def discriminator_loss(real_out,fake_out): |
python
1 | def generator_loss(fake_out): |
定义优化器
python
1 | generator_opt=tf.keras.optimizers.Adam(1e-4) |
设置超参数,实例化生成器和判别器
python
1 | EPOCHS=100 |
定义每个batch的训练过程
python
1 | def train_step(images_one_batch): |
定义生成图片的展示函数
python
1 | #将test_noise送入gen_model,以产生假图片 |
定义训练函数
python
1 | def train(dataset,epochs): |
开始训练
python
1 | train(datasets,EPOCHS) |