这篇文章对应视频:“[veRL] fsdp sft trainer 补充,teacher forcing、shift labels shift logits、loss mask”(BV1eWjtzbEdP)。

它是上一篇 SFT trainer 文章的“补充篇”,专门把三个最容易写错、但一错就会把模型训歪的细节讲透:

  1. Teacher forcing:SFT 到底在“喂什么”给模型,喂错会导致什么偏差。
  2. Shift labels / shift logits:为什么 causal LM 的 CE loss 天生存在“错一位”,实现里你必须显式对齐。
  3. Loss mask:multi-turn + tool-use 数据里,你到底要监督哪些 token;mask 在 shift 前后怎么对齐。

系列导航:

关联阅读(建议先看主篇再看补充):

配套仓库(你本地已下载)里,这篇最相关的材料:

  1. multi-turn SFT 的 loss_mask 可视化与数据结构:
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/ReTool-sft.ipynb
  2. tokenization 一致性检查(不一致会让 mask 与 RL 都一起崩):
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/tokenizer/encode-decode.ipynb
  3. 一个可跑的 SFT 启动脚本(FSDP2 + torchrun):
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/scripts/run_qwen2_7b_sft.sh

1. Teacher forcing:SFT 在训练什么“条件分布”

“Teacher forcing” 在 LLM 的 SFT 里几乎是默认行为:训练时每一步都把 ground-truth token 当作历史喂给模型

设一条序列是 $x_0,x_1,\dots,x_{T-1}$,causal LM 学的是:

$$\max_\theta\ \sum_{t=0}^{T-2}\log p_\theta(x_{t+1}\mid x_{\le t})$$

这意味着:

  • 训练时在位置 $t$ 的条件是 $x_{\le t}$(全是“老师给的正确历史”)
  • 但推理时在位置 $t$ 的条件变成了模型自己生成的 $\hat x_{\le t}$(历史里混入了“自己的错误”)

这个 train/test mismatch 通常被叫做 exposure bias。它不是一个“理论问题”,而是你在 agent 场景里马上会踩到的工程问题:

  • SFT 能把格式训得很稳(JSON/tool-call/引用结构),但只要历史里出现一次小错,后续格式就可能崩。
  • 所以你后面做 RL(PPO/GRPO/RLOO 等)时,rollout 是“自回归采样”,会天然暴露并放大这些问题。

结论(工程视角):

  1. SFT 阶段必须把协议训稳(工具调用、引用、结构化输出)。否则 RL 的探索只会把漏洞学得更快。
  2. RL 阶段不是替你发明协议,而是在已有轨迹附近做 reweighting(尤其还有 KL anchor 时)。

2. 为什么要 shift:CE loss 的“错一位”从哪来

causal LM 的 logits 是一个长度为 $T$ 的序列:logits[t] 表示“在看到 $x_{\le t}$ 后对下一个 token 的预测”。

因此 logits 的时间轴labels 的时间轴天然错开 1:

  • logits[t] 用来预测 labels[t+1]

把它写成最常见的实现形式就是:

  • logits[:, :-1] 对齐 labels[:, 1:]

对应数学形式:

$$\mathcal{L}_\text{CE} = -\sum_{t=0}^{T-2}\log p_\theta(x_{t+1}\mid x_{\le t})$$

你在工程里看到的“shift logits / shift labels”,本质上就是把“预测 $x_{t+1}$ 的那一项”对齐到同一个 index 上去算交叉熵。

这件事看起来基础,但它和 loss_mask 一起时会出现非常典型的 off-by-one bug:

  • 你以为你在训 assistant response
  • 实际你训到 prompt(甚至训到 system 工具说明)
  • 或者你把 response 的第一个 token 的监督丢掉(模型会学得很慢、且开头质量差)

3. Loss mask:你到底要监督哪些 token(以及 shift 前后如何对齐)

3.1 先约定一个最重要的语义:mask 标的是“哪些 label 要被预测”

我强烈建议你在工程里把 loss_mask 理解成:

loss_mask[t] = 1 表示 token x_t 这个 label 需要被预测,也就是它应该出现在 CE 的 target 里。

这样做的好处是:它天然和 “labels=-100 ignore_index” 等价,也能和 HuggingFace 的内置 shift 对齐。

如果你的 mask 标的是“哪些输入 token 参与条件”(也就是 prompt 部分),那它在 CE 上没有直接意义,必须再映射一次,极容易出错。

3.2 一个 4-token 的最小例子:用表格把错位讲明白

假设我们把 prompt 和 response 拼在同一条序列里:

  • prompt: P0 P1
  • response: R0 R1

input_ids(原始 token 序列):

indextoken
0P0
1P1
2R0
3R1

“只监督 response”的直觉是:我们希望模型学会生成 R0 R1,也就是:

  • labels[2]=R0labels[3]=R1 参与 loss
  • labels[0], labels[1] 不参与

