Masked Autoencoders Are Scalable Vision Learners

引言

MAE (Masked Autoencoders) 由He Kaiming团队在2021年提出,为视觉自监督学习带来了新的范式。论文标题“Masked Autoencoders Are Scalable Vision Learners”凸显了其两大特性:一是基于掩码的自重构任务;二是能在大规模数据和模型上稳定扩展。和SimCLR、MoCo等对比学习方法相比,MAE丢弃了昂贵的负样本构造环节,通过简单的遮挡-重建目标即可学习高质量的视觉特征。

在图像理解任务中,过去的自监督方法往往依赖对比学习或生成式建模。MAE将NLP中成熟的Masked Language Modeling理念迁移到视觉领域,将图片切分为patch token,然后随机遮挡大部分token,让模型仅凭剩余少量可见token推断出被遮挡的像素,从而学到上下文结构。

背景知识

自监督视觉预训练的演进

  1. 预文本任务 (Pretext Task):如旋转预测、拼图恢复等,但任务与下游语义差距较大。
  2. 对比学习时代:SimCLR、MoCo、BYOL通过实例判别与数据增强获得强特征,但需要大batch或额外队列。
  3. 生成式方法回潮:iGPT、VQ-VAE尝试像素重建,但计算成本高、收敛慢。
  4. Vision Transformer普及:ViT把图像转成patch序列,为视觉领域引入“token”概念,使得掩码预测成为可能。

MAE正是在ViT基础上发展的生成式自监督方法。

自动编码器 vs. 对比学习

特性自动编码器对比学习
目标自重构输入拉近正样本、推远负样本
数据增强依赖
训练稳定性较稳定依赖大batch/动量队列
计算需求可低(MAE只编码可见token)往往高
表征性质偏向局部+全局偏向判别特征

MAE融合了自动编码器的补全思想与Transformer的全局建模能力。

论文核心思想

主要贡献

  1. 高掩码率:高达75%~90%的随机遮挡仍能有效预训练,极大降低计算量。
  2. 轻量解码器:编码器专注表示学习,解码器仅用于预训练阶段重建。
  3. 非对称架构:编码器处理少量可见token,解码器负责重建全部token,简化了训练。
  4. 可扩展性:在ImageNet上与有监督预训练相当甚至更优,对下游检测、分割任务具备竞争力。

整体流程

1
输入图像 → Patch划分 → 随机掩码 → 编码器处理可见patch → 嵌入掩码token → 轻量解码器重建像素 → 计算重建损失

模型架构

Patch Embedding 与 ViT兼容性

  • 使用ViT相同的patchify方式:将图像划分为16×16的patch,拼接成序列。
  • 对每个patch做线性投影得到D维token,并加上位置编码。
  • 由于MAE后续要还原像素,额外保存patch形状信息用于unpatchify。

随机掩码策略

  1. 均匀随机:从所有patch中随机采样保留25%,掩码75%。
  2. 可见token排序:保留patch在顺序上也会随机打乱,增加任务难度。
  3. 原因:视觉信息冗余高,遮挡大部分区域仍能推断结构,训练效率也随掩码率提高而提高。

编码器(Encoder)

  • 直接使用标准ViT Backbone(如ViT-B/16)。
  • 输入仅限可见token,大幅减少自注意力计算量。
  • 由于只处理25%的token,训练速度提高约3~4倍。

解码器(Decoder)

  • 结构比编码器浅(如8层Transformer),隐藏维度更小(如512)。
  • 输入由“编码器输出的可见token + 掩码token嵌入”组成。
  • 仅在预训练阶段存在,下游微调时丢弃,避免增加推理成本。

重建目标

  • 对所有patch输出预测像素,常用均方误差(MSE)。
  • 论文采用归一化的pixel值(对每个patch做均值方差归一),提升稳定性。
  • 也可以替换为DCT系数、特征空间等,扩展空间很大。

训练策略

预训练设置

