LLM技术系列之-基于RLHF的生成式大模型实现思路及其中的关键算法PPO介绍

这篇文章将向大家介绍基于RLHF的生成式大模型总体思路流程,以及介绍TRL库中的PPO训练算法实现中的几个关键细节点。

1、rollout数据回合采样的方法,rollout为对数据进行采集,类似于强化学习的gym模拟仿真环境中的step等环境仿真模拟方法(传统的RL任务,游戏,机器人仿真环境等,提供了统一的接口,如step,render等)。在LLM场景下,rollout 过程用于收集序列(states, actions, rewards)以供训练策略优化,states对应于queries tokens,actions对应于生成的response的tokens,以及通过奖励模型(Reward Model, RM) 计算生成的tokens序列奖励得分(一个query的得分为一个scalar的标量值来度量生成的tokens序列的质量)。这里对计算过程中的相关变量做一下说明:(1)logits为未经过softmax的原始值,如GTP-2的输出output有两个属性,一个为scores即为logits,另一个sequences为输出的tokens的序列。logprob为对logits进行log_softmax计算后选择生成的token对应索引位置的值,这里的参考模型和策略模型输出的KL散度的定义为基于对应的token id的概率之间的差异,而不是基于整个logits的所有词汇表(vocabulary)空间的概率分布去计算KL散度信息,详细的代码注释说明请见[1]。

2、从训练框架代码上需要说明的是,accelerator库提供不同GPU型号的memory伸缩训练支持(通过超参如local_batch_size,local_mini_batch_size,per_device_train_batch_size进行设置)。accelerator库可以通过多个小批次的梯度累积计算来模拟大批量的训练,具体的实现方法为:在 with accelerator.accumulate(model) 作用域内:减少显存占用,允许在小显存的 GPU 上模拟大批量训练。仅在梯度累积步数达到设定值时执行 optimizer.step()optimizer.zero_grad(),否则只进行梯度计算,不更新权重。避免手动使用 if step % accumulation_steps == 0: 逻辑,提高代码可读性及实现便捷性。

问题1:rewards代表即时回报,计算过程中为什么中间token的rewards值只与KL散度相关,而最后的token为KL散度相关的奖励加上reward_model给出的评分(score)之和?这么设计为什么有效?

回答:rewards在优势函数中参与了相关的计算,引文[2]公式中的r即为reward,根据回报函数或优势函数的计算方式,回报函数或优势函数为当前即时奖励到最后的每一个step对应奖励的序列的折扣和(具体公式见[2]),这样最后的回报通过传递都会对前面的token的优势函数有影响。在具体的设计过程中,KL loss前还有一个权重超参kl_coef。

问题2:在RLHRF算法中,参考策略模型,策略网络,价值网络,评分模型(网络)之间的关系是怎样的?以及ppo(近端策略优化)算法中的数据采样和迭代训练的大体流程是怎样的?

回答:在RLHF的训练过程中,一般包含三个步骤:(1)、在超大规模文本数据上基于自监督预训练的LLM模型上进行有监督微调(Supervised Fine-Tunning,基于人类标注的监督式数据),SFT模型在基于RLHF的PPO算法中用作参考策略模型提供训练数据样本生成;(2)、奖励建模,通过对比损失函数和基于人类反馈的偏好评分rank标记数据集(RLHF数据集)训练reward model;(3)采用PPO算法训练问答大模型,在实现时,参考模型既可以是独立传入的SFT训练的模型,也可以和策略网络共享部分或所有参数参与更新(具体可以查看参数num_shared_layers以及函数create_reference_model,但基于现有的调研来看,参考模型通常是冻结的,不参与训练,参数也不共享。策略模型从参考模型初始化,但在训练过程中独立更新),而价值网络一般可以和策略模型共享backbone(如类AutoModelForCausalLMWithValueHead),其在输出时多一个输出token对应的标量的value值(value值和reward值一起用于参与优势函数的计算)。而上述的rollout过程即为训练数据在线生成的过程,通过queries分别调用策略模型,参考策略模型,价值网络模型(可能和策略模型共享backbone),以及reward model(打分模型)推理以生成训练数据,然后基于一次rollout生成的数据进行num_ppo_epochs次参数更新迭代训练(policy网络和value网络的联合loss实现参数更新)。

References


by

Tags:

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *