导入相关包

1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import Dropout, Concatenate,BatchNormalization,LeakyReLU,UpSampling2D, Conv2D
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import os
from skimage.transform import resize
import imageio
from glob import glob
import numpy as np
tf.keras.backend.set_floatx('float64')

读取数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

class DataLoader():
def __init__(self, dataset_name, img_res=(128, 128)):
self.dataset_name = dataset_name
self.img_res = img_res

#for 测试
def load_data(self, domain, batch_size=1, is_testing=False):
data_type = "train%s" % domain if not is_testing else "test%s" % domain
path = glob('/home/fanxi/Desktop/dataset/%s/%s/*' % (self.dataset_name, data_type))

batch_images = np.random.choice(path, size=batch_size)

imgs = []
for img_path in batch_images:
img = self.imread(img_path)
if not is_testing:
img = resize(img, self.img_res)

if np.random.random() > 0.5:
img = np.fliplr(img)
else:
img = resize(img, self.img_res)
imgs.append(img)

imgs = np.array(imgs)/127.5 - 1.

return imgs

#for训练
def load_batch(self, batch_size=1, is_testing=False):
data_type = "train" if not is_testing else "val"
path_A = glob('/home/fanxi/Desktop/dataset/%s/%sA/*' % (self.dataset_name, data_type))
path_B = glob('/home/fanxi/Desktop/dataset/%s/%sB/*' % (self.dataset_name, data_type))

self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
total_samples = self.n_batches * batch_size

# Sample n_batches * batch_size from each path list so that model sees all
# samples from both domains
path_A = np.random.choice(path_A, total_samples, replace=False)
path_B = np.random.choice(path_B, total_samples, replace=False)

for i in range(self.n_batches-1):
batch_A = path_A[i*batch_size:(i+1)*batch_size]
batch_B = path_B[i*batch_size:(i+1)*batch_size]
imgs_A, imgs_B = [], []
for img_A, img_B in zip(batch_A, batch_B):
img_A = self.imread(img_A)
img_B = self.imread(img_B)

img_A = resize(img_A, self.img_res)
img_B = resize(img_B, self.img_res)

if not is_testing and np.random.random() > 0.5:
img_A = np.fliplr(img_A)
img_B = np.fliplr(img_B)

imgs_A.append(img_A)
imgs_B.append(img_B)

imgs_A = np.array(imgs_A)/127.5 - 1.
imgs_B = np.array(imgs_B)/127.5 - 1.

yield imgs_A, imgs_B

def load_img(self, path):
img = self.imread(path)
img = resize(img, self.img_res)
img = img/127.5 - 1.
return img[np.newaxis, :, :, :]

def imread(self, path):
return imageio.imread(path).astype(np.float)
1
dataloader=DataLoader('apple2orange')

搭建生成器网络和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#U-net generator
class Build_generator(tf.keras.Model):

def __init__(self):
super().__init__()
self.gf = 32# Number of filters in the first layer of G
self.channels=3

#final layers
self.upsampling=UpSampling2D(size=2)
self.conv=Conv2D(filters=self.channels,kernel_size=4,strides=1,padding='same',activation='tanh')

def conv2d(self,layer_input,filters,f_size=4):
#downsampling
d=Conv2D(filters,kernel_size=f_size,strides=2,padding='same')(layer_input)
d=LeakyReLU(alpha=0.2)(d)
d=tfa.layers.normalizations.InstanceNormalization()(d)
return d

def deconv2d(self,layer_input,skip_input,filters,f_size=4,dropout_rate=0):
#upsampling
u=UpSampling2D(size=2)(layer_input)
u=Conv2D(filters,kernel_size=f_size,strides=1,padding='same',activation='relu')(u)
if dropout_rate:
u=Dropout(dropout_rate)(u)
u=tfa.layers.normalizations.InstanceNormalization()(u)
#print(u.shape)
#print(skip_input.shape)
u=Concatenate()([u,skip_input])#特征级联?
return u

