前两篇实现了简版GPT,并对其进行了SFT,我们接下来看ChatGPT整体训练流程的最后一个环节——对齐训练(Alignment Training)。
1.方法3:对齐训练(Alignment Training)
(1)与ChatGPT整体训练流程图的对应关系
- 方法3对应于ChatGPT整体训练流程的STEP2、STEP3。
- 方法3的核心思想是利用了强化学习,最终将GPT3演进为了更通人性的ChatGPT。
- ChatGPT整体训练流程中的STEP2、STEP3,就是大名鼎鼎的RLHF——基于人类反馈的强化学习。
- RL:Reinforcement Learning
- HF:Human Feedback
- ChatGPT整体训练流程中的STEP2,对应于强化学习模型的Interpreter模型。
- ChatGPT整体训练流程中的STEP3,对应于强化学习模型的Action模型。
(2)什么是对齐训练
对齐训练:Alignment Training,它就是一种机器学习的模型训练方法。
核心思想:训练出人类主观感受的模型,这个模型具备预测人类的决策的能力。
- 这样,训练好的模型,就可以在未见过的场景下,按照类似人的行为模式做出选择。
对齐训练与强化学习的关系:OpenAI在对齐训练中,结合了强化学习。
- ChatGPT整体训练流程的STEP2就是对齐训练,学习出预测人类回答问题的偏好模型。
- ChatGPT整体训练流程的STEP3就是强化学习,STEP2输出的这个模型,作为强化学习的Interpreter模型。STEP3不断迭代,最终学习到Action模型。
- 通过SFT训练之后GPT3,本质就是一个能机械式地回答问题的机器人。
- 通过RLHF学习的Action模型,才是帮助SFT之后的GPT3,类似人类回答问题的关键机关。
细节:ChatGPT整体训练流程图中,出现了PPO算法,PPO算法是近端策略梯度优化,增加一个限制Action模型在训练过程中梯度上升速度,本质就是避免Action模型产生一个离谱的Action。
- PPO算法展开说内容太多,本文不赘述,详见论文:https://arxiv.org/abs/1707.06347
(3)STEP2的Reward Model模型训练伪码
- 我们再来看看STEP2的伪码,如下图:
(4)STEP3的RLHF训练伪码
- 我们再来看看STEP3的伪码,如下图:
2.DeepSpeed
RLHF,是ChatGPT最核心的技术机密,除了在《Introducing ChatGPT》(https://openai.com/blog/chatgpt)中提到了,并未公开过源码。
在前文的伪码实现部分,虽然通过伪码描述了RLHF的核心逻辑,但距离商用还欠缺很多东西(如:分布式训练等)。
幸好微软开源了类似的框架,DeepSpeed,我们可以通过阅读它的源码、使用它,开展RLHF。
3.实例-开展RLHF训练
STEP0.前置准备
- 硬件:V100一块,32G显存
- 基础软件:Ubuntun20.04,Minicoda3,Pytorch3.8,CUDA11.6,Python3.10
- 预训练模型:选择Facebook的opt1.3B,即13亿参数的预训练模型。
- 环境初始配置:创建虚拟环境,
- 安装依赖:进入DeepSpeed-Chat目录,安装相关依赖
- 环境测试:确认相关基础软件版本号。
STEP1.SFT
- 开展SFT训练时,由于服务器资源不足(省钱),需要避免OOM异常,因此需要修改一下训练脚本。
路径:training/step1_supervised_finetuning/training_scripts/opt/single_gpu/run_1.3b.sh
- 设置待微调的预训练模型,以及输出路径。
路径:training/step1_supervised_finetuning/evaluation_scripts/run_prompt.sh
- 执行训练脚本run_1.3b.sh,触发DeepSpeed开始SFT训练。
STEP2.RM
- 开展RM训练时,由于服务器资源不足(省钱),需要避免OOM异常,因此需要修改一下训练脚本。
路径:training_scripts/opt/single_gpu/run_350m.sh
- 指定Reward Model的输出路径。
路径:evaluation_scripts/run_eval.sh
- 执行训练脚本run_350m.sh,触发DeepSpeed开始RW训练。
STEP3.RLHF
- 开展RLHF训练时,由于服务器资源不足(省钱),需要避免OOM异常,因此需要修改一下训练脚本。
training_scripts/opt/single_gpu/run_1.3b.sh
- 执行训练脚本run_1.3b.sh,触发DeepSpeed开始RLHF训练。
STEP4.模型测试
- 执行测试脚本
python chat.py --path training/step3_rlhf_finetuning/output/actor
,不赘述。
4.小结
本文是实现简版GPT的三篇中的最后一篇,也是最难理解的一部分内容:
- 对齐训练是什么?
- 对齐训练和强化学习的关系是什么?
- ChatGPT整体训练流程的STEP2、STEP3与强化学习的Interpreter和Action模型如何对应?
- DeepSpeed的实际操作?
本文也有没有展开探讨的内容,待本专栏后续继续展开:
- RLHF的策略梯度优化算法
- PPO算法
- ……
编写本专栏受益匪浅,也非常感恩因为编写本专栏认识的大神们,期待与各位小伙伴持续的讨论和思辨!