交叉验证时划分数据的方式
交叉验证是机器学习中常用的一种策略,其核心是对数据集进行划分,本文介绍sklearn中的3种交叉验证时划分数据集的方法:
1 | KFold |
这里来举例说明各个方法的使用。
先来导入这些方法,并创建一个简单的数据集:
1 | from sklearn.model_selection import KFold,StratifiedKFold,GroupKFold |
KFold
KFold只需要特征x就能够完成数据划分
1 | kf = KFold(n_splits=3,shuffle=True)# n_splits不能超过总的样本数 |
输出:
1 | [0 3 4 5 6 9] [1 2 7 8] |
StratifiedKFold
和上面的KFold不同的是,StratifiedKFold保证了划分后的两部分数据分布和原始数据集的标签分布是近似相同的
1 | skf = StratifiedKFold(n_splits=3,shuffle=True)#n_splits 不能超过每个类别中所含样本数 |
输出:
1 | [0 1 4 5 7 8] [2 3 6 9] |
GroupKFold
GroupKFold保证了同一组的样本会被同时划分为训练集或验证集,而不是既有样本在训练集也有样本在验证集。
这种方法在结构化数据建模时常用,比如同一个用户会购买许多商品,那么这个用户和他所购买每一个商品之间都会形成一个样本。
如果将同一个用户下的全部样本的一部分划分到训练集中,另一部分划分到验证集中,那么,由于用户的信息(比如年龄,居住地,性别,工作等字段信息)已经在模型训练时见过了,所以在验证时,模型可能很容易的就能对涉及该用户的样本做出预测,但这并不是因为模型有多厉害,而是模型在训练时已经见过了。这就造成了信息泄露。
如果采用GroupKFold的方法进行划分,将用户的id作为group,就可以避免上述信息泄露的问题。
1 | gkf=GroupKFold(n_splits=3)# n_splits不能超过总的 group数 |
输出:
1 | [0 1 2 7 8 9] [3 4 5 6] |
参考:
- [1] https://blog.csdn.net/qq_16761099/article/details/106091354
- [2] https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
- [3] https://scikit-learn.org/0.16/modules/generated/sklearn.cross_validation.StratifiedKFold.html
- [4] https://scikit-learn.org/0.20/modules/generated/sklearn.model_selection.GroupKFold.html