入门深度学习至今已经一年了,趁着这个暑假参加了FlyAI举办的蘑菇分类竞赛,并侥幸取得了第5名。虽然之前也分享过类似竞赛的文章,但都是作为学习帖(传送门戳我),因此本文算是第一次正式参加机器学习类竞赛的浅薄经验分享帖,不足之处,欢迎指正。
Alt text

赛题介绍

这里就不复制官方大段的赛事介绍了,只讲一下重点部分:

给定一批图片,其中共有9种不同品种的蘑菇,希望你用机器学习的方法训练一个分类器,使得该分类器能够将不同种类的蘑菇区分开来。
Alt text
Alt text

评估指标是准确率(Accuracy),它定义为测试集中所有预测正确的样本数与实际总样本数之比。

图片被划分为训练集和测试集两部分。

和Kaggle等竞赛不同,这里我们是看不到测试数据的,并且只能看到训练集的很小部分数据。

官方提供训练平台,并有免费算力发放,因此不需要担心显卡的问题。

baseline搭建

官方已经定义好了数据的读取代码,并给出了基于PyTorch的baseline,所以并不需要我们从头去搭建baseline。

我们要做的,就是在这个baseline基础做优化,尽可能地去提升准确率指标。

模型优化

合适的骨干网络

骨干网络用于提取图片中的特征,并将特征送入分类器(一般是全连接层),输出预测的类别概率分布。

这里,我选择了EfficientNet B6作为骨干网络,理由见下图:
Alt text

可以看出,EfficientNet系列无论是在参数量还是Top-1准确度上,都具备很大的优势。

而且这些骨干网络都会提供预训练模型,你可以直接加载使用,而不是从头开始去训练它。

事实上,在很多时候,我们都不会从头去做训练,因为这样做很容易因为数据量不足而导致模型不收敛等问题。

我们都站在巨人的肩膀上。

数据增强

数据增强起到了数据集扩充的作用,它对于分类器性能提升起了很大的作用。

通过多次实验,我选取了如下数据增强方式:

  • RandomScaleCrop
  • RandomErasing
  • HorizontalFlip
  • VerticalFlip
  • ShiftScaleRotate
  • RandomBrightnessContrast
  • OneOf(GaussNoise, CLAHE,ImageCompression,RandomGamma,Posterize,Blur)

除了前面两个数据增强方法需要手动实现外,剩余的可以通过调用albumentations来实现。这是一款专为数据增强打造的第三方库,提供了更多的数据增强方法,且能与PyTorch无缝衔接,使用起来非常方便。

以上数据增强方法都在Dataset模块中完成。除此之外,在train过程中,还引入了mixup,这也是一种数据增强方法,但不同于上述方法都是对于一张图片进行操作,mixup会将两张图片混叠在一起,得到的图片”看山不是山,看水不是水”,你可以通过下面这张经过mixup处理的图来感受下:
Alt text

这种数据增强的方法对于准确度的提升非常有帮助,当然,代价是更长的训练时间。具体地,当使用了mixup后,需要加大epoch数进行训练,剩下的就是耐心等待。

实现代码就几行:

1
2
3
4
5
6
7
8
9
10
11
12
...

alpha=1
lam = np.random.beta(alpha,alpha)
index = torch.randperm(inputs.size(0)).cuda()
inputs = lam*inputs + (1-lam)*inputs[index,:]
targets_a, targets_b = labels, labels[index]

outputs = self.net(inputs)
loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)

...

训练策略

在训练过程中,发现模型在训练集和验证集上表现还不错,但测试集上的结果并不好,怀疑出现了过拟合问题。

为了应对这个问题,除了上面的数据增强方式外,还引入了SAM。

SAM来自论文《SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION》,它能够使得网络收敛到相对平滑的最小值,而不是尖锐的最小值点。
Alt text

注意SAM并不是一个新的优化器,需要配合其他优化器一起使用。

损失函数

这里使用了经典的多分类交叉熵损失函数。

我尝试过标签平滑,类别加权等损失,但效果还不如交叉熵,因此就没有在损失上做更多尝试。

TTA

TTA,即”测试时增强”。对于测试集中的一张图片,通过多次不同的数据增强,可以得到多张不同数据增强方法下的测试图片。然后分别将每张图片输入训练好的模型进行预测,最终预测结果等于这些预测结果的平均值。

可能的改进方式

  • 尝试更强大的骨干网络
  • 尝试其他损失,比如人脸问题中常用的ArcFace
  • 尝试不同的学习率衰减策略
  • 训练多个模型,做模型融合
  • 在算力足够的前提下,增大输入模型图像的尺寸
  • 在算力足够且使用了mixup的前提下,尽可能训练更长的时间(大力出奇迹)

总结

以上分享都是个人在做了很多次实验后总结出来的,大部分方案可以迁移到其他类似的任务上,因此如果你正好在做类似竞赛,不妨尝试一下。当然,具体有没有用还要实验过后才知道,正所谓实验先行

最后附上一张五彩斑斓的实验记录部分截图
Alt text

以上就是关于本次竞赛的全部分享内容了,希望能够有所帮助。