通过采样等方式获取高清图片的低分辨率版本,两者形成一一映射的关系,作为准备好的数据集。

不同于之前的GAN的输入为noise,SRGAN 的输入为低分辨率图片,希望通过对抗的方式学习如何生成低分辨率图片的超清版本。

主要改动的地方除了生成器和判别器的架构外,就是损失函数了:判别器的损失函数无需改动,而生成器的损失函数在原来的基础上,需要再增加两项,一个是真实图片与生成图片的均方误差,另一个是真实图片与生成图片经过vgg19提取得到的特征之间的均方误差(论文中把这两个合起来叫做内容损失,而之前生成器的损失叫做对抗损失,内容损失+对抗损失=感知损失)。


代码

导入相关函数

1
2
3
4
5
6
7
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers
import scipy#需要执行 pip install scipy==1.2.1 来给scipy降级
from glob import glob
import datetime

准备数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class DataLoader():
def __init__(self, dataset_name, img_res=(128, 128)):
self.dataset_name = dataset_name
self.img_res = img_res

def load_data(self,is_testing=False):
data_type = "train" if not is_testing else "test"

path = glob(r'C:\Users\fanxi\Desktop\%s\*' % (self.dataset_name))

imgs_hr = []
imgs_lr = []
#对每一个图片做预处理
for img_path in path:
img = self.imread(img_path)

h, w = self.img_res
low_h, low_w = int(h / 4), int(w / 4)

img_hr = scipy.misc.imresize(img, self.img_res)
img_lr = scipy.misc.imresize(img, (low_h, low_w))

# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)

imgs_hr.append(img_hr)
imgs_lr.append(img_lr)

imgs_hr = np.array(imgs_hr,dtype=np.float32) / 127.5 - 1.
imgs_lr = np.array(imgs_lr,dtype=np.float32) / 127.5 - 1.

#封装成tf.dataset的格式
BATCH_SIZE=256
BUFFER_SIZE=len(imgs_hr)
datasets=tf.data.Dataset.from_tensor_slices((imgs_hr,imgs_lr))
datasets=datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

return datasets


def imread(self, path):
return scipy.misc.imread(path, mode='RGB').astype(np.float)

data_loader=DataLoader('test')#西安去除几张图片放入test文件夹试一下效果
datasets=data_loader.load_data()

构建VGG

1
2
3
4
5
6
7
8
9
#提取图像特征
class BuildVGG(tf.keras.models.Model):
def __init__(self):
super().__init__()
vgg=tf.keras.applications.VGG19(weights='imagenet')
vgg.outputs=[vgg.layers[9].output]
def call(self,img):#输入的图像img有两种:生成的高分辨率图像(sr)和原来的高分辨率图像(hr)
img_features=vgg(img)
return img_features

搭建生成器和判别器网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Discriminator_model(tf.keras.Model):
def __init__(self):
super().__init__()
self.df=64#判别器的第一层中filter的个数为64

self.dense1=layers.Dense(self.df*16)
self.leakyrelu=layers.LeakyReLU()
self.dense2=layers.Dense(1,activation='sigmoid')

#判别器的层,因为会被使用多次,故而写成一个方法以供调用
def d_block(self,layer_input,filters,strides=1,bn=True):
d=layers.Conv2D(filters,kernel_size=3,strides=strides,padding='same')(layer_input)
d=layers.LeakyReLU(alpha=0.2)(d)
if bn:
d=layers.BatchNormalization(momentum=0.8)(d)
return d

#判别器的输入是一张图片img,它可能是真实的高分辨率图片(hr),也可能是生成的高分辨率图片(sr),但无论img属于哪一类,它们的尺寸都是一样的
def call(self,img,training=True):
x=self.d_block(img,self.df,bn=False)
x=self.d_block(x,filters=self.df,strides=2)
x=self.d_block(x,filters=self.df*2)
x=self.d_block(x,filters=self.df*2,strides=2)
x=self.d_block(x,self.df*4)
x=self.d_block(x,filters=self.df*4,strides=2)
x=self.d_block(x,filters=self.df*8)
x=self.d_block(x,filters=self.df*8,strides=2)

