Sakana AI 提出 DiffusionBlocks:将残差网络转为独立可训练去噪模块的块训练框架,训练内存降至 1/B
Sakana AI 与东京大学提出 DiffusionBlocks,将残差网络视为扩散模型去噪步骤,实现块级独立训练。训练内存降至 1/B,在 ViT、DiT、MDM、AR Transformer 等架构上性能持平甚至优于端到端训练。本文详解原理、实验与中文用户视角。
一句话看懂
Sakana AI 提出 DiffusionBlocks,将残差网络每一层更新解释为扩散模型去噪步,实现块级独立训练,训练内存降至 1/B,且性能不降反升。
详细发生了什么
Sakana AI 与东京大学联合提出 DiffusionBlocks,一种将残差网络转化为独立可训练去噪模块的块训练框架。核心洞察在于:残差网络的更新公式 z_ℓ = z_{ℓ-1} + f_θ_ℓ(z_{ℓ-1}) 与扩散模型中的概率流 ODE 的欧拉离散化在结构上完全一致。因此,每个残差块可以被视为一个去噪步骤,而扩散模型的 score matching 目标天然支持每个噪声级别独立优化——这意味着每个块可以独立训练,无需块间通信。
具体转换分三步:1)将 L 层网络分为 B 个块;2)定义噪声分布 p_noise 和噪声范围 [σ_min, σ_max],并采用等概率划分(equi-probability partitioning)使每个块处理相同概率质量的噪声;3)为每个块添加噪声条件输入(通过 AdaLN),让块学习从其分配噪声范围内预测干净目标。训练时每次只采样一个块,内存消耗降至 L/B 层。
实验覆盖 ViT、DiT、MDM、AR Transformer、Huginn 等五种架构,在 CIFAR-100、ImageNet、text8、LM1B、OpenWebText 等数据集上,DiffusionBlocks 在多数任务上达到或超越端到端基线,同时实现 B× 训练内存缩减。对于 DiT 等扩散模型,推理时每个去噪步也只激活一个块,推理计算量同样降至 1/B。与 Forward-Forward 算法(CIFAR-100 仅 7.85% 准确率)相比,DiffusionBlocks 在相同 ViT 架构上达到 59.30% 准确率,差距巨大。
中文圈视角
对国内用户意味着什么?
-
训练成本直接降低:对于使用 ViT、DiT 等架构的团队,训练内存降至 1/B 意味着可以用更少 GPU 训练更深模型。例如 12 层 DiT 分 3 块,训练内存降至 1/3,推理计算量也降至 1/3。这对预算有限的学术团队和中小企业是直接利好。
-
国产模型可借鉴:国内类似残差网络架构(如 DeepSeek-V2 的 MoE、Qwen2 的 Transformer)理论上均可应用。但需要注意,DiffusionBlocks 要求模型使用残差连接,且需要添加噪声条件模块(AdaLN),对现有架构有侵入性修改。
-
与国产框架兼容性:该方法不依赖特定训练框架,可在 PyTorch、MindSpore、PaddlePaddle 中实现。但国内流行的 ModelScope 社区尚未有类似实现,早期采用者可获得先发优势。
-
潜在盲点:论文在 OpenWebText 上 MAUVE 略低于基线(0.82 vs 0.85),表明在超大规模语言模型上可能仍有差距。国内大模型厂商在千亿参数场景下需谨慎验证。
几条值得记住的细节
- 训练内存缩减比例:B 个块,内存降至 1/B。例如 B=4 时内存降至 1/4。
- 等概率划分:根据噪声分布 p_noise 的累积概率密度划分区间,中间噪声级别块获得更窄区间,极端噪声块更宽。CIFAR-10 上 FID 从 43.53(均匀划分)降至 38.03。
- 推理加速:对于 DiT 等扩散模型,每个去噪步只激活一个块,12 层 DiT 分 3 块时推理计算量降至 1/3。
- Huginn 训练:用 DiffusionBlocks 替代 BPTT,训练步数从 32 次迭代降至 1 次,总计算量减少约 10 倍。
- 与 NoProp 对比:NoProp 仅支持分类任务且需自定义 CNN,DiffusionBlocks 是唯一同时支持连续时间公式和块训练的方法,在 CIFAR-100 上准确率 46.88% 接近端到端 47.80%。
一句话总结
DiffusionBlocks 让你用更少 GPU 训练更深模型,训练内存降至 1/B,推理也加速,但大规模语言模型上仍需验证。