def call(self,inputs):#inputs:img

d1=self.conv2d(inputs,self.gf)
d2=self.conv2d(d1,self.gf*2)
d3=self.conv2d(d2,self.gf*4)
d4=self.conv2d(d3,self.gf*8)

#upsampling
u1=self.deconv2d(d4,d3,self.gf*4)
u2=self.deconv2d(u1,d2,self.gf*2)
u3=self.deconv2d(u2,d1,self.gf)

u4=self.upsampling(u3)
output_img=self.conv(u4)

return output_img
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Build_discriminator(tf.keras.Model):

def __init__(self):
super().__init__()
self.conv=Conv2D(filters=1,kernel_size=4,strides=1,padding='same')
self.df = 64 # Number of filters in the first layer of D

def d_layer(self,layer_input,filters,f_size=4,normalization=True):
#print('filters:',filters)
d=Conv2D(filters,kernel_size=f_size,strides=2,padding='same')(layer_input)
d=LeakyReLU(alpha=0.2)(d)
if normalization:
d=tfa.layers.normalizations.InstanceNormalization()(d)
return d

def call(self,inputs):#inputs: img
d1=self.d_layer(inputs,self.df, normalization=False)
d2=self.d_layer(d1,self.df*2)
d3=self.d_layer(d2,self.df*4)
d4=self.d_layer(d3,self.df*8)

validity=self.conv(d4)

return validity

定义损失函数


1
2
3
4
5
6
mse=tf.keras.losses.MeanSquaredError()
mae=tf.keras.losses.MeanAbsoluteError()

#criterion_GAN = mse
#criterion_cycle = mae
#criterion_identity = mae
1
2
3
4
5
6
7
8
9
10
11
12
13
def generator_loss(imgs_A,imgs_B,fake_A,fake_B,dA_fake_out,dB_fake_out):
#由于判别器是patchGAN #由于判别器是patchGAN的缘故,这里不再使用交叉熵,而是使用mse
#fake和valid的shape:[1,8,8,1]
#这里的batch_size=1,正好对应,若不为1,不知道有没有广播机制?
#经测试,是有广播机制的,所以即使这里的batch_size 可以为2,3,10,100,etc.
gAB_loss=mse(dB_fake_out,valid)#对于A——B的生成器来说,它所生成的图片fake_B越接近1,就越能骗过判别器B
gBA_loss=mse(dA_fake_out,valid)#对于B——A的生成器来说,它所生成的图片fake_A越接近1,就越能骗过判别器A

#衡量重建损失(应该就是论文中的循环损失了),使用L1 loss
gAB_reconstruction_loss=mae(imgs_A,fake_A)
gBA_reconstruction_loss=mae(imgs_B,fake_B)

return gAB_loss+gBA_loss+gAB_reconstruction_loss+gBA_reconstruction_loss
1
2
3
4
5
6
7
8
9
10
11
12
def discriminator_loss(imgs_A,imgs_B,fake_A,fake_B,dA_real_out,dA_fake_out,dB_real_out,dB_fake_out):
#由于判别器是patchGAN的缘故,这里不再使用交叉熵,而是使用mse
#fake和valid的shape:[1,8,8,1]
#这里的batch_size=1,正好对应,若不为1,不知道有没有广播机制?
#经测试,是有广播机制的,所以即使这里的batch_size 可以为2,3,10,100,etc.
dA_loss_real=mse(dA_real_out,valid)#domain A 的真实图片越接近1越好
dA_loss_fake=mse(dA_fake_out,fake)#生成的domain A 的图片越接近0越好

dB_loss_real=mse(dB_real_out,valid)#domain B 的真实图片越接近1越好
dB_loss_fake=mse(dB_fake_out,fake)#生成的domain B 的图片越接近0越好

return dA_loss_real+dA_loss_fake+dB_loss_real+dB_loss_fake

定义优化器