x=self.dense1(x)
x=self.leakyrelu(x)
validity=self.dense2(x)#ture or fake

return validity
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Generator_model(tf.keras.Model):
def __init__(self):
super().__init__()

#残差块的个数
self.n_residual_blocks=16

#1.定义残差块之前的层
self.c1=layers.Conv2D(64,kernel_size=9,strides=1,padding='same')
self.relu=layers.ReLU()

#2.定义残差块(见residual_block方法)

#3.定义残差块后面的层
self.c2=layers.Conv2D(64,kernel_size=3,strides=1,padding='same')#生成器第一层中的filter的个数是64
self.bn=layers.BatchNormalization(momentum=0.8)
self.add=layers.Add()

#4.定义上采样过程(见deconv2d方法)

#5.定义最终用于生成高分辨率图像的输出层
self.c3=layers.Conv2D(3,kernel_size=9,strides=1,padding='same',activation='tanh')

#定义残差块(也是层的堆叠),因为会被使用多次,故而写成一个方法以供调用
def residual_block(self,layer_input,filters):#filters:生成器第一层中的filter的个数
d=layers.Conv2D(filters,kernel_size=3,strides=1,padding='same')(layer_input)
d=layers.Activation('relu')(d)
d=layers.BatchNormalization(momentum=0.8)(d)
d=layers.Conv2D(filters,kernel_size=3,strides=1,padding='same')(d)
d=layers.Add()([d,layer_input])
return d

#定义上采样过程(也是层的堆叠),因为会被使用多次,故而写成一个方法以供调用
def deconv2d(self,layer_input):
u=layers.UpSampling2D(size=2)(layer_input)
u=layers.Conv2D(256,kernel_size=3,strides=1,padding='same')(u)
u=layers.Activation('relu')(u)
return u


#生成器输入的是低分辨率的图片(img_lr) img_lr:(32,32,3)
def call(self,img_lr,training=True):

#1.在残差块之前的层中前向传播
r=self.c1(img_lr)
r=self.relu(r)

#2.在残差块中前向传播
x=self.residual_block(r,64)#64:生成器第一层中的filter的个数
for _ in range(self.n_residual_blocks-1):
x=self.residual_block(x,64)

#3.在残差块后面的层中前向传播
x=self.c2(x)
x=self.bn(x)
x=self.add([x,r])

#4.在上采样层中前向传播
for _ in range(2):
x=self.deconv2d(x)

#5. 生成高分辨率图像(3通道彩色)
x=self.c3(x)

return x

定义损失函数

再把这张图搬过来


共discriminator_loss、generator_content_loss、generator_adversarial_loss三种loss,第一个用来训练判别器,后两个加起来,训练生成器。

训练流程:

  • 沿着红色虚线,计算判别损失,更新判别器参数Dθ
  • 沿着粉色虚线,计算产生损失,更新产生器参数Gθ
1
2
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=False)
mse=tf.keras.losses.MeanSquaredError()
1
2
3
4
5
#判别器无需改动
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss+fake_loss
1
2
3
4
5
6
7
8
9
10
11
#生成器的loss包含两部分,一部分是和之前一样的,这里称之为adversarial loss,即对抗损失;另一部分是content loss,即内容损失
def generator_loss(fake_out,img_real,img_fake):
#img_fake:生成的图片;
#img_real:用于训练的真实图片

adversarial_loss=cross_entropy(tf.ones_like(fake_out),fake_out)#对抗损失

fea_img_real=vgg(img_real)
fea_img_fake=vgg(img_fake)
content_loss=mse(fea_img_real,fea_img_fake)*0.006 + mse(img_real,img_fake)#内容损失
return adversarial_loss+content_loss

定义优化器

1
2
generator_opt=tf.keras.optimizers.Adam(1e-4)
discriminator_opt=tf.keras.optimizers.Adam(1e-4)

设置超参数,实例化生成器和判别器

