扩散模型 · Transformer · 文生图

Mean Mode Screaming:稳住千层扩散 Transformer

极深 DiT 会塌缩进作者称为 Mean Mode Screaming 的均值主导态。把残差拆成均值与去均值两条路径即可修复,训出稳定的 1000 层 DiT,FID 2.77。

Mean Mode Screaming:稳住千层扩散 Transformer

快速答案

把扩散 Transformer(DiT)堆到几百层以上,它往往会塌缩进一种「均值主导态」:每个 token 的表示都漂向同一个向量,模型不再学习有用的空间变化。本文作者 Pengqi Lu 把这种失败命名为 Mean Mode Screaming(MMS,均值模式尖叫),并把它追溯到残差「写入梯度」会分裂成一个均值相干分量和一个去均值分量。修复方法 Mean-Variance Split Residuals(MV-Split,均值-方差分裂残差) 让这两个分量各走各的增益。用上它,400 层 DiT 在 LayerScale 基线发散的地方达到 FID 2.60 / IS 185.5;1000 层 DiT 在 ImageNet-256 latent 上稳定训练到 50k 步,FID 2.77 / IS 217.3。

Mean Mode Screaming 到底是什么

名字戏剧化,机制却很具体。行随机注意力(softmax 映射每一行加和为 1)会保留序列的纯 token 均值分量,却压制去均值的 token 间变化。层数越堆越多,去均值部分不断缩小,token 表示趋于同质——深层里 token 之间的余弦相似度逼近 1.0。一旦对齐,softmax 雅可比的零空间会抹掉查询/键梯度,注意力再也学不动,网络锁死在塌缩态。所谓「尖叫」,是残差写入梯度范数在进入该态那一刻的陡然飙升——不是缓慢漂移,而是一个尖锐事件。

Mean-Variance Split Residuals 如何工作

作者的关键分析动作,是把每个残差写入梯度精确拆成两种模式:

  • 均值相干分量:当 token 对齐时按 O(T)(T 为 token 数)放大——正是这部分爆掉、驱动塌缩;
  • 去均值分量:扩散式、依赖具体 token——这部分承载真实信号,却被压制。

深层 Transformer 的标准稳定器 LayerScale 用每通道一个标量把两者同比例缩小,因此无法在压住均值的同时保护去均值信号。MV-Split 改为给两个子空间各自的逐特征可学增益。均值子空间走漏积分器 (1-alpha)*J(X_l) + alpha*J(F_l),抑制失控的均值相干项;去均值子空间保留标准残差路径 beta * (P*F_l),让 token 变化存活。梯度轨迹验证了这一点:MV-Split 维持的去均值梯度带比 LayerScale 高 2-3 倍,同时把均值相干分量约束住。

关键结果

以下数字均为 50k 训练步,ImageNet-2012 VAE latent(256x256),单流 Post-Norm DiT,Rectified Flow 目标,冻结 FLUX.2 编码器 + Qwen3-0.6B 文本条件。

  • 400 层 MV-Split: FID 2.60,IS 185.5。
  • 400 层 LayerScale 基线: 同检查点 FID 2.90、IS 165.5——且基线在跑完完整计划前就发散,而 MV-Split 没有。
  • 1000 层 MV-Split: FID 2.77,IS 217.3——证明方法在极端深度下仍成立的核心演示。
  • 对齐-放大律(Alignment-Amplification Law): 预测的梯度放大标度与实测斜率拟合,注意力与 FFN 写入器的 R 方均超过 0.9,发散事件处放大约达 13 倍。

诚实地看:1000 层模型是稳定性证明,不是质量纪录——它的 FID(2.77)略逊于 400 层 MV-Split(2.60)。贡献是「这么深也能不塌缩地训出来」,而非「单靠堆深度就买到更好的图」。

为什么现在重要

深度扩展一直是 Transformer 唯一无法像宽度和数据那样自由推高的维度——超过几百层,图像和视频 DiT 就悄悄停止变好或直接崩,通常靠零碎技巧打补丁。本文给出了一个有名字、有机制的解释(行随机注意力对均值模式的放大),外加一个针对性、低成本的修复——只多加逐特征增益。这比又一个「反正能用、但说不清为什么」的稳定器更有价值,因为即使具体药方不通用,这个诊断本身可以迁移。

局限与存疑

作者对缺口很坦诚。其一,精确预测 MMS 何时触发仍未解决——该律描述的是放大幅度,不是发生时机,你依然无法提前判断哪次训练会在哪一步塌缩。其二,分析只针对 softmax 注意力;同样的均值模式病理与修复是否适用于 Mamba、线性注意力等替代方案,尚未验证。其三,结果仅限 ImageNet-256 的类别/文本条件生成——深度最该发挥作用的极长上下文时空(长视频)生成并未涉及。此外这是一篇无署名机构的单作者论文:实验固定在 50k 步预算、单一骨干上,因此在更长计划和其它架构上的独立复现,是显而易见的下一步检验。

常见问题

扩散 Transformer 里的 Mean Mode Screaming 是什么?

Mean Mode Screaming 是极深扩散 Transformer 的一种塌缩模式:token 表示同质化、都漂向共享的均值向量,去均值(token 间)信号消失,残差写入梯度在塌缩那一刻陡然飙升。根源是行随机注意力在多层叠加中保留均值分量、压制去均值变化。

Mean-Variance Split Residuals 如何修复塌缩?

Mean-Variance Split Residuals 把残差更新拆成均值相干路径和去均值路径,各给一个可学的逐特征增益。均值路径用漏积分器抑制,阻止 O(T) 失控;去均值路径保留标准残差,让 token 变化存活。不同于把两种模式一起缩小的 LayerScale,它在压住爆炸的同时保护了信号。

1000 层 DiT 能达到多少 FID?

1000 层 MV-Split DiT 在 ImageNet-256 latent、50k 步时达到 FID 2.77、IS 217.3。400 层 MV-Split 的 FID 其实略好(2.60),所以这个 1000 层数字是「极端深度下稳定训练」的证明,而非新的质量纪录。

Mean-Variance Split Residuals 比 LayerScale 好吗?

在同为 400 层、50k 步时,MV-Split 达 FID 2.60 / IS 185.5,LayerScale 为 2.90 / 165.5,且 LayerScale 基线没跑完就发散,MV-Split 保持稳定。论文认为 LayerScale 之所以失败,是因为它把均值与去均值模式同比例缩小,而非分开处理。

一句话:极深 DiT 塌缩,是因为行随机注意力放大了 token 均值——把残差拆开,就能在压住均值的同时不杀掉信号。阅读 arXiv 原文