这篇文章对应视频:“[veRL] FSDP SFT trainer,SFT vs. RL,交叉熵损失 | loss mask | learning rate scheduler”(BV1CkJgzAEAG)。

补充篇(更聚焦 teacher forcing / shift labels-logits / loss mask 对齐):

进一步把 SFT 接到 tool-use agent 的 cold start(MultiTurn Tool Use / Coding Agent):

但我会把它写成一份“可落地的工程读物”,而不是视频逐句复刻。你看完应该能回答这些问题:

  1. 为什么做 agentic RL / RLHF 之前,SFT 反而是你最不该糊弄的一步?
  2. causal LM 的交叉熵损失到底在算什么?loss_masklabels=-100 到底是一回事吗?
  3. multi-turn 数据里,哪些 token 应该参与 loss?如果 mask 搞错,会把模型训成什么鬼样?
  4. FSDP SFT trainer 到底解决的是什么瓶颈(显存/吞吐/可扩展性)?FSDP2 又是什么?
  5. learning rate scheduler 为什么在 SFT 阶段更关键(甚至比 RL 阶段更“可解释”)?

系列导航:

关联阅读(建议顺序):

  1. veRL 训练参数理解(PPO/GRPO、Batch、KL、Entropy)
  2. veRL 核心算法(GRPO/RLOO/REINFORCE++)与 Baseline
  3. PG loss 组件详解(PPO-clip / KL / Entropy / 聚合)
  4. Tokenizer 非对称性与 Token-in-Token-out(RL 训练崩溃的根因)

0. 资料对齐(视频 + 本地仓库)

视频:

配套仓库(你本地已下载)里,这篇最相关的材料我建议按重要程度看这几份(不要求你都读完):

  1. multi-turn SFT 的 loss_mask 可视化与数据结构:
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/ReTool-sft.ipynb
  2. SFT 与长序列的显存瓶颈(micro batch / SP / remove padding 等):
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/训练及调参经验/sft.ipynb
  3. FSDP 与并行(DP/TP/PP、FSDP vs FSDP2):
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/3D/DP_TP_PP.ipynb
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/3D/fsdp_fsdp2.ipynb
  4. 一个真实可跑的 SFT 启动脚本(FSDP2 + torchrun):
    • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/scripts/run_qwen2_7b_sft.sh

另外一个“很关键但经常被忽略”的坑点(会直接让 RL 不收敛,也会污染 SFT 的 mask 对齐):

  • /Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/tokenizer/encode-decode.ipynb
    • 里面强调了 encode(messages) != (prompt_ids ⊕ response_ids) 这类 token 不一致问题,在训练里是致命的。

1. SFT vs RL:不要把它理解成“先热身再上强度”

很多人把 pipeline 简化成一句话:先 SFT,再 RL。然后就自然得出一个误判:

SFT 只是 warmup,真正的提升靠 RL。

在 agentic RL(工具调用、多轮对话、长上下文)场景里,这个直觉经常是反的:

  1. SFT 决定“动作空间”是否可学
    • tool call 的 JSON schema、引用格式、分段结构、思考模板,本质上都是一种“可执行协议”。
    • 你协议都没训稳,RL 只能在一个错误的动作空间里瞎探索,最后学到的往往是“投机取巧地骗 judge/骗 verifier”。
  2. RL 往往被 KL anchor 限制在 SFT 附近
    • 你后面的 PPO/GRPO 训练几乎一定会用 KL 把策略约束在 ref(通常就是 SFT)附近。
    • 这意味着:SFT 的上限直接限制 RL 的上限
  3. SFT 的 loss 可解释、可 debug,RL 的 reward 不一定
    • 交叉熵的下降是“真的在拟合数据分布”。
    • reward 的上升可能是“骗指标”,你要花更多成本做验证。

所以一个更实用的心智模型是:

  • SFT:把“能做正确动作的先验”写进模型(格式、工具、引用、推理风格)。
  • RL:在这个先验附近做 reweighting,让某些轨迹更常出现(采样效率、稳定性、成本约束)。

2. 交叉熵损失:它到底在优化什么(以及为什么要 shift)

对 causal LM(自回归模型),我们给定输入序列 token:$x_0,x_1,\dots,x_{T-1}$,模型在位置 $t$ 预测下一个 token 的分布 $p_\theta(\cdot|x_{\le t})$。

标准的 token-level 负对数似然(也就是交叉熵)是:

$$\mathcal{L}_{\text{CE}} = -\sum_{t=0}^{T-2}\log p_\theta(x_{t+1}|x_{\le t})$$

工程实现里这就是你常见的“shift”:

  • logits[:, :-1] 对齐 labels[:, 1:]

为什么要强调这个?因为 loss_mask/labels=-100 也必须跟着 shift 对齐,否则你以为你在训 assistant,实际在训 user prompt。

2.1 loss_masklabels=-100:本质上是一回事

你可以用两种等价方式做“只对某些 token 计算 loss”:

  1. mask 版:先算出每个 token 的 loss matrix,再乘以 mask 做聚合。
  2. ignore_index 版:把不参与训练的位置 label 设成 -100(PyTorch/HF 默认 ignore)。

如果把它写成一个统一公式(用 mask 表达更直观):

$$\mathcal{L} = \frac{\sum_t m_t\cdot\left(-\log p_\theta(x_{t+1}|x_{\le t})\right)}{\sum_t m_t + \epsilon}$$

其中 $m_t\in{0,1}$ 表示“这个位置是否计入 loss”。

工程里我更推荐你在日志里同时打印两件事:

  1. masked_token_count = sum(m_t)(否则你根本不知道有效 batch 有多大)
  2. loss_agg_mode(你是 token-mean 还是 seq-mean)

因为这两项经常比“学习率大小”更决定训练动态。


3. loss mask:你在 multi-turn SFT 里到底该训哪些 token

3.1 最常见原则:只训 assistant 说的话

在 chat / tool-use 数据里,一条样本往往包含:

  • system:系统提示
  • user:用户输入
  • tool:工具返回(或者工具 schema)
  • assistant:模型应该输出的内容(可能包含 reasoning、tool_call、final answer)

ReTool-sft.ipynb 的总结非常实用:system/user/tool 都应该 mask 掉,只对 assistant 的输出计算 loss

如果你没这么做,你很容易把模型训坏:

  1. 训练它复读 system prompt(线上看起来像“人格绑定”)
  2. 训练它复读 user 的提问(答非所问)
  3. 训练它去“生成工具输出”(这是最危险的:工具输出本来应该是外部世界的事实)

3.2 multi-turn 的难点:assistant 的“哪一段”算输出

tool-use 体系里 assistant 可能输出两类东西:

  1. 直接回答用户(final answer)
  2. 输出工具调用(<tool_call>{...}</tool_call> 或 JSON)

要不要对 tool_call 计算 loss?

工程上我建议你按目标来:

  1. 如果你的 agent 需要稳定地产生可解析的 tool_call:tool_call 必须训(mask=1)。
  2. 如果 tool_call 会被上层 planner/DSL 强约束生成:可以考虑只训 final answer,把 tool_call 交给规则或 structured decoding。

但无论哪一种,system/user/tool output 一般都不应该训。

3.3 最可靠的做法:把 mask 可视化(别凭感觉)

ReTool-sft.ipynb 给了一个很好的工程习惯:把 loss_mask 逐 token decode 并染色打印出来,肉眼一眼就能看出你到底在训练哪里。

你不需要照抄 notebook 里的彩色控制符,核心逻辑就两步:

  1. decode 每个 token
  2. 同步打印 mask=1 的位置(比如用 [] 包起来)

示例伪代码(核心思想,非 veRL 专用):

1
2
3
4
5
6
7
8
def debug_print_mask(tokenizer, input_ids, loss_mask, max_tokens=400):
pairs = list(zip(input_ids, loss_mask))[:max_tokens]
out = []
for tid, m in pairs:
tok = tokenizer.decode([tid])
out.append(f\"[{tok}]\" if m == 1 else tok)
print(\"\".join(out))
print(\"masked_tokens:\", sum(loss_mask), \"/\", len(loss_mask))

你应该把它当成 SFT 的“单元测试”。跑 3 条样本都对齐了,再开始跑大训练。


4. 一个容易被低估的坑:token 不一致会让 mask 和训练目标一起漂

encode-decode.ipynb 里有一句我强烈建议你记住:

encode(messages) != (prompt_ids ⊕ response_ids)

直觉上我们会以为:

  • 把 messages 用 apply_chat_template 编码一次
  • 等价于“每轮把 prompt 编码,再把 response 拼上去”

但在真实框架里,这两者可能不等价(模板中间插入的控制 token、空格处理、tool role 的拼接方式都可能导致差异)。

对 SFT,这会导致两类问题:

  1. 你以为 mask 在 assistant response 上,但 token 边界其实错位了
  2. 你以为你在拟合“最终消息序列分布”,实际拟合了一个不存在的拼接分布

对 RL(PPO/GRPO)更致命:它会让轨迹偏离策略分布,直接导致 ratio/KL 统计失真,训练不收敛。

工程建议(不一定要完全照 veRL,但思想要一致):

  1. 训练用的 tokenization 路径必须和线上/rollout 一致(同一个 template,同一个拼接方式)。
  2. 只要涉及 multi-turn + tool,你就应该加入 “token-level 对齐检查”(mask 可视化只是第一步)。

5. FSDP SFT Trainer:它解决的不是“更快”,而是“能跑”

当你开始做长上下文(8k/16k/32k)+ 多轮 + 大模型,SFT 最先爆的往往不是算力,而是显存。

5.1 FSDP 的一句话解释:把“模型本体”切碎分到多卡上

在传统 DDP 下,每张卡都有一份完整模型参数、梯度、优化器状态。

FSDP(Fully Sharded Data Parallel)的核心是:把这些大头都 sharding 掉:

  1. 参数(weights)分片
  2. 梯度(grads)分片
  3. 优化器状态(optimizer states)分片

于是每张卡只存自己那一片。计算时需要哪一片,就在前向时 all-gather,反向时 reduce-scatter。

这就是你在很多笔记里看到的关键词:

  • all-gather / reduce-scatter

5.2 FSDP2(fully_shard)你可以先理解成“更细粒度、更少常驻副本”

如果你不想陷进实现细节,一个够用的工程直觉是:

  1. FSDP1:以模块为单位 all-gather(FlatParameter),副本驻留窗口相对大
  2. FSDP2:更细粒度按需 shard/reshard,副本驻留窗口更窄,可重叠机会更多

仓库里的 run_qwen2_7b_sft.sh 也直接把 model.strategy=fsdp2 当作默认推荐。

5.3 SFT 训练时,你真正要盯的是“micro batch 才是显存开关”

SFT 的工程规律非常朴素:

  • 真正把你打爆显存的通常是 激活(activations)+ KV/attention 的中间量,而不是参数本身。

这也是为什么即便用了 FSDP,你仍然需要:

  1. micro batch 小(data.micro_batch_size_per_gpu
  2. gradient accumulation 做大 effective batch
  3. remove padding / packed sequence(减少无效 token)
  4. 必要时上 sequence parallel(SP)或 activation checkpointing

你可以把它和你前面那篇 veRL 参数文章联系起来:很多“看起来像训练参数”的东西,本质是在调显存与吞吐的 tradeoff。


6. learning rate scheduler:SFT 阶段反而更值得你认真设计

SFT 的 loss 是交叉熵,它的形态比 RL 的 reward 更“线性可解释”,所以 scheduler 的收益也更可预期。

6.1 一个够用的默认策略:warmup + cosine decay

如果你没有特别理由,SFT 我建议你先用:

  1. warmup:1% - 3% 的总步数
  2. cosine decay 到一个较小的 min_lr(比如 lr * 0.1 或更低)

直觉:

  • warmup 解决 early-stage 的梯度不稳定(尤其是长上下文、混合精度、FSDP 通信下)
  • cosine decay 让后期收敛更平滑,减少“后半程学坏”的风险

6.2 scheduler 的最大坑:你用的是 optimizer step 还是 micro step

有 gradient accumulation 时:

  • micro step(每次 forward/backward)不等于 optimizer step(真正更新参数)

scheduler 应该跟着 optimizer step 走,否则你的 lr 会被错误地 decay 得过快。

工程上最推荐你在日志里打印三件事:

  1. global_step(optimizer step)
  2. lr(真实学习率)
  3. tokens_per_step(有效训练 token 数)

你一旦把“lr 曲线”和“有效 token 曲线”对齐,很多训练异常(loss spike/不收敛/学坏)会变得很好解释。


7. SFT 写在 agentic RL 计划里:我建议你怎么做

如果你的最终目标是 deep research / tool-use agentic RL,我会把 SFT 当成一个明确的阶段目标,而不是“先跑个 baseline”:

  1. 先用 SFT 把协议训稳:
    • tool schema、引用格式、输出结构(可解析、可验证)
  2. 然后再用 RL 优化你真正关心的指标:
    • 引用正确率、事实一致性、覆盖度、成本/时延、pass@k 等

原因很现实:RL 不能替你把协议发明出来,RL 只会把已有轨迹的概率 reweight。协议都不稳定,RL 只会放大漏洞。


8. 一份能直接用的工程 checklist(建议你贴到 Notion 里)

SFT 跑之前:

  1. 抽样可视化 loss_mask(至少 20 条样本)
  2. 统计每条样本的 masked_token_ratio(太低说明你在训很少 token,训练会很慢且不稳定)
  3. 确认 tokenization 路径一致(尤其 multi-turn + tool)

SFT 跑起来之后(每 200-500 step 一次):

  1. 看训练/验证 loss 是否同时下降(只看 train loss 容易被过拟合骗)
  2. 看输出样例是否“协议稳定”(工具调用是否可解析、引用是否仍存在)
  3. lr/grad_norm/tokens_per_step 是否健康(scheduler + 有效 batch 闭环)

如果你愿意把这一步做扎实,后面的 RL(PPO/GRPO)会容易一个数量级。