1
2
3
4
generatorAB_opt=tf.keras.optimizers.Adam(1e-4)
generatorBA_opt=tf.keras.optimizers.Adam(1e-4)
discriminatorA_opt=tf.keras.optimizers.Adam(1e-4)
discriminatorB_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
EPOCHS=1
d_A=Build_discriminator()
d_B=Build_discriminator()
g_AB=Build_generator()
g_BA=Build_generator()

定义每个batch的训练过程

1
2
3
4
5
6
7
8
9
10
img_rows,img_cols,img_channels=128,128,3
img_shape=(img_rows,img_cols,img_channels)

patch = int(img_rows / 2**4)#
disc_patch = (patch, patch, 1)
print(disc_patch)
# Adversarial loss ground truths
valid = np.ones((1,) + disc_patch)
fake = np.zeros((1,) + disc_patch)
print(fake.shape)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# for 每一个batch
def train_step(imgs_A,imgs_B):
#noise=tf.random.normal([num_examples_to_generate,noise_dim])#noise=seed
with tf.GradientTape() as genAB_tape,tf.GradientTape() as genBA_tape, tf.GradientTape() as disc_tapeA,tf.GradientTape() as disc_tapeB:
fake_B=g_AB(imgs_A)#把domain A 的图片转为domain B 的风格
fake_A=g_BA(imgs_B)#把domain B 的图片转为domain A 的风格

dA_real_out=d_A(imgs_A)
dA_fake_out=d_A(fake_A)

dB_real_out=d_B(imgs_B)
dB_fake_out=d_B(fake_B)

#分别计算两者的损失

disc_loss=discriminator_loss(imgs_A,imgs_B,fake_A,fake_B,dA_real_out,dA_fake_out,dB_real_out,dB_fake_out)
gen_loss=generator_loss(imgs_A,imgs_B,fake_A,fake_B,dA_fake_out,dB_fake_out)


#求可训练参数的梯度
gradient_genAB=genAB_tape.gradient(gen_loss,g_AB.trainable_variables)
gradient_genBA=genBA_tape.gradient(gen_loss,g_BA.trainable_variables)

gradient_discA=disc_tapeA.gradient(disc_loss,d_A.trainable_variables)
gradient_discB=disc_tapeB.gradient(disc_loss,d_B.trainable_variables)

#使用优化器更新可训练参数的权值
generatorAB_opt.apply_gradients(zip(gradient_genAB,g_AB.trainable_variables))
generatorBA_opt.apply_gradients(zip(gradient_genBA,g_BA.trainable_variables))
discriminatorA_opt.apply_gradients(zip(gradient_discA,d_A.trainable_variables))
discriminatorB_opt.apply_gradients(zip(gradient_discB,d_B.trainable_variables))

定义生成图片的展示函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def sample_images( epoch, batch_i):
dataset_name='apple2orange'
os.makedirs('images/%s' % dataset_name, exist_ok=True)
r, c = 2, 3
data_loader=DataLoader(dataset_name)

imgs_A = data_loader.load_data(domain="A", batch_size=1, is_testing=True)
imgs_B = data_loader.load_data(domain="B", batch_size=1, is_testing=True)

# Demo (for GIF)
#imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
#imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

# Translate images to the other domain
fake_B = g_AB(imgs_A)
fake_A = g_BA(imgs_B)
# Translate back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5

titles = ['Original', 'Translated', 'Reconstructed']
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt])
axs[i, j].set_title(titles[j])
axs[i,j].axis('off')
cnt += 1
plt.show()
fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
plt.close()

定义训练函数

1
2
3
4
5
6
7
8
def train(datasets,epochs):
batch_i=0
for epoch in range(epochs):
print('准备训练第{}个epoch:'.format(epoch))
for imgs_A,imgs_B in datasets:
train_step(imgs_A,imgs_B)
sample_images(epoch,batch_i)#test
batch_i+=1

开始训练

1
2
3
datasets=dataloader.load_batch(batch_size=1)

train(datasets,EPOCHS)