这篇文章对应视频:“[veRL] FSDP SFT trainer,SFT vs. RL,交叉熵损失 | loss mask | learning rate scheduler”(BV1CkJgzAEAG)。
补充篇(更聚焦 teacher forcing / shift labels-logits / loss mask 对齐):
进一步把 SFT 接到 tool-use agent 的 cold start(MultiTurn Tool Use / Coding Agent):
但我会把它写成一份“可落地的工程读物”,而不是视频逐句复刻。你看完应该能回答这些问题:
- 为什么做 agentic RL / RLHF 之前,SFT 反而是你最不该糊弄的一步?
- causal LM 的交叉熵损失到底在算什么?
loss_mask和labels=-100到底是一回事吗? - multi-turn 数据里,哪些 token 应该参与 loss?如果 mask 搞错,会把模型训成什么鬼样?
- FSDP SFT trainer 到底解决的是什么瓶颈(显存/吞吐/可扩展性)?FSDP2 又是什么?
- learning rate scheduler 为什么在 SFT 阶段更关键(甚至比 RL 阶段更“可解释”)?
系列导航:
关联阅读(建议顺序):
- veRL 训练参数理解(PPO/GRPO、Batch、KL、Entropy)
- veRL 核心算法(GRPO/RLOO/REINFORCE++)与 Baseline
- PG loss 组件详解(PPO-clip / KL / Entropy / 聚合)
- Tokenizer 非对称性与 Token-in-Token-out(RL 训练崩溃的根因)
0. 资料对齐(视频 + 本地仓库)
视频:
BV1CkJgzAEAG
配套仓库(你本地已下载)里,这篇最相关的材料我建议按重要程度看这几份(不要求你都读完):
- multi-turn SFT 的
loss_mask可视化与数据结构:/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/ReTool-sft.ipynb
- SFT 与长序列的显存瓶颈(micro batch / SP / remove padding 等):
/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/训练及调参经验/sft.ipynb
- FSDP 与并行(DP/TP/PP、FSDP vs FSDP2):
/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/3D/DP_TP_PP.ipynb/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/3D/fsdp_fsdp2.ipynb
- 一个真实可跑的 SFT 启动脚本(FSDP2 + torchrun):
/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/retool/scripts/run_qwen2_7b_sft.sh
另外一个“很关键但经常被忽略”的坑点(会直接让 RL 不收敛,也会污染 SFT 的 mask 对齐):
/Users/wangpeng/Downloads/modern_genai_bilibili-main/agentic_rl/verl/tokenizer/encode-decode.ipynb- 里面强调了
encode(messages) != (prompt_ids ⊕ response_ids)这类 token 不一致问题,在训练里是致命的。
- 里面强调了
1. SFT vs RL:不要把它理解成“先热身再上强度”
很多人把 pipeline 简化成一句话:先 SFT,再 RL。然后就自然得出一个误判:
SFT 只是 warmup,真正的提升靠 RL。
在 agentic RL(工具调用、多轮对话、长上下文)场景里,这个直觉经常是反的:
- SFT 决定“动作空间”是否可学
- tool call 的 JSON schema、引用格式、分段结构、思考模板,本质上都是一种“可执行协议”。
- 你协议都没训稳,RL 只能在一个错误的动作空间里瞎探索,最后学到的往往是“投机取巧地骗 judge/骗 verifier”。
- RL 往往被 KL anchor 限制在 SFT 附近
- 你后面的 PPO/GRPO 训练几乎一定会用 KL 把策略约束在 ref(通常就是 SFT)附近。
- 这意味着:SFT 的上限直接限制 RL 的上限。
- SFT 的 loss 可解释、可 debug,RL 的 reward 不一定
- 交叉熵的下降是“真的在拟合数据分布”。
- reward 的上升可能是“骗指标”,你要花更多成本做验证。
所以一个更实用的心智模型是:
- SFT:把“能做正确动作的先验”写进模型(格式、工具、引用、推理风格)。
- RL:在这个先验附近做 reweighting,让某些轨迹更常出现(采样效率、稳定性、成本约束)。
2. 交叉熵损失:它到底在优化什么(以及为什么要 shift)
对 causal LM(自回归模型),我们给定输入序列 token:$x_0,x_1,\dots,x_{T-1}$,模型在位置 $t$ 预测下一个 token 的分布 $p_\theta(\cdot|x_{\le t})$。
标准的 token-level 负对数似然(也就是交叉熵)是:
$$\mathcal{L}_{\text{CE}} = -\sum_{t=0}^{T-2}\log p_\theta(x_{t+1}|x_{\le t})$$
工程实现里这就是你常见的“shift”:
logits[:, :-1]对齐labels[:, 1:]
为什么要强调这个?因为 loss_mask/labels=-100 也必须跟着 shift 对齐,否则你以为你在训 assistant,实际在训 user prompt。
2.1 loss_mask 与 labels=-100:本质上是一回事
你可以用两种等价方式做“只对某些 token 计算 loss”:
- mask 版:先算出每个 token 的 loss matrix,再乘以 mask 做聚合。
- ignore_index 版:把不参与训练的位置 label 设成
-100(PyTorch/HF 默认 ignore)。
如果把它写成一个统一公式(用 mask 表达更直观):
$$\mathcal{L} = \frac{\sum_t m_t\cdot\left(-\log p_\theta(x_{t+1}|x_{\le t})\right)}{\sum_t m_t + \epsilon}$$
其中 $m_t\in{0,1}$ 表示“这个位置是否计入 loss”。
工程里我更推荐你在日志里同时打印两件事:
masked_token_count = sum(m_t)(否则你根本不知道有效 batch 有多大)loss_agg_mode(你是 token-mean 还是 seq-mean)
因为这两项经常比“学习率大小”更决定训练动态。
3. loss mask:你在 multi-turn SFT 里到底该训哪些 token
3.1 最常见原则:只训 assistant 说的话
在 chat / tool-use 数据里,一条样本往往包含:
- system:系统提示
- user:用户输入
- tool:工具返回(或者工具 schema)
- assistant:模型应该输出的内容(可能包含 reasoning、tool_call、final answer)
ReTool-sft.ipynb 的总结非常实用:system/user/tool 都应该 mask 掉,只对 assistant 的输出计算 loss。
如果你没这么做,你很容易把模型训坏:
- 训练它复读 system prompt(线上看起来像“人格绑定”)
- 训练它复读 user 的提问(答非所问)
- 训练它去“生成工具输出”(这是最危险的:工具输出本来应该是外部世界的事实)
3.2 multi-turn 的难点:assistant 的“哪一段”算输出
tool-use 体系里 assistant 可能输出两类东西:
- 直接回答用户(final answer)
- 输出工具调用(
<tool_call>{...}</tool_call>或 JSON)
要不要对 tool_call 计算 loss?
工程上我建议你按目标来:
- 如果你的 agent 需要稳定地产生可解析的 tool_call:tool_call 必须训(mask=1)。
- 如果 tool_call 会被上层 planner/DSL 强约束生成:可以考虑只训 final answer,把 tool_call 交给规则或 structured decoding。
但无论哪一种,system/user/tool output 一般都不应该训。
3.3 最可靠的做法:把 mask 可视化(别凭感觉)
ReTool-sft.ipynb 给了一个很好的工程习惯:把 loss_mask 逐 token decode 并染色打印出来,肉眼一眼就能看出你到底在训练哪里。
你不需要照抄 notebook 里的彩色控制符,核心逻辑就两步:
- decode 每个 token
- 同步打印 mask=1 的位置(比如用
[]包起来)
示例伪代码(核心思想,非 veRL 专用):
1 | def debug_print_mask(tokenizer, input_ids, loss_mask, max_tokens=400): |
你应该把它当成 SFT 的“单元测试”。跑 3 条样本都对齐了,再开始跑大训练。
4. 一个容易被低估的坑:token 不一致会让 mask 和训练目标一起漂
encode-decode.ipynb 里有一句我强烈建议你记住:
encode(messages) != (prompt_ids ⊕ response_ids)
直觉上我们会以为:
- 把 messages 用
apply_chat_template编码一次 - 等价于“每轮把 prompt 编码,再把 response 拼上去”
但在真实框架里,这两者可能不等价(模板中间插入的控制 token、空格处理、tool role 的拼接方式都可能导致差异)。
对 SFT,这会导致两类问题:
- 你以为 mask 在 assistant response 上,但 token 边界其实错位了
- 你以为你在拟合“最终消息序列分布”,实际拟合了一个不存在的拼接分布
对 RL(PPO/GRPO)更致命:它会让轨迹偏离策略分布,直接导致 ratio/KL 统计失真,训练不收敛。
工程建议(不一定要完全照 veRL,但思想要一致):
- 训练用的 tokenization 路径必须和线上/rollout 一致(同一个 template,同一个拼接方式)。
- 只要涉及 multi-turn + tool,你就应该加入 “token-level 对齐检查”(mask 可视化只是第一步)。
5. FSDP SFT Trainer:它解决的不是“更快”,而是“能跑”
当你开始做长上下文(8k/16k/32k)+ 多轮 + 大模型,SFT 最先爆的往往不是算力,而是显存。
5.1 FSDP 的一句话解释:把“模型本体”切碎分到多卡上
在传统 DDP 下,每张卡都有一份完整模型参数、梯度、优化器状态。
FSDP(Fully Sharded Data Parallel)的核心是:把这些大头都 sharding 掉:
- 参数(weights)分片
- 梯度(grads)分片
- 优化器状态(optimizer states)分片
于是每张卡只存自己那一片。计算时需要哪一片,就在前向时 all-gather,反向时 reduce-scatter。
这就是你在很多笔记里看到的关键词:
- all-gather / reduce-scatter
5.2 FSDP2(fully_shard)你可以先理解成“更细粒度、更少常驻副本”
如果你不想陷进实现细节,一个够用的工程直觉是:
- FSDP1:以模块为单位 all-gather(FlatParameter),副本驻留窗口相对大
- FSDP2:更细粒度按需 shard/reshard,副本驻留窗口更窄,可重叠机会更多
仓库里的 run_qwen2_7b_sft.sh 也直接把 model.strategy=fsdp2 当作默认推荐。
5.3 SFT 训练时,你真正要盯的是“micro batch 才是显存开关”
SFT 的工程规律非常朴素:
- 真正把你打爆显存的通常是 激活(activations)+ KV/attention 的中间量,而不是参数本身。
这也是为什么即便用了 FSDP,你仍然需要:
- micro batch 小(
data.micro_batch_size_per_gpu) - gradient accumulation 做大 effective batch
- remove padding / packed sequence(减少无效 token)
- 必要时上 sequence parallel(SP)或 activation checkpointing
你可以把它和你前面那篇 veRL 参数文章联系起来:很多“看起来像训练参数”的东西,本质是在调显存与吞吐的 tradeoff。
6. learning rate scheduler:SFT 阶段反而更值得你认真设计
SFT 的 loss 是交叉熵,它的形态比 RL 的 reward 更“线性可解释”,所以 scheduler 的收益也更可预期。
6.1 一个够用的默认策略:warmup + cosine decay
如果你没有特别理由,SFT 我建议你先用:
- warmup:1% - 3% 的总步数
- cosine decay 到一个较小的
min_lr(比如lr * 0.1或更低)
直觉:
- warmup 解决 early-stage 的梯度不稳定(尤其是长上下文、混合精度、FSDP 通信下)
- cosine decay 让后期收敛更平滑,减少“后半程学坏”的风险
6.2 scheduler 的最大坑:你用的是 optimizer step 还是 micro step
有 gradient accumulation 时:
- micro step(每次 forward/backward)不等于 optimizer step(真正更新参数)
scheduler 应该跟着 optimizer step 走,否则你的 lr 会被错误地 decay 得过快。
工程上最推荐你在日志里打印三件事:
global_step(optimizer step)lr(真实学习率)tokens_per_step(有效训练 token 数)
你一旦把“lr 曲线”和“有效 token 曲线”对齐,很多训练异常(loss spike/不收敛/学坏)会变得很好解释。
7. SFT 写在 agentic RL 计划里:我建议你怎么做
如果你的最终目标是 deep research / tool-use agentic RL,我会把 SFT 当成一个明确的阶段目标,而不是“先跑个 baseline”:
- 先用 SFT 把协议训稳:
- tool schema、引用格式、输出结构(可解析、可验证)
- 然后再用 RL 优化你真正关心的指标:
- 引用正确率、事实一致性、覆盖度、成本/时延、pass@k 等
原因很现实:RL 不能替你把协议发明出来,RL 只会把已有轨迹的概率 reweight。协议都不稳定,RL 只会放大漏洞。
8. 一份能直接用的工程 checklist(建议你贴到 Notion 里)
SFT 跑之前:
- 抽样可视化
loss_mask(至少 20 条样本) - 统计每条样本的
masked_token_ratio(太低说明你在训很少 token,训练会很慢且不稳定) - 确认 tokenization 路径一致(尤其 multi-turn + tool)
SFT 跑起来之后(每 200-500 step 一次):
- 看训练/验证 loss 是否同时下降(只看 train loss 容易被过拟合骗)
- 看输出样例是否“协议稳定”(工具调用是否可解析、引用是否仍存在)
- 看
lr/grad_norm/tokens_per_step是否健康(scheduler + 有效 batch 闭环)
如果你愿意把这一步做扎实,后面的 RL(PPO/GRPO)会容易一个数量级。