1
2
3
4
5
EPOCHS=100
vgg=BuildVGG()
vgg.trainable=False
generator=Generator_model()
discriminator=Discriminator_model()

定义每个batch的训练过程

所有操作都是对于一个批次的全部样本”同时”进行的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#image_one_batch包含两部分 :img_lr与img_hr,即:低分辨率图片 与 原来的高分辨率图片
#此外,把生成器生成的图片叫做img_sr,即超分辨率图片
def train_step(imgs_hr,imgs_lr):

with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:

real_out=discriminator(imgs_hr,training=True)#真实图片送入判别器之后得到的预测标签

imgs_sr=generator(imgs_lr,training=True)#img_lr:低分辨率图片;img_sr:生成的高分辨率图片
fake_out=discriminator(imgs_sr,training=True)#生成图片送入判别器之后得到的预测标签

#分别计算两者的损失
gen_loss=generator_loss(fake_out,imgs_hr,imgs_sr) #img_sr:生成的图片; img_hr:用于训练的真实图片
disc_loss=discriminator_loss(real_out,fake_out)

#求可训练参数的梯度
gradient_gen=gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

#使用优化器更新可训练参数的权值
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))

定义生成图片的展示函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def sample_images(epoch):
os.makedirs('images' , exist_ok=True)
r, c = 2, 2

imgs_hr, imgs_lr =[],[]
for hr,lr in datasets.take(2):
imgs_hr.append(hr)
imgs.lr.append(lr)

fake_hr = generator.predict(imgs_lr)

# Rescale images 0 - 1
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5

# Save generated images and the high resolution originals
titles = ['Generated', 'Original']
fig, axs = plt.subplots(r, c)
cnt = 0
for row in range(r):
for col, ione_batchmage in enumerate([fake_hr, imgs_hr]):
axs[row, col].imshow(image[row])
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
cnt += 1
fig.savefig("images/%d.png" % ( epoch))
plt.close()

# Save low resolution images for comparison
for i in range(r):
fig = plt.figure()
plt.imshow(imgs_lr[i])
fig.savefig('images/%d_lowres%d.png' % (epoch, i))
plt.close()

定义训练函数

1
2
3
4
5
6
def train(dataset,epochs):
for epoch in range(epochs):
for imgs_hr,imgs_lr in dataset:
train_step(imgs_hr,imgs_lr)
#运行 generate_plot_image function
sample_images(epoch)

开始训练

1
train(datasets,EPOCHS)

报错:

1
RecursionError: maximum recursion depth exceeded while calling a Python object

报错原因最后定位到了预训练模型vgg19

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
----> 8     fea_img_real=vgg(img_real)
9 fea_img_fake=vgg(img_fake)
10 content_loss=mse(fea_img_real,fea_img_fake)*0.006 + mse(img_real,img_fake)#内容损失

D:\Anoconda\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
966 with base_layer_utils.autocast_context_manager(
967 self._compute_dtype):
--> 968 outputs = self.call(cast_inputs, *args, **kwargs)
969 self._handle_activity_regularization(inputs, outputs)
970 self._set_mask_metadata(inputs, outputs, input_masks)

<ipython-input-6-69bcc14728ec> in call(self, img)
6 vgg.outputs=[vgg.layers[9].output]
7 def call(self,img):#输入的图像img有两种:生成的高分辨率图像(sr)和原来的高分辨率图像(hr)
----> 8 img_features=vgg(img)
9 return img_features

... last 2 frames repeated, from the frame below ...

D:\Anoconda\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
966 with base_layer_utils.autocast_context_manager(
967 self._compute_dtype):
--> 968 outputs = self.call(cast_inputs, *args, **kwargs)
969 self._handle_activity_regularization(inputs, outputs)
970 self._set_mask_metadata(inputs, outputs, input_masks)

RecursionError: maximum recursion depth exceeded while calling a Python object

暂时未解决,先放着。

参考资料

https://blog.csdn.net/DuinoDu/article/details/78819344

https://www.cnblogs.com/zgqcn/p/11260343.html