解决过拟合的手段有很多,比如early stopping, dropout, weight regularization,然而,这些手段无法解决模型对于标签过度自信的问题:在标签有误时,容易导致模型学习到错误的内容。
在普通的交叉熵函数中,只有预测正确的类别才会对损失作出贡献。标签平滑的思想是对标签target的onehot形式进行改造,使其取值不再是非0即1,这样,预测错误的类别也会对损失作出较小的贡献,从而迫使模型进一步学习不同类别之间的区别,避免了模型的过度自信。
使用标签平滑,只需对标签target进行变换即可,其余部分和交叉熵的计算方式是一样的,平滑后的标签如下:
$$y_{ls}=(1-\alpha)*y_{onehot}+\frac{\alpha}K$$
其中,K是类别数,$\alpha$是平滑系数,$y_{onehot}$是原始标签的onehot结果。
标签平滑的PyTorch代码实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| import torch import torch.nn.functional as F
def cross_entropy_loss(preds, target): logp = F.log_softmax(preds, dim=1) loss = torch.sum(-logp * target, dim=1) return loss.mean()
def label_smoothing_cross_entropy_loss(preds, targets,epsilon=0.1): n_classes = preds.size(1) onehot = F.one_hot(targets).squeeze(1) targets = onehot * (1 - epsilon) + torch.ones_like(onehot) * epsilon / n_classes loss = cross_entropy_loss(preds, targets) return loss
if __name__=='__main__': preds=torch.tensor([[0,0,0.1],[0.3,0.4,0.2],[0.2,0.3,0.4],[0.9,0.8,0.7]]) targets=torch.tensor([1,2,1,0]).reshape(4,1) print(label_smoothing_cross_entropy_loss(preds,targets))
|
参考: