AI 快讯 编译自 marktechpost #模型发布#注意力机制#LLM 预训练

Parallax:保留 Softmax 注意力并添加学习协方差校正分支,提升 LLM 预训练效率

Parallax 是一种参数化局部线性注意力机制,通过添加可学习的协方差校正分支,在不替换 Softmax 的前提下提升 LLM 预训练性能。在 0.6B 和 1.7B 规模上,配合 Muon 优化器,Perplexity 和下游准确率均优于 Transformer 基线。本文解析其原理、硬件优势及对中文圈用户的潜在影响。

编译发布 2026/06/01 原文发布 2026/06/01

一句话看懂

Parallax 在保留 Softmax 注意力的基础上,添加一个可学习的协方差校正分支,通过增加计算密度而非削减计算来提升 LLM 预训练效率,在 0.6B 和 1.7B 模型上取得更优 Perplexity。

详细发生了什么

Transformer 的注意力机制自 2017 年以来几乎没有改变。大多数效率优化工作试图直接替换 Softmax 注意力。西北大学、Tilde Research 和华盛顿大学的研究团队提出了一种不同的方法:保留 Softmax 注意力,并附加一个校正分支。他们引入了一种参数化局部线性注意力(Local Linear Attention, LLA)机制,称为 Parallax,可扩展到 LLM 预训练,并与 Muon 优化器协同设计。

Parallax 建立在 LLA 之上。LLA 将注意力视为一个回归求解器:键是训练数据点,值是标签,查询是测试点。Softmax 注意力是一种非参数估计器(Nadaraya-Watson),为每个查询拟合局部常数函数。LLA 将其升级为局部线性估计,理论上具有更小的积分均方误差。但 LLA 在大规模应用中存在三个问题:密集的 I/O、正则化-表达能力权衡困难、低精度不兼容。

Parallax 移除了 LLA 中每个查询的共轭梯度求解器,转而学习一个额外的投影矩阵。它保持局部线性原则,但用可学习的查询式投影器替代了逐查询求解。输出等于 Softmax 注意力输出减去一个投影协方差项。当投影矩阵为零时,Parallax 退化为标准 Softmax 注意力,因此预训练检查点可通过添加投影矩阵并微调来转换。

硬件方面,Parallax 继承了 FlashAttention 的流式结构,增加了一个协方差分支,复用相同的键值流。它需要两个并行评分分支,共享在线最大值、缩放因子以及 K 和 V 块。在键值计算主导的场景下,Parallax 大致将算术强度(AI)翻倍,使注意力更接近计算受限状态,有利于现代 GPU 的 kernel 优化。研究团队在 NVIDIA H200 GPU 上使用 CuTeDSL 原型化了解码 kernel,在 BF16 精度下与 FlashAttention 2/3 对比,在计算匹配设置下实现了 1.54 倍加速,在 I/O 匹配设置下实现 1.14 倍加速。

实验在 0.6B 和 1.7B 规模的 LLM 预训练上进行,使用 Qwen-3 架构和 Ultra-FineWeb 数据集。Parallax 在 MAD-Benchmark 上达到 0.716 的平均准确率,在语言建模中配合 Muon 优化器取得最佳 Perplexity,1.7B 模型下游平均准确率 62.45,优于 Transformer 的 61.43。参数匹配和计算匹配的控制实验表明,增益来自机制本身而非额外参数或计算。

一个核心发现是优化器-架构协同:Parallax 在 Muon 下优势显著,在 AdamW 下优势大幅缩小甚至消失。Muon 优化器使校正分支在深层网络中贡献更大(校正输出比超过 8),而 AdamW 下模型倾向于抑制校正分支。

中文圈视角

Parallax 对中文圈用户意味着什么?首先,它不要求替换现有模型架构,只需在预训练 Transformer 上添加一个可学习投影矩阵并微调,降低了迁移成本。对于使用 Qwen 系列模型(如 Qwen-3)的国内团队,可以直接尝试将 Parallax 集成到预训练流程中。

其次,Parallax 依赖 Muon 优化器,而 Muon 在中文社区中已有一定讨论(如知乎、GitHub 项目),但尚未广泛用于 LLM 训练。国内主流训练框架(如 DeepSpeed、Megatron)对 Muon 的支持有限,需要额外适配。不过,Parallax 的代码已开源(基于 torchtitan),技术门槛不高。

与国产注意力机制对比:国内 Kimi 的 DeltaAttention 也属于线性注意力变体,但 Parallax 保留了 Softmax,更易于与现有系统兼容。对于中文长文本场景(如法律文档、学术论文),Parallax 更高的算术强度和更好的召回性能可能带来优势。

监管方面:Parallax 本身不涉及数据出境或内容安全问题,但若用于微调中文大模型,需注意训练数据的合规性。

一个中文圈尚未讨论的盲点:Parallax 在 WSD 学习率衰减阶段优势减弱,这提示在实际训练中可能需要调整衰减策略,国内团队可提前探索权重衰减退火等技巧。

几条值得记住的细节

  • Parallax 输出等于 Softmax 注意力输出减去投影协方差项,当投影矩阵为零时完全退化为 Softmax。
  • 在 H200 GPU 上,Parallax 解码 kernel 在计算匹配设置下比 FlashAttention 2/3 快 1.54 倍。
  • 1.7B 模型下游平均准确率 62.45,高于 Transformer 的 61.43。
  • 增益高度依赖 Muon 优化器;AdamW 下优势大幅缩小,校正输出比从 8 降至 4 以下。
  • 实验规模止步于 1.7B,未涉及 MoE 或更长上下文(当前为 4096)。

一句话总结

Parallax 通过保留 Softmax 并添加可学习校正分支,在 Muon 优化器配合下为 LLM 预训练提供了即插即用的效率提升,但实际效果取决于优化器选择和训练阶段。