配置常用取值
数据集ImageNet-1K无标签
掩码率75%
Patch size16
训练时长400 epoch
优化器AdamW (lr=1.5e-4, weight decay=0.05)
学习率策略Cosine + warmup
数据增强仅RandomResizeCrop + 随机水平翻转

微调策略

  1. 使用预训练好的编码器权重初始化ViT Backbone。
  2. 替换为任务特定头(分类、检测、分割)。
  3. 学习率通常更小(如5e-4),训练epoch也减少(如100 epoch)。
  4. 对分类任务,还会加入Mixup、CutMix等常规增强。

Linear Probe & KNN

  • 冻结编码器,仅训练线性分类头,可评估特征线性可分性。
  • 5-NN评估也常用来衡量无监督表示质量。

实验结果

ImageNet分类

模型预训练方式Top-1 Acc
ViT-B/16 (监督)有标签81.8%
ViT-B/16 (MAE)400 epoch83.7%
ViT-L/16 (MAE)400 epoch85.9%

MAE在有限标注下游训练(如少量epoch)同样保持优势。

下游任务

  • 目标检测 (COCO):与有监督预训练持平甚至略优。
  • 语义分割 (ADE20K):微调后比监督预训练高1~2 mIoU。
  • 鲁棒性:对遮挡和噪声具有更强适应性。

消融实验

  1. 掩码率:75%最佳,过低浪费计算,过高训练不稳定。
  2. 解码器深度:浅层即可,过深收益有限。
  3. 重建目标:直接像素即可,无需复杂loss。
  4. 位置编码:无需额外改动,使用ViT默认编码即可。

表征分析

  • MAE倾向关注全局结构,对局部纹理不过拟合。
  • 可视化显示模型能推断出缺失的主体轮廓,说明捕捉到高层语义。
  • 线性探针结果表明特征具备良好的线性可分性。

代码实现要点

基础组件

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x):
# x: [B, C, H, W]
x = self.proj(x) # [B, embed_dim, H/patch, W/patch]
x = x.flatten(2).transpose(1, 2) # [B, N, embed_dim]
return x

随机掩码函数

1
2
3
4
5
6
7
8
9
def random_masking(x, mask_ratio=0.75):
B, N, _ = x.shape
len_keep = int(N * (1 - mask_ratio))
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, x.size(-1)))
return x_masked, ids_restore, ids_keep

训练循环示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def mae_forward(model, imgs, mask_ratio=0.75):
tokens = model.patch_embed(imgs)
tokens = tokens + model.pos_embed[:, 1:, :]
tokens, ids_restore, _ = random_masking(tokens, mask_ratio)
latent = model.encoder(tokens)

# prepare decoder input
mask_tokens = model.mask_token.repeat(latent.size(0), ids_restore.size(1) - latent.size(1), 1)
latent_full = torch.cat([latent, mask_tokens], dim=1)
latent_full = torch.gather(latent_full, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, latent.size(2)))

rec = model.decoder(latent_full)
target = model.patchify(imgs)
loss = ((rec - target) ** 2).mean()
return loss

PyTorch Lightning/Timm 的现成实现

  • timm.models.mae 提供官方实现,可直接加载mae_vit_base_patch16等模型。
  • Hugging Face Transformers也提供Mask2Former等派生实现。

工程实践建议

  • Mixed Precision:配合高掩码率,训练非常高效。
  • 数据增强:预训练阶段增强较轻,避免破坏像素重建难度;微调阶段再加重。
  • 学习率调试:预训练用较大学习率,微调用较小的学习率。
  • 梯度累计:若显存受限,可结合梯度累计保持批量大小。

与其他方法的比较

模型训练范式计算需求表征特点
MAE掩码重建低 (编码少量token)全局语义强
BEiTToken重建中 (需tokenizer)依赖dVAE词表
SimMIM像素重建中等无解码器分离
MaskFeatHOG特征重建更注重低级特征

MAE以简单高效著称,成为后续大量工作(如MaskFeat、iBOT)的基础。

