提出问题

在看GAN的实现代码的时候,发现了这么一个地方:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class GAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100

optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

# Build the generator
self.generator = self.build_generator()

# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)

# For the combined model we will only train the generator
self.discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

缩小上面代码的范围,看这一行:

1
2
# For the combined model we will only train the generator
self.discriminator.trainable = False

这里将判别器设置为不训练状态

那是不是意味着判别器就不能被训练了呢?

不!

我们先继续看后面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def build_generator(self):

model = Sequential()

model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))

model.summary()

noise = Input(shape=(self.latent_dim,))
img = model(noise)

return Model(noise, img)

def build_discriminator(self):

model = Sequential()

model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()

img = Input(shape=self.img_shape)
validity = model(img)

return Model(img, validity)

这是分别构建生成器和判别器网络的代码

好像看不出什么端倪

那继续看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def train(self, epochs, batch_size=128, sample_interval=50):

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):

# ---------------------
# Train Discriminator
# ---------------------

# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

# Generate a batch of new images
gen_imgs = self.generator.predict(noise)

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# ---------------------
# Train Generator
# ---------------------

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)

# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)

train函数中,似乎发现了一些东西。来,缩小代码范围:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

# Generate a batch of new images
gen_imgs = self.generator.predict(noise)

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# ---------------------
# Train Generator
# ---------------------

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)

看!代码中出现了# Train the discriminator的步骤!

不是说之前已经设置为判别器为不可训练状态了吗?

emm~

解决问题

经过一番搜索,终于找到了答案

1
By setting trainable=False after the discriminator has been compiled the discriminator is still trained during discriminator.train_on_batch but since it's set to non-trainable before the combined model is compiled it's not trained during combined.train_on_batch
  • discriminatorcompile之后,即使设置了discriminator.trainable=False,该discriminator仍然可以通过train_on_batch的方式被训练;

  • 但是如果discriminator在被compile之前就把训练状态设置为False,那么即使是使用discriminator.train_on_batch的方式也不能训练该判别器。

1
2
3
When you call compile, it builds a trainable model, and uses the current trainable flags. A compiled model can then not have its trainable flags changed, so we are free to change them and compile another model with different flags.

Then, for some reason, those two compiled models can still have the same weights (I guess?) even though tensorflow/keras still sees them as clearly separated things.
  • compile一个模型时,该模型的训练状态就被固定成当前状态了(比如discriminator.trainable=True);

  • 当尝试修改compile之后的discriminator的训练状态时,实质上是对另外一个discriminator的训练状态进行修改(也许接下来会compile);

  • 由于框架自身的机制,上面compile之后的两个discriminator拥有相同的网络权重(即使它们被看作是两个独立的模型)。

参考

https://github.com/eriklindernoren/Keras-GAN/issues/73