LLM自回归预训练过程详解-大模型炼丹术(五)
在前面的4篇文章中,我们已经完成了整个数据流向所需的模块构建,包括tokinizer,embedding,注意力机制,并串联得到了GPT2这个LLM架构。
现在,是时候准备开始训练我们的LLM了。
相比于前面发布的4篇文章,本文将更加偏重于代码实战。
一、准备自回归预训练数据集
在开始编写训练脚本之前,我们需要先构建训练所需数据集。这里使用the-verdict.txt
,这是在本系列一开始就作为示例使用的一本书。
1 | import os |
现在有了原始数据,还需要用tokinizer进一步编码成token ID序列的形式。先把我们之前定义好的tokinizer搬过来:
1 | import tiktoken |
可以看到,这本书很小,总共包含20479个字符,使用BPE进行编码后,总共得到5145个token。
定义基本的编码解码函数:
1 | def text_to_token_ids(text, tokenizer): |
同样,我们在之前已经定义好了数据加载器,这里也直接搬过来:
1 | from torch.utils.data import Dataset, DataLoader |
配置文件粘过来:
1 | GPT_CONFIG_124M = { |
调用数据加载器来定义训练/验证loader:
1 | # Train/validation ratio |
确保训练/验证集中至少包含一个样本(长度为context_size
):
1 | # Sanity check |
查看数据集:
1 | print("Train loader:") |
输出:
1 | Train loader: |
由此可知,在batch_size设置为2,context_length设置为256时,总共得到10个样本,这是一个相当小的数据集。
二、准备模型架构与损失函数
直接把我们在上一篇文章中定义的GPT2架构搬过来:
1 | class GPTModel(nn.Module): |
使用交叉熵作为损失函数:
1 | def calc_loss_batch(input_batch, target_batch, model, device): |
在开始训练之前,可以先查看一下整体的训练和验证集的loss:
1 | if torch.cuda.is_available(): |
三、编写LLM自回归预训练循环
这部分代码也遵循PyTorch深度学习中的经典训练循环形式,代码非常简单,这里不再细说。
1 | def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs, |
现在使用定义好的训练循环函数开始执行训练:
1 | model = GPTModel(GPT_CONFIG_124M) |
训练日志如下:
1 | Ep 1 (Step 000000): Train loss 9.817, Val loss 9.924 |
我们给定的start_context
是Every effort moves you
在最开始,模型只会输出Every effort moves you,,,,,,,,,,,,.
而到了最后一个epoch,模型输出了语法基本正确的句子:Every effort moves you?" "Yes--quite insensible to the irony. She wanted him vindicated--and by me!" He laughed again, and threw back his head to look up at the sketch of the donkey. "There were days when I
你可能会疑惑,预设的max_tokens
不是50吗,这两次的测试输入都是Every effort moves
,可是为什么输出的句子长度却不一样呢?
因为max_tokens=50
指的是生成的token数量上限,而不是句子的字数或单词数。一个token可能是:
- 一个单词(例如 “donkey”)
- 一个子词(例如 “sketch” 可能被拆分为 [“sk”, “etch”])
- 一个标点符号(例如 “,”、”.” 可能单独算作 token)
随着训练的进行,模型的语言能力增强:
- 早期:模型可能随机输出大量逗号、”and” 等低信息量的 token,使得句子看起来短而混乱。
- 后期:模型学会了输出完整的单词、短语和句子,因此即使 max_tokens限制为50,生成的文本可能更连贯、信息密度更高,看起来更长。
最后来看一下loss:
可以看到,整体的训练loss是下降的,但存在过拟合(验证集loss后期上升),这是因为我们所使用的数据集比较小,仅仅用于演示。
到这里,我们完成了LLM的预训练。模型已经掌握了基本的语言模式,但如何让它更好地生成高质量文本,还需要合理的解码策略。
在下一篇文章中,我们将深入探讨LLM的一些解码策略,并对这些策略的优缺点进行详细分析。敬请期待!