优缺点总结

优点

  1. 训练高效:编码器只处理可见token,显著降低计算。
  2. 鲁棒性强:对遮挡与噪声有更好表现。
  3. 迁移能力好:分类、检测、分割任务均有优势。
  4. 实现简单:无须复杂的数据增强或对比对。

缺点

  1. 重建目标限制:重建像素可能关注低级细节,对高层语义关注不足。
  2. 遮挡策略固定:随机掩码未利用场景先验,对结构化遮挡可能欠佳。
  3. 不适合生成:目标是补全而非生成高质量图像。
  4. 超参数敏感:掩码率、解码器宽度等需要调优。

相关与后续工作

  1. MAE v2:引入多尺度特征、更强的数据增强。
  2. SimMIM / MaskFeat:探索不同重建目标(直接像素、HOG等)。
  3. MaskCLIP:结合跨模态信息,用文本引导掩码学习。
  4. Masked Siamese Networks:将掩码与对比学习结合。
  5. VideoMAE:扩展到视频,随机掩码时空patch。

总结

MAE通过高掩码率的自重构任务,在不依赖负样本对的情况下实现了高质量的视觉自监督学习。其关键在于:

  1. 非对称架构:编码器轻量输入,解码器仅用于训练。
  2. 高掩码率:降低计算同时保持学习难度。
  3. 贴近NLP的设计:借鉴MLM理念,实现跨模态迁移。

MAE证明了生成式自监督在视觉任务中的可行性,为后续的Mask-based预训练方法奠定了基础。

参考文献

  1. He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2021). Masked autoencoders are scalable vision learners. arXiv:2111.06377.
  2. Dosovitskiy, A., et al. (2020). An image is worth 16x16 words: Transformers for image recognition at scale.
  3. Bao, H., Dong, L., & Wei, F. (2021). BEiT: BERT Pre-Training of Image Transformers.
  4. Xie, Z., et al. (2022). SimMIM: A Simple Framework for Masked Image Modeling.

思考题

  1. 为什么MAE可以使用75%的高掩码率?是否存在场景需要降低掩码率?
  2. 解码器仅用于预训练阶段,是否意味着我们可以用更复杂的解码器来提升效果?
  3. MAE的像素重建目标会不会让模型学到过多低级特征?如果是,该如何改进?
  4. 与对比学习相比,MAE缺少显式的判别约束,如何弥补这一点?
  5. 如果要把MAE扩展到视频或多模态任务,需要额外注意哪些设计?

思考题答案

1. 为什么MAE可以使用75%的高掩码率?是否需要调整?

  • 图像存在大量冗余,局部区域通常可由上下文推断。
  • Transformer具备全局建模能力,即便仅看到25%的patch也能捕捉结构。
  • 高掩码率减少计算成本;在纹理细节极其丰富的任务(如医学影像)可以适当调低(如50%)以保证信息量。

2. 解码器能否更复杂?

  • 理论上可以,但实验表明解码器过深收益有限,且增加训练成本。
  • 解码器主要提供学习信号,过强的解码器会掩盖编码器能力。
  • 若要改进,不如让解码器重建更高层次表征(如语义分割mask)。

3. 像素重建是否导致偏向低级细节?

  • 像素Loss会促使模型拟合纹理,但高掩码率迫使模型理解结构。
  • 可通过重建特征空间(如DINO特征)、频域信息或多任务损失提升语义理解。

4. 如何引入判别约束?

  • 在MAE基础上叠加对比学习头(如MAE+MoCo)。
  • 结合监督信号(半监督设置)或知识蒸馏。
  • 在线性探针阶段引入额外的判别任务。

5. 拓展到视频或多模态的注意点

  • 视频:需要处理时序维度,可随机掩码时空patch,并采用时序位置编码。
  • 多模态(图文):需联合掩码不同模态,并设计跨模态对齐目标,如重建文本描述。
  • 计算资源:视频/多模态数据更大,掩码率、模型尺寸需兼顾效率。