Mamba: Linear-Time Sequence Modeling with Selective SSMs

论文地址:https://arxiv.org/pdf/2312.00752

代码地址:https://github.com/state-spaces/mamba

引言

Mamba 是一种基于状态空间模型(State Space Model, SSM)的高效序列建模框架,旨在在保持强表达能力的同时,将计算与内存复杂度降至与序列长度线性相关。与Transformer的二次复杂度相比,Mamba在超长序列、低时延和内存受限场景中具有显著优势。

Mamba的核心在于“选择性扫描(Selective Scan)”与“输入依赖的状态转移”,通过对经典S4(Structured State Space Sequence Model)的工程化与理论改进,实现端到端可训练、GPU友好、且具有SOTA性能的线性时间序列模型。

背景知识:状态空间模型(SSM)与S4

连续与离散SSM

连续时间SSM:
$$
\dot x(t) = A x(t) + B u(t), \quad y(t) = C x(t) + D u(t)
$$
离散化后:
$$
x_{k+1} = \bar A x_k + \bar B u_k, \quad y_k = C x_k + D u_k
$$
其中 $x$ 为隐状态,$u$ 为输入,$y$ 为输出。

S4的关键思想

  • 通过特定结构(HiPPO 等)构造稳定的 $A$ 矩阵,捕捉长程依赖;
  • 使用高效卷积实现序列到序列的映射(状态演化可转换为一维卷积核);
  • 优点:理论稳定、长依赖强;
  • 局限:实现复杂、某些硬件/场景下吞吐有限,且与输入的自适应性不足。

Mamba的核心创新

  1. 选择性扫描(Selective Scan)

    • 令状态转移与输入耦合,通过门控/选择性机制动态调节信息流;
    • 避免固定核卷积的刚性,增强对非稳态、稀疏事件的响应能力。
  2. 输入条件化的SSM参数化

    • 将 $B, C$ 等参数设为输入依赖(conditioning on input),提升表达力;
    • 结合高效实现,使其仍保持线性复杂度。
  3. GPU友好的实现与块并行

    • 通过分块扫描(block scan)与前缀-后缀合并,实现可并行化的线性时间推理;
    • 内存访问模式优化,显著提升吞吐。
  4. 端到端可训练的稳定性

    • 对状态矩阵参数施加稳定性约束(如谱半径控制、隐式参数化),保证数值稳定。

模型架构

典型的 Mamba Block 包含:

  • 输入投影与门控(选择性)
  • 选择性SSM扫描(线性复杂度)
  • 前馈网络(MLP/GEGLU)
  • 残差连接与归一化

流程示意:

1
X → InputProj → Selective SSM (scan) → OutputProj → MLP → Residual/Norm

相较Transformer:不使用全局自注意力,而以“可学习的一维核 + 选择性门控”的形式进行token混合;相较S4:对输入进行条件化,提升适配性与表达力。

选择性扫描(Selective Scan)要点

  • 将序列划分为若干块(blocks),每块内部执行线性递推扫描;
  • 记录块末端的状态作为“前缀状态”,在合并阶段向后续块传递;
  • 通过选择性/门控,抑制无用状态更新,突出关键位置的信息流;
  • 与并行卷积不同,选择性扫描可随输入动态调整有效感受野。

复杂度与性能

  • 时间复杂度:$O(N)$ 随序列长度线性;
  • 空间复杂度:$O(N)$(可通过流水线/检查点进一步降内存峰值);
  • 对超长序列(>4K、>16K tokens)训练与推理更加高效;
  • 在语言建模、语音、时间序列等任务上与或优于同级别Transformer。

训练与实践

  • 优化器:AdamW;学习率余弦退火;warmup数千步;
  • 正则:Dropout/Stochastic Depth 按深度线性增长;
  • 初始化:对状态矩阵采用稳定化参数化(如对角+低秩);
  • 长序列技巧:梯度检查点、混合精度、分块扫描大小调优。

与Transformer/Conv的比较

方向TransformerConvNetMamba(SSM)
复杂度O(N^2)O(N)O(N)
长距离依赖强(显式全连接)弱-中强(稳定递推核)
并行性训练强/推理弱(自回归)强(块并行扫描)
归纳偏置强(局部)中(递推+选择性)
适用场景通用局部模式超长序列/低时延

代码片段(概念化示例)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn

class SelectiveSSM(nn.Module):
def __init__(self, d_model, state_size):
super().__init__()
self.d_model = d_model
self.state_size = state_size
# 输入依赖的门控/参数化(示意)
self.in_proj = nn.Linear(d_model, 3 * state_size)
self.out_proj = nn.Linear(state_size, d_model)

def forward(self, x):
# x: [B, N, D]
B, N, D = x.shape
a, b, g = self.in_proj(x).chunk(3, dim=-1) # 输入条件化参数(示意)
state = torch.zeros(B, self.state_size, device=x.device)
outputs = []
for t in range(N): # 实际实现会使用块并行,这里仅示意
state = torch.tanh(a[:, t, :] * state + b[:, t, :])
state = g[:, t, :].sigmoid() * state # 选择性门控
outputs.append(state)
y = torch.stack(outputs, dim=1) # [B, N, S]
return self.out_proj(y)

class MambaBlock(nn.Module):
def __init__(self, d_model, state_size, mlp_ratio=4.0, p=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.ssm = SelectiveSSM(d_model, state_size)
self.drop1 = nn.Dropout(p)
self.norm2 = nn.LayerNorm(d_model)
hidden = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(p),
nn.Linear(hidden, d_model), nn.Dropout(p)
)

def forward(self, x):
x = x + self.drop1(self.ssm(self.norm1(x)))
x = x + self.mlp(self.norm2(x))
return x

注:上例为思想演示,真实Mamba使用更优化的扫描与参数化,且为高效CUDA/Flash实现;请参考官方实现。

实验要点与经验

  • 序列长度:尽量使用长序列预训练以发挥线性复杂度优势;
  • 批量与块大小:根据显存调优块扫描长度,保持吞吐与稳定;
  • 任务迁移:语言→语音/时间序列时,适当调整状态规模与门控强度;
  • 可解释性:通过可视化门控开启位置与核响应,分析模型关注点。

优缺点

优点

  1. 线性复杂度,适合超长序列与端侧低时延应用;
  2. 选择性门控使得对稀疏/非稳态事件更敏感;
  3. GPU友好实现,训练/推理吞吐高;
  4. 理论上可与注意力/卷积并行或互补,形成混合结构。

缺点

  1. 对全局两两交互的显式建模不如注意力直观;
  2. 超参数(状态规模、门控强度、离散化方式)对稳定性敏感;
  3. 在某些视觉任务(如密集预测)上需额外结构适配(见MambaOut)。

相关与后续工作

  • S4/S5:结构化SSM与其后续改进;
  • MambaOut:面向视觉分类的实用化改造;
  • Hyena/RetNet/WSI:其他线性/次线性序列混合范式;
  • 混合架构:SSM + Attention/Conv的多路并行。

参考文献

  1. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint arXiv:2312.00752. https://arxiv.org/pdf/2312.00752

  2. Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICML 2022.

  3. Sun, Y., et al. (2023). RetNet: Retentive Network: A Successor to Transformer for Large Language Models. NeurIPS 2023.

  4. Poli, M., et al. (2023). Hyena Hierarchy: Towards Larger Convolutional Language Models. ICML 2023.

思考题

  1. 为什么选择性扫描能在保持线性复杂度的同时提升表达力?
  2. 如何在保持稳定性的同时引入更强的输入条件化(例如多尺度门控)?
  3. 在语言建模中,Mamba与Transformer是否适合分工合作?如何设计混合路由?
  4. 针对超长上下文(>64K tokens),Mamba的块并行与跨块信息传递应如何权衡?

思考题答案

  1. 选择性扫描通过输入依赖的门控让状态更新“稀疏化/聚焦化”,避免固定核对非稳态信号的欠拟合,同时保持线性递推形式;因此表达力↑、复杂度仍为O(N)。
  2. 采用分层门控(通道/时间/块级多粒度)、稳定化参数化(对角+低秩)、以及正则(门控L1/温度限制)可提升条件化同时维持稳定;
  3. 可用“前段SSM提取远距记忆 + 中段注意力做细粒度交互”的混合体,或MoE式路由将易于局部建模的片段交给Conv/SSM,需在延迟与精度之间折中;
  4. 提高块大小、增加跨块汇总通道(summary tokens)、或在块边界引入轻量全局模块(如稀疏注意力)以减小信息切换损失;同时用检查点与流水线降低内存峰值。