这篇文章对应视频:“[veRL] fsdp sft trainer 补充,teacher forcing、shift labels shift logits、loss mask”(BV1eWjtzbEdP)。
它是上一篇 SFT trainer 文章的“补充篇”,专门把三个最容易写错、但一错就会把模型训歪的细节讲透:
- Teacher forcing:SFT 到底在“喂什么”给模型,喂错会导致什么偏差。
- Shift labels / shift logits:为什么 causal LM 的 CE loss 天生存在“错一位”,实现里你必须显式对齐。
- Loss mask:multi-turn + tool-use 数据里,你到底要监督哪些 token;mask 在 shift 前后怎么对齐。
系列导航:
关联阅读(建议先看主篇再看补充):
- veRL:FSDP SFT Trainer 主篇(交叉熵 / loss mask / scheduler)
- veRL:MultiTurn Tool Use / Coding Agent SFT(Cold Start for RL)
- Tokenizer 非对称性与 Token-in-Token-out(RL 训练崩溃的根因)
配套仓库(你本地已下载)里,这篇最相关的材料:
- multi-turn SFT 的
loss_mask可视化与数据结构:/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/ReTool-sft.ipynb
- tokenization 一致性检查(不一致会让 mask 与 RL 都一起崩):
/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/tokenizer/encode-decode.ipynb
- 一个可跑的 SFT 启动脚本(FSDP2 + torchrun):
/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/scripts/run_qwen2_7b_sft.sh
1. Teacher forcing:SFT 在训练什么“条件分布”
“Teacher forcing” 在 LLM 的 SFT 里几乎是默认行为:训练时每一步都把 ground-truth token 当作历史喂给模型。
设一条序列是 $x_0,x_1,\dots,x_{T-1}$,causal LM 学的是:
$$\max_\theta\ \sum_{t=0}^{T-2}\log p_\theta(x_{t+1}\mid x_{\le t})$$
这意味着:
- 训练时在位置 $t$ 的条件是 $x_{\le t}$(全是“老师给的正确历史”)
- 但推理时在位置 $t$ 的条件变成了模型自己生成的 $\hat x_{\le t}$(历史里混入了“自己的错误”)
这个 train/test mismatch 通常被叫做 exposure bias。它不是一个“理论问题”,而是你在 agent 场景里马上会踩到的工程问题:
- SFT 能把格式训得很稳(JSON/tool-call/引用结构),但只要历史里出现一次小错,后续格式就可能崩。
- 所以你后面做 RL(PPO/GRPO/RLOO 等)时,rollout 是“自回归采样”,会天然暴露并放大这些问题。
结论(工程视角):
- SFT 阶段必须把协议训稳(工具调用、引用、结构化输出)。否则 RL 的探索只会把漏洞学得更快。
- RL 阶段不是替你发明协议,而是在已有轨迹附近做 reweighting(尤其还有 KL anchor 时)。
2. 为什么要 shift:CE loss 的“错一位”从哪来
causal LM 的 logits 是一个长度为 $T$ 的序列:logits[t] 表示“在看到 $x_{\le t}$ 后对下一个 token 的预测”。
因此 logits 的时间轴 与 labels 的时间轴天然错开 1:
logits[t]用来预测labels[t+1]
把它写成最常见的实现形式就是:
logits[:, :-1]对齐labels[:, 1:]
对应数学形式:
$$\mathcal{L}_\text{CE} = -\sum_{t=0}^{T-2}\log p_\theta(x_{t+1}\mid x_{\le t})$$
你在工程里看到的“shift logits / shift labels”,本质上就是把“预测 $x_{t+1}$ 的那一项”对齐到同一个 index 上去算交叉熵。
这件事看起来基础,但它和 loss_mask 一起时会出现非常典型的 off-by-one bug:
- 你以为你在训 assistant response
- 实际你训到 prompt(甚至训到 system 工具说明)
- 或者你把 response 的第一个 token 的监督丢掉(模型会学得很慢、且开头质量差)
3. Loss mask:你到底要监督哪些 token(以及 shift 前后如何对齐)
3.1 先约定一个最重要的语义:mask 标的是“哪些 label 要被预测”
我强烈建议你在工程里把 loss_mask 理解成:
loss_mask[t] = 1表示 tokenx_t这个 label 需要被预测,也就是它应该出现在 CE 的 target 里。
这样做的好处是:它天然和 “labels=-100 ignore_index” 等价,也能和 HuggingFace 的内置 shift 对齐。
如果你的 mask 标的是“哪些输入 token 参与条件”(也就是 prompt 部分),那它在 CE 上没有直接意义,必须再映射一次,极容易出错。
3.2 一个 4-token 的最小例子:用表格把错位讲明白
假设我们把 prompt 和 response 拼在同一条序列里:
- prompt:
P0 P1 - response:
R0 R1
input_ids(原始 token 序列):
| index | token |
|---|---|
| 0 | P0 |
| 1 | P1 |
| 2 | R0 |
| 3 | R1 |
“只监督 response”的直觉是:我们希望模型学会生成 R0 R1,也就是:
labels[2]=R0、labels[3]=R1参与 losslabels[0], labels[1]不参与
但是 CE 真正在算的是 logits[t] -> labels[t+1],所以 预测 R0 的那一项在 t=1(看到 P0 P1 后预测 R0)。
这就是为什么:
- 如果你手写 per-token loss matrix,mask 通常要跟着 shift:
loss_mask_shifted = loss_mask[:, 1:] - 如果你用 HF 内置 loss(传入
labels),那就把“不要监督的 label”设成-100即可,shift 交给框架
3.3 两种实现方式(推荐优先用 A)
A) 推荐:用 labels=-100 让模型内部 shift(最少坑)
核心是:
labels = input_ids.clone()- 对不监督的位置设
-100(ignore_index) - 把
labels直接传给model(...),让模型内部做 shift
示例(关键逻辑,适配 HF 大多数 causal LM):
1 | input_ids = batch["input_ids"] # [B, T] |
你只要保证 loss_mask 的语义是“哪些 label token 要监督”,这套就稳定。
B) 手动实现:自己 shift + 自己做 masked CE(调试/定制时用)
当你要做 packed sequence、或者要拿到 token-level loss matrix 做复杂聚合(比如 veRL 里常见的 agg_loss),你可能需要手动实现。
要点:
logits = logits[:, :-1, :]labels = labels[:, 1:]loss_mask = loss_mask[:, 1:]
然后对齐算 CE 并按 mask 聚合:
1 | logits = out.logits[:, :-1, :] # [B, T-1, V] |
这段代码写错一个 1:,你就会得到“看起来还能训练,但就是不对劲”的模型。
4. multi-turn + tool-use:mask 到底该怎么标(以及为什么别训 tool output)
仓库里的 ReTool-sft.ipynb 给了一个我认为非常正确的工程习惯:把每条样本逐 token decode 并按 loss_mask 上色可视化。
它的默认策略可以概括成一句话:
- system/user/tool 这三类 token 一般不参与 loss
- 只训练 assistant 的输出(包括你认为需要稳定生成的 tool-call 协议片段)
为什么强烈建议不要训 tool output?
- tool output 本质是“外部世界事实”,它不应该被模型“生成”出来
- 一旦你监督了 tool output,模型会倾向于把它当作可自由编造的 continuation
- 到 RL 阶段,这会被 reward 漏洞放大(模型学会“编造工具返回”骗 verifier/judge)
如果你确实需要模型“复述工具结果”,更稳的做法是:
- tool output 仍然作为输入条件(mask=0)
- assistant 的摘要/解释作为输出监督(mask=1)
5. Debug Checklist:一眼抓住 off-by-one / mask 漂移
这份 checklist 的目标是:不用跑大训练,10 分钟内确认你没把 CE 训歪。
- 抽 20 条样本做 mask 可视化(推荐直接用
ReTool-sft.ipynb的思路) - 打印
masked_tokens / nonpad_tokens的分布:- 过低:有效监督太少,loss 会很抖且学习慢
- 过高:你可能在训 prompt/system/tool
- 做一次“强一致性测试”:
- 把同一条 multi-turn 数据用两种 tokenization 拼接方式编码
- 确认 token 与边界一致(参考
encode-decode.ipynb的警告)
- 用一个极小 batch 过一遍 forward:
- 开启手动 shift 版(B)和 labels=-100 版(A)
- 对比两者 loss 数值是否在同一量级(不要求完全相等,但不应差一个数量级)
6. 为什么这篇补充对 agentic RL 很关键
你后面做 PPO/GRPO 时,核心计算都是 token-level logprob:
logp_new(a_t|s_t)、logp_old(a_t|s_t)、logp_ref(a_t|s_t)
这些 logprob 的“对齐方式”与这里的 shift 完全同构:你还是在拿某个位置的 logits 去对应某个位置的 action token。
所以你可以把这篇文章当成一个基本功定位:
- SFT 阶段把 shift + mask 写对了
- RL 阶段你计算 ratio/KL/entropy 时就不容易把统计量算错(也更容易 debug)
如果你愿意,我下一步可以把这套“shift + mask 对齐”进一步延伸到 RL 端:
logp计算的对齐(prompt/response 的边界如何切)response_mask与loss_mask在 RL 里各自是什么语义- vLLM rollout 下
prompt_logprobs / token_logprobs的坑点如何验证

