网络结构

ACGAN 可以看作是CGAN和SGAN的融合:

  • 模仿CGAN,将类别标签class也输入生成器
  • 模仿SGAN,判别器不仅仅输出真假,还充当分类器

搭建模型

由于ACGAN可以看作是CGAN和SGAN的融合,因此在代码实现上也是综合了两者的代码,主要修改部分为损失函数和两个网络结构,具体见代码注释

ACGAN的损失部分如下:

判别损失:

分类损失:

判别器的损失为$L_C+L_S$,意思是让判别器既能判别图片的真伪,又有不错的分类能力;

生成器的损失为$L_C-L_S$,意思是让判别器不能判别图片的真伪,但还有不错的分类能力。

导入相关函数

1
2
3
4
5
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical

准备数据

1
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
1
train_images.shape
(60000, 28, 28)
1
train_images.dtype
dtype('uint8')
1
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
1
train_images.shape
(60000, 28, 28, 1)
1
train_images.dtype
dtype('float32')
1
train_images=(train_images-127.5)/127.1#归一化
1
train_labels.shape
(60000,)
1
2
BATCH_SIZE=256
BUFFER_SIZE=60000
1
datasets=tf.data.Dataset.from_tensor_slices((train_images,train_labels))
1
datasets
<TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.uint8)>
1
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
1
datasets
<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.uint8)>

搭建生成器和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Generator_model(tf.keras.Model):
def __init__(self):
super().__init__()

###先把noise和输入的label进行融合
self.noise_dim=100
self.embedding=tf.keras.layers.Embedding(10,noise_dim)
self.multiply=tf.keras.layers.Multiply()#得到一个noise_dim维度的长向量

###前向传播逻辑
self.dense=tf.keras.layers.Dense(7*7*256,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.leakyrelu1=tf.keras.layers.LeakyReLU()

self.reshape=tf.keras.layers.Reshape((7,7,256))

self.convT1=tf.keras.layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False)
self.bn2=tf.keras.layers.BatchNormalization()
self.leakyrelu2=tf.keras.layers.LeakyReLU()

self.convT2=tf.keras.layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding='same',use_bias=False)
self.bn3=tf.keras.layers.BatchNormalization()
self.leakyrelu3=tf.keras.layers.LeakyReLU()

self.convT3=tf.keras.layers.Conv2DTranspose(1,(5,5),strides=(2,2),padding='same',use_bias=False,activation='tanh')


def call(self,inputs,labels,training=True):#inputs是输入的noise,labels是真实图片的标签

label_embedding=self.embedding(labels)
x=self.multiply([inputs,label_embedding])

x=self.dense(x)
x=self.bn1(x,training)
x=self.leakyrelu1(x)

x=self.reshape(x)

x=self.convT1(x)
x=self.bn2(x,training)
x=self.leakyrelu2(x)

x=self.convT2(x)
x=self.bn3(x,training)
x=self.leakyrelu3(x)

x=self.convT3(x)

return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Discriminator_model(tf.keras.Model):
def __init__(self,num_classes):
self.num_classes=num_classes
super().__init__()

self.conv1=tf.keras.layers.Conv2D(64,(5,5),strides=(2,2),padding='same')
self.leakyrelu1=tf.keras.layers.LeakyReLU()
self.dropout1=tf.keras.layers.Dropout(0.3)

self.conv2=tf.keras.layers.Conv2D(128,(5,5),strides=(2,2),padding='same')
self.leakyrelu2=tf.keras.layers.LeakyReLU()
self.dropout2=tf.keras.layers.Dropout(0.3)

self.flatten=tf.keras.layers.Flatten()

#真实图片还是生成图片,二分类
self.dense_valid=tf.keras.layers.Dense(1,activation='sigmoid')

#分类器
self.dense_label=tf.keras.layers.Dense(self.num_classes)
self.softmax=tf.keras.layers.Softmax()

