前两篇实现了简版GPT,并对其进行了SFT,我们接下来看ChatGPT整体训练流程的最后一个环节——对齐训练(Alignment Training)

1.方法3:对齐训练(Alignment Training)

(1)与ChatGPT整体训练流程图的对应关系

  • 方法3对应于ChatGPT整体训练流程的STEP2、STEP3
  • 方法3的核心思想是利用了强化学习,最终将GPT3演进为了更通人性的ChatGPT

image-20231020062050959

  • ChatGPT整体训练流程中的STEP2、STEP3,就是大名鼎鼎的RLHF——基于人类反馈的强化学习
    • RL:Reinforcement Learning
    • HF:Human Feedback
  • ChatGPT整体训练流程中的STEP2对应于强化学习模型的Interpreter模型
  • ChatGPT整体训练流程中的STEP3对应于强化学习模型的Action模型

image-20231020085551189

(2)什么是对齐训练

  • 对齐训练:Alignment Training,它就是一种机器学习的模型训练方法。

  • 核心思想:训练出人类主观感受的模型,这个模型具备预测人类的决策的能力。

    • 这样,训练好的模型,就可以在未见过的场景下,按照类似人的行为模式做出选择。
  • 对齐训练与强化学习的关系:OpenAI在对齐训练中,结合了强化学习。

    • ChatGPT整体训练流程的STEP2就是对齐训练,学习出预测人类回答问题的偏好模型。
    • ChatGPT整体训练流程的STEP3就是强化学习,STEP2输出的这个模型,作为强化学习的Interpreter模型。STEP3不断迭代,最终学习到Action模型。
      • 通过SFT训练之后GPT3,本质就是一个能机械式地回答问题的机器人。
      • 通过RLHF学习的Action模型,才是帮助SFT之后的GPT3,类似人类回答问题的关键机关。
  • 细节:ChatGPT整体训练流程图中,出现了PPO算法,PPO算法是近端策略梯度优化,增加一个限制Action模型在训练过程中梯度上升速度,本质就是避免Action模型产生一个离谱的Action。

    • PPO算法展开说内容太多,本文不赘述,详见论文:https://arxiv.org/abs/1707.06347

(3)STEP2的Reward Model模型训练伪码

  • 我们再来看看STEP2的伪码,如下图:

image-20231020093407332

(4)STEP3的RLHF训练伪码

  • 我们再来看看STEP3的伪码,如下图:

image-20231020095154134

2.DeepSpeed

RLHF,是ChatGPT最核心的技术机密,除了在《Introducing ChatGPT》(https://openai.com/blog/chatgpt)中提到了,并未公开过源码。

在前文的伪码实现部分,虽然通过伪码描述了RLHF的核心逻辑,但距离商用还欠缺很多东西(如:分布式训练等)。

幸好微软开源了类似的框架,DeepSpeed,我们可以通过阅读它的源码、使用它,开展RLHF。

image-20231020100752488

3.实例-开展RLHF训练

STEP0.前置准备

  • 硬件:V100一块,32G显存
  • 基础软件:Ubuntun20.04,Minicoda3,Pytorch3.8,CUDA11.6,Python3.10
  • 预训练模型:选择Facebook的opt1.3B,即13亿参数的预训练模型。
  • 环境初始配置:创建虚拟环境,

image-20231020102932305

  • 安装依赖:进入DeepSpeed-Chat目录,安装相关依赖

image-20231020103116451

  • 环境测试:确认相关基础软件版本号。

image-20231020103548543

STEP1.SFT

  • 开展SFT训练时,由于服务器资源不足(省钱),需要避免OOM异常,因此需要修改一下训练脚本。

路径:training/step1_supervised_finetuning/training_scripts/opt/single_gpu/run_1.3b.sh

image-20231020104103388

  • 设置待微调的预训练模型,以及输出路径。

路径:training/step1_supervised_finetuning/evaluation_scripts/run_prompt.sh

image-20231020110616405

  • 执行训练脚本run_1.3b.sh,触发DeepSpeed开始SFT训练。

STEP2.RM

  • 开展RM训练时,由于服务器资源不足(省钱),需要避免OOM异常,因此需要修改一下训练脚本。

路径:training_scripts/opt/single_gpu/run_350m.sh

image-20231020145149415

  • 指定Reward Model的输出路径。

路径:evaluation_scripts/run_eval.sh

image-20231020145843742

  • 执行训练脚本run_350m.sh,触发DeepSpeed开始RW训练。

STEP3.RLHF

  • 开展RLHF训练时,由于服务器资源不足(省钱),需要避免OOM异常,因此需要修改一下训练脚本。

training_scripts/opt/single_gpu/run_1.3b.sh

image-20231020155704428

  • 执行训练脚本run_1.3b.sh,触发DeepSpeed开始RLHF训练。

STEP4.模型测试

  • 执行测试脚本python chat.py --path training/step3_rlhf_finetuning/output/actor,不赘述。

4.小结

本文是实现简版GPT的三篇中的最后一篇,也是最难理解的一部分内容:

  • 对齐训练是什么?
  • 对齐训练和强化学习的关系是什么?
  • ChatGPT整体训练流程的STEP2、STEP3与强化学习的Interpreter和Action模型如何对应?
  • DeepSpeed的实际操作?

本文也有没有展开探讨的内容,待本专栏后续继续展开:

  • RLHF的策略梯度优化算法
  • PPO算法
  • ……

编写本专栏受益匪浅,也非常感恩因为编写本专栏认识的大神们,期待与各位小伙伴持续的讨论和思辨!