但是 CE 真正在算的是 logits[t] -> labels[t+1],所以 预测 R0 的那一项在 t=1(看到 P0 P1 后预测 R0)。

这就是为什么:

  • 如果你手写 per-token loss matrix,mask 通常要跟着 shift:loss_mask_shifted = loss_mask[:, 1:]
  • 如果你用 HF 内置 loss(传入 labels),那就把“不要监督的 label”设成 -100 即可,shift 交给框架

3.3 两种实现方式(推荐优先用 A)

A) 推荐:用 labels=-100 让模型内部 shift(最少坑)

核心是:

  1. labels = input_ids.clone()
  2. 对不监督的位置设 -100(ignore_index)
  3. labels 直接传给 model(...),让模型内部做 shift

示例(关键逻辑,适配 HF 大多数 causal LM):

1
2
3
4
5
6
7
8
input_ids = batch["input_ids"]          # [B, T]
loss_mask = batch["loss_mask"].bool() # [B, T] 1=supervise label token

labels = input_ids.clone()
labels[~loss_mask] = -100

out = model(input_ids=input_ids, labels=labels)
loss = out.loss

你只要保证 loss_mask 的语义是“哪些 label token 要监督”,这套就稳定。

B) 手动实现:自己 shift + 自己做 masked CE(调试/定制时用)

当你要做 packed sequence、或者要拿到 token-level loss matrix 做复杂聚合(比如 veRL 里常见的 agg_loss),你可能需要手动实现。

要点:

  • logits = logits[:, :-1, :]
  • labels = labels[:, 1:]
  • loss_mask = loss_mask[:, 1:]

然后对齐算 CE 并按 mask 聚合:

1
2
3
4
5
6
7
8
9
10
11
logits = out.logits[:, :-1, :]              # [B, T-1, V]
labels = input_ids[:, 1:] # [B, T-1]
mask = loss_mask[:, 1:].float() # [B, T-1]

per_tok = torch.nn.functional.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1),
reduction="none",
).reshape_as(labels) # [B, T-1]

loss = (per_tok * mask).sum() / (mask.sum() + 1e-8)

这段代码写错一个 1:,你就会得到“看起来还能训练,但就是不对劲”的模型。


4. multi-turn + tool-use:mask 到底该怎么标(以及为什么别训 tool output)

仓库里的 ReTool-sft.ipynb 给了一个我认为非常正确的工程习惯:把每条样本逐 token decode 并按 loss_mask 上色可视化

它的默认策略可以概括成一句话:

  • system/user/tool 这三类 token 一般不参与 loss
  • 只训练 assistant 的输出(包括你认为需要稳定生成的 tool-call 协议片段)

为什么强烈建议不要训 tool output?

  1. tool output 本质是“外部世界事实”,它不应该被模型“生成”出来
  2. 一旦你监督了 tool output,模型会倾向于把它当作可自由编造的 continuation
  3. 到 RL 阶段,这会被 reward 漏洞放大(模型学会“编造工具返回”骗 verifier/judge)

如果你确实需要模型“复述工具结果”,更稳的做法是:

  • tool output 仍然作为输入条件(mask=0)
  • assistant 的摘要/解释作为输出监督(mask=1)

5. Debug Checklist:一眼抓住 off-by-one / mask 漂移

这份 checklist 的目标是:不用跑大训练,10 分钟内确认你没把 CE 训歪

  1. 抽 20 条样本做 mask 可视化(推荐直接用 ReTool-sft.ipynb 的思路)
  2. 打印 masked_tokens / nonpad_tokens 的分布:
    • 过低:有效监督太少,loss 会很抖且学习慢
    • 过高:你可能在训 prompt/system/tool
  3. 做一次“强一致性测试”:
    • 把同一条 multi-turn 数据用两种 tokenization 拼接方式编码
    • 确认 token 与边界一致(参考 encode-decode.ipynb 的警告)
  4. 用一个极小 batch 过一遍 forward:
    • 开启手动 shift 版(B)和 labels=-100 版(A)
    • 对比两者 loss 数值是否在同一量级(不要求完全相等,但不应差一个数量级)

6. 为什么这篇补充对 agentic RL 很关键

你后面做 PPO/GRPO 时,核心计算都是 token-level logprob:

  • logp_new(a_t|s_t)logp_old(a_t|s_t)logp_ref(a_t|s_t)

这些 logprob 的“对齐方式”与这里的 shift 完全同构:你还是在拿某个位置的 logits 去对应某个位置的 action token

所以你可以把这篇文章当成一个基本功定位:

  • SFT 阶段把 shift + mask 写对了
  • RL 阶段你计算 ratio/KL/entropy 时就不容易把统计量算错(也更容易 debug)

如果你愿意,我下一步可以把这套“shift + mask 对齐”进一步延伸到 RL 端:

  1. logp 计算的对齐(prompt/response 的边界如何切)
  2. response_maskloss_mask 在 RL 里各自是什么语义
  3. vLLM rollout 下 prompt_logprobs / token_logprobs 的坑点如何验证