def call(self,inputs,training=True):
x=self.conv1(inputs)
x=self.leakyrelu1(x)
x=self.dropout1(x,training)

x=self.conv2(inputs)
x=self.leakyrelu2(x)
x=self.dropout2(x,training)

features=self.flatten(x)#特征

valid=self.dense_valid(features)#真or假

label=self.dense_label(features)
label=self.softmax(label)#属于哪一类

return valid,label#真实图片还是生成图片:valid;属于哪一类(n_classes,fake):label

定义损失函数

1
2
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)
categorical_cross_entropy=tf.keras.losses.CategoricalCrossentropy()#接受one_hot形式
1
2
3
4
5
6
def discriminator_loss(real_out,fake_out,real_pred_labels,to_categorical_real_labels):
#判别损失+分类损失
real_loss=cross_entropy(tf.ones_like(real_out),real_out)+categorical_cross_entropy(real_pred_labels,to_categorical_real_labels)
#fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)+categorical_cross_entropy(fake_pred_labels,to_categorical_real_labels)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)#只保留判别损失,不同于SGAN,这里只有10类,没有fake类?
return real_loss+fake_loss
1
2
3
4
5
def generator_loss(fake_out,fake_pred_labels,to_categorical_real_labels):
#判别损失+分类损失
#其中分类损失指的是:让判别器输出的类别和图片的真实图片的类别越接近越好,因为生成器在生成图片时使用了真实图片的标签
fake_loss=cross_entropy(tf.ones_like(fake_out),fake_out)+categorical_cross_entropy(fake_pred_labels,to_categorical_real_labels)
return fake_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
6
7
EPOCHS=100
noise_dim=100
#num_examples_to_generate=16
#seed=tf.random.normal([num_examples_to_generate,noise_dim])

generator=Generator_model()
discriminator=Discriminator_model(num_classes=10)

定义每个batch的训练过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def train_step(images_one_batch,one_batch_labels):
noise=tf.random.normal([images_one_batch.shape[0],noise_dim])#noise=seed
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out,real_pred_labels=discriminator(images_one_batch,training=True)#真实图片送入判别器之后得到的预测真假标签,预测类别

#sampled_labels=np.random.randint(0,10,(images_one_batch.shape[0],1))#github上的纯keras用这个来表示"要生成的图片的标签信息",但我感觉还是用输入的真实图片的label比较贴合原文?
gen_image=generator(noise,one_batch_labels,training=True)
fake_out,fake_pred_labels=discriminator(gen_image,training=True)#生成的假图片送入判别器之后得到的预测真假标签,预测类别
#在ACGAN中,fake_pred_labels并没有用到(分类器只有10类,没有fake类别),写出来只是为了和SGAN的代码统一下

to_categorical_real_labels=to_categorical(one_batch_labels,num_classes=10)
#ACGAN分类的类别只有10类(手写数字0-9),不包含fake类别,因此下面一行代码被注释掉了
#to_categorical_fake_labels=to_categorical(np.full((images_one_batch.shape[0],1),10),num_classes=10)#类别10代表假,真实样本的类别为0

#分别计算两者的损失
gen_loss=generator_loss(fake_out,fake_pred_labels,to_categorical_real_labels)
disc_loss=discriminator_loss(real_out,fake_out,real_pred_labels,to_categorical_real_labels)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片的展示函数

1
2
3
4
5
6
7
8
9
10
#将test_noise送入gen_model,以产生假图片
def generate_plot_image(gen_model,test_noise,test_label,epoch):
pre_images=gen_model(test_noise,test_label,training=False)#此时无需训练生成器网络
fig=plt.figure(figsize=(5,18))
#pred = tf.squeeze(pred)
for i in range(10):
plt.subplot(1,10, i + 1)
plt.imshow(pre_images[i, :, :,0],cmap='Greys')
plt.axis('off')
plt.show()

定义训练函数

1
2
3
4
5
6
7
8
9
10
#noise_seed=tf.random.normal([num_example_to_generate,noise_dim])
noise_seed=tf.random.normal([10,noise_dim])
#print(noise_seed)
label_seed=np.array([i for i in range(10)])

def train(dataset,epochs):
for epoch in range(epochs):
for image_batch,label in dataset:
train_step(image_batch,label)
generate_plot_image(generator,noise_seed,label_seed,epoch)
1
#generator(noise_seed,label_seed,training=False)

开始训练

1
train(datasets,EPOCHS)

本模型最终的结果如下

训练过程中报warning

1
WARNING:tensorflow:Gradients do not exist for variables ['discriminator_model/conv2d/kernel:0', 'discriminator_model/conv2d/bias:0'] when minimizing the loss.

猜想原因和之前一样,毕竟判别器的网络结构和SGAN中的是一样的。

具体见:

https://fx0809.gitee.io/2020/10/15/SGAN/

注意!并不是代码改成了继承自Model类的风格才导致的梯度问题,因为我把基础GAN

https://fx0809.gitee.io/2020/10/06/%E5%9F%BA%E7%A1%80GAN/

中的判别器代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def discriminator_model():
model=tf.keras.Sequential()

model.add(layers.Flatten())
model.add(layers.Dense(512,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(256,use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())

model.add(layers.Dense(1))#判断真假,二分类

return model

改成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Discriminator_model(tf.keras.Model):
def __init__(self,num_classes):
self.num_classes=num_classes
super().__init__()

self.flatten=tf.keras.layers.Flatten()
self.dense1=tf.keras.layers.Dense(512,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.leakyrelu1=tf.keras.layers.LeakyReLU()

self.dense2=tf.keras.layers.Dense(256)
self.bn2=tf.keras.layers.BatchNormalization()
self.leakyrelu2=tf.keras.layers.LeakyReLU()


self.dense=tf.keras.layers.Dense(1)

def call(self,inputs,training=True):
x=self.flatten(inputs)

x=self.dense1(x)
x=self.bn1(x,training)
x=self.leakyrelu1(x)

x=self.dense2(x)
x=self.bn2(x,training)
x=self.leakyrelu2(x)

return self.dense(x)

之后,模型仍可以正常训练。

【试错】把判别器的网络改成全连接层实现

搭建模型过程中判别器的代码改为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Discriminator_model(tf.keras.Model):
def __init__(self,num_classes):
self.num_classes=num_classes
super().__init__()

self.flatten=tf.keras.layers.Flatten()

self.dense1=tf.keras.layers.Dense(512,use_bias=False)
self.bn1=tf.keras.layers.BatchNormalization()
self.leakyrelu1=tf.keras.layers.LeakyReLU()

self.dense2=tf.keras.layers.Dense(256)
self.bn2=tf.keras.layers.BatchNormalization()
self.leakyrelu2=tf.keras.layers.LeakyReLU()



#真实图片还是生成图片,二分类
self.dense_valid=tf.keras.layers.Dense(1,activation='sigmoid')

#分类器
self.dense_label=tf.keras.layers.Dense(self.num_classes)
self.softmax=tf.keras.layers.Softmax()

def call(self,inputs,training=True):

x=self.flatten(inputs)

x=self.dense1(x)
x=self.bn1(x,training)
x=self.leakyrelu1(x)

x=self.dense2(x)
x=self.bn2(x,training)
x=self.leakyrelu2(x)

features=x

valid=self.dense_valid(features)#真or假

label=self.dense_label(features)
label=self.softmax(label)#属于哪一类

return valid,label#真实图片还是生成图片:valid;属于哪一类(n_classes,fake):label

再次运行,不报错,可以正常训练,生成器产生的图片效果如下:

这说明之前(SGAN)中的报错都是判别器的网络结构导致的

因此,优化的方向大致为:修改判别器的网络结构!包括之前的SGAN的优化,应该也是这个方向。