CH 3 FLOPs 估算

读者定位:已完成 CH1-2 的参数计算,目标是推导 prefilling / decoding 的单 token 计算量,并理解不同架构(Full Attn / MSA / MLA / Mamba-2)的 FLOPs 差异根源。

系列导航(一)预备知识与参数分解 ← 当前 → (三)KV Cache 与推理显存(四)M3 实战 + Roofline(五)训练显存(六)通信分析(七)推理服务


3.1 通用原理

本文各节描述的是前向(推理)FLOPs。训练 FLOPs 需要乘以训练系数(线性投影 ×6,QK ×4,AV ×3,Indexer ×1)——系数推导见 §3.12 从推理到训练:系数体系

建立”前向 FLOPs = 所有权重矩阵乘法之和”的底层逻辑。参数量是”模型存了多少数”,FLOPs 是”每次前向要算多少下”——两者直接决定推理延迟和硬件成本。

核心公式

单层 FLOPs = 该层内所有矩阵乘法的 $2 \times m \times n \times k$ 之和(见 1.2 节)。

$$\text{FLOPs}_{\text{total}} = \sum_{l=1}^{L} \text{FLOPs}_{\text{attn}}^{(l)} + \text{FLOPs}_{\text{ffn}}^{(l)} + \text{FLOPs}_{\text{norm}}^{(l)}$$

其中 norm(RMSNorm / LayerNorm)的 FLOPs 为 $4 \times d$(乘 $\gamma$ + 加 $\beta$),在大模型中可忽略($d=8192$ 时 $\approx 32\text{K FLOPs}$,而 Q 投影是 $\approx 134\text{M FLOPs}$)。

Prefill vs Decode

  • Prefill:输入 $T_{in}$ 个 token,所有层对所有 token 完整计算一次。总 FLOPs 正比于 $T_{in}$(线性部分)或 $T_{in}^2$(注意力部分)。
  • Decode:每次只产生 1 个新 token,但需要 attend 到所有历史 token($T_{total}$)。只有新 token 的 QKV 需要投影,但 QK 点积和 V 加权要覆盖全部历史。
$$\text{FLOPs}_{\text{decode\_per\_token}} = \sum_{l=1}^{L} \text{FLOPs}_{\text{new\_token}}^{(l)}$$

Prefill 是“一口气读完整本书再回答问题”,Decode 是“每次多读一个字就要把所有笔记翻一遍”。前者吞吐高但延迟长,后者每步轻量但被历史长度拖累。Attention 的 O(T²) 项只在 Prefill 是全量爆炸,Decode 时变成 O(T)(因为只有 1 个 query)。

单 Token FLOPs 计算范式

对每个矩阵乘法,固定范式为:

$$\text{FLOPs} = 2 \times (\text{输出第一维}) \times (\text{输出第二维}) \times (\text{被缩并的公共维度})$$

案例:Attention 层 Q 投影,输入 hidden $[1, d]$,权重 $W_Q [d, H_q \times D_h]$:

$$\text{FLOPs}_Q = 2 \times 1 \times (H_q \times D_h) \times d$$

Nemotron 12 个 Attention 层之一($d=8192$,$H_q=64$,$D_h=128$):

$$\text{FLOPs}_Q = 2 \times 1 \times (64 \times 128) \times 8192 = 2 \times 8192 \times 8192 = 134{,}217{,}728 \approx 134.2\text{M FLOPs/token}$$

每产生一个 token,Q 投影就要把 8192 维向量乘上 $8192 \times 8192$ 的矩阵——相当于做 8192 次 8192 维的内积。这就是一个 token 经过一层 Attention 的“起步价”。


3.2 Full Attention FLOPs

逐项拆解标准 Attention(含 GQA)的四部分 FLOPs,区分线性项和平方项。不理解 O(T²) 项从哪里来,就无法理解为什么长上下文推理会变慢——以及为什么 MSA、Mamba 等替代架构有意义。

3.2.1 QKV 投影(线性项,O(T))

投影部分在 Prefill 时随 T 线性增长,在 Decode 时是常数(只投影新 token)。

$$\text{FLOPs}_{Q} = 2 \times d \times (H_q \times D_h) \times T_{\text{new}}$$$$\text{FLOPs}_{K} = 2 \times d \times (H_{kv} \times D_h) \times T_{\text{new}}$$$$\text{FLOPs}_{V} = 2 \times d \times (H_{kv} \times D_h) \times T_{\text{new}}$$

GQA 的精髓:K 和 V 投影的输出维度是 $H_{kv} \times D_h$ 而非 $H_q \times D_h$——这是 GQA 相比于 MHA 在计算量(而不仅是参数量)上的直接节省。

案例代入:Nemotron Attention 层(GQA 32:1,$d=8192$,$H_q=64$,$H_{kv}=2$,$D_h=128$)。

Prefill($T=4096$)

$$\text{FLOPs}_{Q} = 2 \times 8192 \times (64 \times 128) \times 4096 = 2 \times 8192 \times 8192 \times 4096$$$$= 2 \times 67{,}108{,}864 \times 4096 = 549{,}755{,}813{,}888 \approx 550 \text{ GFLOPs}$$$$\text{FLOPs}_{K} = 2 \times 8192 \times (2 \times 128) \times 4096 = 2 \times 8192 \times 256 \times 4096$$$$= 2 \times 2{,}097{,}152 \times 4096 = 17{,}179{,}869{,}184 \approx 17.2 \text{ GFLOPs}$$$$\text{FLOPs}_{V} = 17.2 \text{ GFLOPs} \quad (\text{与 K 相同})$$

Prefill 一次性投影所有 4096 个 token 的 Q、K、V。注意 K 投影(17 GFLOPs)只占 Q 投影(550 GFLOPs)的约 3%——因为 $H_{kv} = 2$ 只有 $H_q = 64$ 的 1/32。

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)

$$\text{FLOPs}_{Q} = 2 \times 8192 \times (64 \times 128) \times 1 = 134{,}217{,}728 \approx 134.2\text{M FLOPs}$$$$\text{FLOPs}_{K} = 2 \times 8192 \times (2 \times 128) \times 1 = 4{,}194{,}304 \approx 4.2\text{M FLOPs}$$$$\text{FLOPs}_{V} = 4.2\text{M FLOPs}$$

QKV 投影在 decode 时总共 $\approx 142.6\text{M FLOPs}$——与上下文长度无关

QKV 投影就像“打字”——每个新 token 只需要把自己的向量投影一次。历史 token 的 K 和 V 投影结果被缓存在 KV cache 里,不用重算。

3.2.2 QK 点积(平方项,O(T²) 的根源)

$$\text{FLOPs}_{\text{QK}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$$

Prefill($T=T_{\text{new}}=T_{\text{total}}=4096$,causal mask 下约计算一半)

$$\text{FLOPs}_{\text{QK}} = 2 \times 64 \times 4096 \times \frac{4096}{2} \times 128 = 2 \times 64 \times 4096 \times 2048 \times 128$$$$= 2 \times 64 \times 8{,}388{,}608 \times 128 = 137{,}438{,}953{,}472 \approx 137 \text{ GFLOPs}$$

(精确无 causal 时为 275 GFLOPs,causal mask 下约折半。)

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)——这就是长上下文问题的核心

$$\text{FLOPs}_{\text{QK}} = 2 \times 64 \times 1 \times 1{,}000{,}000 \times 128$$$$= 2 \times 64 \times 128 \times 10^6 = 16{,}384 \times 10^6 = 1.6384 \times 10^{10} \approx 16.4 \text{ GFLOPs}$$

当上下文达到 1M tokens 时,仅一个 Attention 层的 QK 点积就需要 164 亿次浮点运算。对于有 12 个 Attention 层的 Nemotron:$12 \times 16.4 \approx 197 \text{ GFLOPs}$,仅此一项就超过了 QKV 投影(12 × 142.6M ≈ 1.7 GFLOPs)两个数量级。

QK 点积是把新 token 的一个 query 与缓存中所有 1M 个 key 逐一算相似度。1M 个 key,每个 128 维,每个维度一次乘法+一次加法=$2 \times 128 = 256$ FLOPs,64 个 head 各做一次,总计就是 $64 \times 1\text{M} \times 256 = 16.4\text{GFLOPs}$。这就是 Attention 在长上下文下“喘不过气”的根本原因。

3.2.3 V 加权(同样是 O(T) 项,decode 中体量等于 QK)

$$\text{FLOPs}_{\text{V}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$$

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)

$$\text{FLOPs}_{\text{V}} = 2 \times 64 \times 1 \times 1{,}000{,}000 \times 128 = 16.4 \text{ GFLOPs}$$

与 QK 点积等量级!原因:注意力权重要乘上 V 矩阵——1M 个 value 向量,每个 128 维,64 个 head。计算量路径:$[1, 64, 1, 1\text{M}] \times [1, 64, 1\text{M}, 128] \to [1, 64, 1, 128]$,缩并维度是 1M。

算完“每个历史 token 有多相关”(QK 点积)后,还要把 1M 个 value 向量按相关性加权平均。这个“加权平均”的运算量跟“计算相似度”一样大——都是 $2 \times H \times T \times D_h$。所以 Attention 的 decode 成本 = QK + V ≈ $4 \times H \times T \times D_h$。

3.2.4 输出投影(线性项,O(T))

$$\text{FLOPs}_{\text{O}} = 2 \times d \times (H_q \times D_h) \times T_{\text{new}}$$

decode 时为常数(Nemotron):$\text{FLOPs}_O = 2 \times 8192 \times 8192 \times 1 = 134.2\text{M FLOPs}$

与 Q 投影相同——因为输入和输出的维度都是 $d \times d$。

3.2.5 单层 Full Attention Decode FLOPs 汇总

以 Nemotron Attention 层(GQA 32:1,T=1M)为例:

组件公式$T=1\text{M}$ 时 FLOPs占比
Q 投影$2 \times d \times (H_q \times D_h)$134.2M0.4%
K 投影$2 \times d \times (H_{kv} \times D_h)$4.2M0.01%
V 投影$2 \times d \times (H_{kv} \times D_h)$4.2M0.01%
QK 点积$2 \times H_q \times T \times D_h$16.4G49.7%
V 加权$2 \times H_q \times T \times D_h$16.4G49.7%
O 投影$2 \times d \times (H_q \times D_h)$134.2M0.4%
单层合计~33.1G100%

关键观察:在 1M 上下文下,Attention 层 99.4% 的计算量花在 QK 点积和 V 加权上——这两个 O(T) 项(decode 时)。投影部分是常数,可以忽略。任何想加速长上下文推理的架构,都是从这两个 O(T) 项下手。

3.2.6 GQA 对 FLOPs 的影响

GQA 降低了 K 和 V 投影的 FLOPs($H_{kv}$ 替代 $H_q$),但不降低 QK 点积和 V 加权的 FLOPs。原因是 K 和 V 在注意力计算前会被 repeat_kv 扩展到与 Q 相同的头数:

1
2
# 标准 GQA 实现(transformers 源码)
K = K.repeat_interleave(H_q // H_kv, dim=1)  # [B, H_kv, T, D] -> [B, H_q, T, D]

所以 QK 点积的规模仍然是 $2 \times H_q \times T \times D_h$——与 MHA 完全相同

GQA 节省的是:

  • K、V 投影的 FLOPs(节省比例 $\frac{H_q}{H_{kv}}$ 倍,如 64/2=32×)
  • KV cache(同样 32×)

GQA 节省的 不是

  • QK 点积的 FLOPs
  • V 加权的 FLOPs

GQA 就像“出版社印了 64 份杂志(Q head),但只审了 2 份稿子(KV head),审稿费省了 32×,但印杂志的成本(读者阅读 = QK 点积)没省——因为每份杂志都要发给所有读者看。”


3.3 MSA 稀疏 Attention FLOPs(MiniMax M3)

推导 M3 的 Multi-stage Sparse Attention 计算量,理解“用廉价 Index Branch 筛选 + 昂贵 Main Branch 只在筛选区域计算”的 FLOPs 逻辑。M3 在 1M 上下文时实现约 30× 的 decode 加速——这是稀疏 Attention 的标杆案例。

3.3.1 MSA 架构概述

M3 的 MSA 将 Attention 分为两个分支:

  • Index Branch(廉价筛选器):用少量 head($H_{\text{idx}} = 4$)在全部 T 个 token 上做 QK 评分 + max-pool + top-k,选出 16 个 block(每 block 128 token,共 $16 \times 128 = 2048$ 个候选 token)。
  • Main Branch(精准计算器):用全部 head($H_q = 64$)只在 2048 个入选 token 上做完整 Attention。

M3 有 60 层:3 层 Full Attention(Layer 0-2)+ 57 层 MSA(Layer 3-59)。

3.3.2 Index Branch FLOPs

训练提示:MSA Indexer 在源码中被 @torch.no_grad() 包裹(modeling_minimax_m3_vl.py:L695),训练时不计算梯度——所有 Indexer 操作的系数为 1(仅前向),不是 6 或 7。详见 §3.13

维度回顾:$d = 6144$,$H_{\text{idx}} = 4$,$D_{\text{idx}} = 128$,$H_q = 64$,$D_h = 128$。

(1) Index Q 投影

$$\text{FLOPs}_{\text{idx\_Q}} = 2 \times d \times (H_{\text{idx}} \times D_{\text{idx}}) \times T_{\text{new}}$$

Decode($T_{\text{new}}=1$):

$$\text{FLOPs}_{\text{idx\_Q}} = 2 \times 6144 \times (4 \times 128) \times 1 = 2 \times 6144 \times 512 = 6{,}291{,}456 \approx 6.3\text{M FLOPs}$$

(2) Index K 投影

Index K 只有一个 head 的维度(128),4 个 index head 共享同一个 K:

$$\text{FLOPs}_{\text{idx\_K}} = 2 \times d \times D_{\text{idx}} \times T_{\text{new}} = 2 \times 6144 \times 128 \times 1 = 1{,}572{,}864 \approx 1.6\text{M FLOPs}$$

(3) Index QK 评分(O(T²) in prefill,O(T) in decode)

这是 Index Branch 的计算主体。Index Branch 用 4 个 head 在全序列上做 QK 点积。

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)

$$\text{FLOPs}_{\text{idx\_QK}} = 2 \times H_{\text{idx}} \times 1 \times T \times D_{\text{idx}} = 2 \times 4 \times 1 \times 10^6 \times 128$$$$= 2 \times 512 \times 10^6 = 1{,}024{,}000{,}000 \approx 1.02\text{ GFLOPs}$$

对比 Full Attention 的 QK 点积(如果用全部 64 个 head 做全序列评分):

$$\text{FLOPs}_{\text{full\_QK}} = 2 \times 64 \times 1 \times 10^6 \times 128 = 16{,}384 \times 10^6 \approx 16.4\text{ GFLOPs}$$

Index Branch 的 QK 评分仅需要 1.02 GFLOPs,而 Full Attention 需要 16.4 GFLOPs——减少了 16×。原因直截了当:4 个 head vs 64 个 head,$64/4 = 16$。

这就是 Index Branch 设计的精妙之处:用 16× 更便宜的计算,筛选出哪些 token 值得做完整的 64-head Attention。

(4) Max-pool + Top-k

Max-pool 将分数按 block 聚合(每 128 token 一个 block,共 $T/128$ 个 block),再选出 top-16 个 block。这部分本质是遍历和排序,FLOPs $\approx T/128 \times \log(16)$,约 $10^4$ 级别,完全可忽略。

3.3.3 Main Branch FLOPs

Main Branch 的核心:只在入选的 2048 个 token 上做完整 Attention。

$$\text{访问 token 数} = \text{block\_size} \times \text{top\_k\_blocks} = 128 \times 16 = 2048$$

(1) Main QK 点积

$$\text{FLOPs}_{\text{main\_QK}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{selected}} \times D_h$$$$\text{Decode} = 2 \times 64 \times 1 \times 2048 \times 128 = 2 \times 64 \times 262{,}144$$$$= 33{,}554{,}432 \approx 33.6\text{M FLOPs}$$

关键对比:Full Attention 的 QK = $16.4\text{G FLOPs}$,MSA Main QK = $33.6\text{M FLOPs}$。加速比 = $16.4\text{G} / 33.6\text{M} \approx 488\times$(T=1M 时,仅 QK 部分)。

(2) Main V 加权

$$\text{FLOPs}_{\text{main\_V}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{selected}} \times D_h = 33.6\text{M FLOPs}$$

与 Main QK 对称。

3.3.4 MSA 单层 Decode FLOPs 汇总(T=1M)

组件FLOPs类别
Index Q 投影6.3M常数
Index K 投影1.6M常数
Index QK 评分1.02GO(T),但 16× 小
Index max-pool + top-k~0可忽略
Main Q 投影$2 \times 6144 \times (64 \times 128) = 100.7\text{M}$常数
Main K 投影$2 \times 6144 \times (4 \times 128) = 6.3\text{M}$常数(GQA 16:1)
Main V 投影6.3M常数
Main QK 点积33.6M常数(仅 2048 个 token)
Main V 加权33.6M常数
Main O 投影$2 \times 6144 \times (64 \times 128) = 100.7\text{M}$常数
总计~1.31G

对比 Full Attention 层的 $\approx 33.1\text{G FLOPs}$(相同 $d$, $H_q$ 配置在 T=1M 下),MSA 单层仅需 $\approx 1.31\text{G FLOPs}$——加速约 25×

MSA 单层最大的开销是 Index QK 评分(1.02G,占 78%),这一项仍然随 T 线性增长——但它是用 4 个 head 而非 64 个,系数差距是 16×。

3.3.5 总体加速比

Decode 场景(T=1M)

对于 M3 的 57 层 MSA + 3 层 Full Attention:

  • 3 层 Full Attention:$3 \times 33.1\text{G} \approx 99.3\text{G FLOPs}$($d=6144$, $H_q=64$, $H_{kv}=4$)
  • 57 层 MSA:$57 \times 1.31\text{G} \approx 74.7\text{G FLOPs}$
  • 总计:$\approx 174\text{G FLOPs}$ 用于 Attention 部分

假如同样的 60 层全部是 Full Attention:

  • $60 \times 33.1\text{G} \approx 1986\text{G FLOPs} \approx 1.99\text{T FLOPs}$
  • 加速比 $\approx 1986 / 174 \approx 11.4\times$(仅 Attention 部分)

Prefill 场景(T=1M,causal),加速更显著:

  • Index QK 的 O(T²) 部分:$2 \times 4 \times (10^6)^2/2 \times 128 \approx 5.12 \times 10^{14}$ FLOPs/层
  • Full Attention QK 的 O(T²) 部分:$2 \times 64 \times (10^6)^2/2 \times 128 \approx 8.19 \times 10^{15}$ FLOPs/层
  • Main Branch QK:$2 \times 64 \times 10^6 \times 2048 \times 128 \approx 3.36 \times 10^{13}$ FLOPs/层(常数,不随 T² 增长)
  • 加速比 $\approx 8.19 \times 10^{15} / (5.12 \times 10^{14} + 3.36 \times 10^{13}) \approx 15\times$(仅 QK 部分)

综合其他恒定开销,实际整体 decode 加速约 2-5×,Prefill 加速约 10-20×(取决于序列长度和 overhead 比例)。论文声称的 30× 是 decode 场景下 Attention 部分 QK+V 的加速。

MSA 的哲学是“先粗筛再精算”。花 1 GFLOPs(Index Branch)扫一眼全场,发现最有戏的 2048 个 token,然后花 67 MFLOPs(Main QK+V)在这 2048 个 token 上精算。而 Full Attention 要花 33 GFLOPs 在所有 1M 个 token 上精算。前者总花费 $\approx 1.1\text{G}$,后者 $\approx 33\text{G}$,高下立判。


3.4 MLA FLOPs(Kimi K2.5 / DeepSeek V4)

推导 Multi-head Latent Attention 的 FLOPs,区分低秩投影的线性节省和 QK 点积的不变性。MLA 的卖点是“省 KV cache”而非“省 FLOPs”——但低秩投影确实也节省了一部分线性 FLOPs。

3.4.1 MLA 计算流程回顾

以 Kimi K2.5 为例($d=7168$,$d_{kv}=512$,$d_q=1536$,$H=64$,$D_{\text{nope}}=128$,$D_{\text{rope}}=64$,$D_v=128$):

MLA 的两阶段计算

  1. 压缩阶段:hidden $\to$ latent($W_{kv\_a}$, $W_{q\_a}$)
  2. 解压阶段:latent $\to$ per-head K, V, Q($W_{kv\_b}$, $W_{q\_b}$)
  3. RoPE 直接投影:hidden $\to$ per-head Q/K rope($W_{q\_rope}$,不经过 latent)

3.4.2 KV 侧 FLOPs(线性项节省的来源)

(1) KV 压缩投影 $W_{kv\_a}$

$$W_{kv\_a}: [d] \to [d_{kv} + D_{\text{rope}}] = 7168 \to 512 + 64 = 576$$$$\text{FLOPs}_{kv\_a} = 2 \times d \times (d_{kv} + D_{\text{rope}}) \times T_{\text{new}}$$

Decode:$= 2 \times 7168 \times 576 \times 1 = 8{,}257{,}536 \approx 8.3\text{M FLOPs}$

这个投影产生两部分输出:

  • 前 512 维:压缩的 KV latent,进入 $W_{kv\_b}$ 解压
  • 后 64 维:K 的 RoPE 分量(不压缩),直接用于注意力计算

(2) KV 解压投影 $W_{kv\_b}$

$$W_{kv\_b}: [d_{kv}] \to [H \times (D_{\text{nope}} + D_v)] = 512 \to 64 \times (128 + 128) = 64 \times 256 = 16384$$$$\text{FLOPs}_{kv\_b} = 2 \times d_{kv} \times H \times (D_{\text{nope}} + D_v) \times T_{\text{new}}$$

Decode:$= 2 \times 512 \times 64 \times 256 \times 1 = 16{,}777{,}216 \approx 16.8\text{M FLOPs}$

这个投影从 512 维 latent 中“解压”出 64 个 head,每个 head 有 128 维 nope K 和 128 维 V。等效于用一个 $512 \times 16384$ 的矩阵做投影——但比直接从 $7168 \to 16384$(MHA 方式)的 $7168 \times 16384 = 117.4\text{M}$ 矩阵 小了 7×

3.4.3 Q 侧 FLOPs

(3) Q 压缩投影 $W_{q\_a}$

$$W_{q\_a}: [d] \to [d_q] = 7168 \to 1536$$$$\text{FLOPs}_{q\_a} = 2 \times d \times d_q \times T_{\text{new}}$$

Decode:$= 2 \times 7168 \times 1536 \times 1 = 22{,}020{,}096 \approx 22.0\text{M FLOPs}$

(4) Q nope 解压投影 $W_{q\_b}$

$$W_{q\_b}: [d_q] \to [H \times D_{\text{nope}}] = 1536 \to 64 \times 128 = 8192$$$$\text{FLOPs}_{q\_b} = 2 \times d_q \times H \times D_{\text{nope}} \times T_{\text{new}}$$

Decode:$= 2 \times 1536 \times 64 \times 128 \times 1 = 25{,}165{,}824 \approx 25.2\text{M FLOPs}$

(5) Q RoPE 直投投影 $W_{q\_rope}$

RoPE 分量必须直接从 hidden 维度投影,不能经过压缩——因为 RoPE 的旋转操作施加在维度对上,压缩会破坏这个结构。

$$W_{q\_rope}: [d] \to [H \times D_{\text{rope}}] = 7168 \to 64 \times 64 = 4096$$$$\text{FLOPs}_{q\_rope} = 2 \times d \times H \times D_{\text{rope}} \times T_{\text{new}}$$

Decode:$= 2 \times 7168 \times 64 \times 64 \times 1 = 58{,}720{,}256 \approx 58.7\text{M FLOPs}$

注意:$W_{q\_rope}$ 是 MLA 中第二大的单项 FLOPs(仅次于输出投影),因为 RoPE 部分不能享受低秩压缩的红利。

3.4.4 QK 点积与 V 加权(O(T²) 项——与 MHA 完全等同)

训练提示:QK 用 $D_{qk} = D_{\text{nope}} + D_{\text{rope}} = 192$ 维,V 用 $D_v = 128$ 维——两者维度不同,训练时系数也不同(QK: 4 passes,V: 3 passes),不能用统一的 $7 \times (D_{qk} + D_v)$,必须分开 $(4 \times D_{qk} + 3 \times D_v)$。详见 §3.12

MLA 的 QK 点积分为两部分:

(6a) nope 分量的 QK 点积

$$\text{FLOPs}_{QK_{\text{nope}}} = 2 \times H \times T_{\text{new}} \times T_{\text{total}} \times D_{\text{nope}}$$

Decode(T=1M):$= 2 \times 64 \times 1 \times 10^6 \times 128 = 16.4\text{G FLOPs}$

(6b) rope 分量的 QK 点积

$$\text{FLOPs}_{QK_{\text{rope}}} = 2 \times H \times T_{\text{new}} \times T_{\text{total}} \times D_{\text{rope}}$$

Decode(T=1M):$= 2 \times 64 \times 1 \times 10^6 \times 64 = 8.2\text{G FLOPs}$

(6c) 合计 QK 点积

$$\text{FLOPs}_{QK} = 2 \times H \times T \times (D_{\text{nope}} + D_{\text{rope}}) = 2 \times H \times T \times D_h$$$$= 2 \times 64 \times 10^6 \times 192 = 24.6\text{G FLOPs}$$

其中 $D_h = 128 + 64 = 192$。这与标准 MHA($D_h=192$)的 QK 点积 FLOPs 完全相等。

(7) V 加权

$$\text{FLOPs}_{V} = 2 \times H \times T_{\text{new}} \times T_{\text{total}} \times D_v$$

Decode(T=1M):$= 2 \times 64 \times 1 \times 10^6 \times 128 = 16.4\text{G FLOPs}$

3.4.5 输出投影

(8) 输出投影 $W_o$

$$W_o: [H \times D_v] \to [d] = (64 \times 128) = 8192 \to 7168$$$$\text{FLOPs}_o = 2 \times H \times D_v \times d \times T_{\text{new}}$$

Decode:$= 2 \times 64 \times 128 \times 7168 \times 1 = 117{,}440{,}512 \approx 117.4\text{M FLOPs}$

3.4.6 MLA 单层 Decode FLOPs 汇总(T=1M)

组件FLOPs类型vs MHA 同配置
$W_{kv\_a}$(KV 压缩)8.3M常数—(MLA 新增)
$W_{kv\_b}$(KV 解压)16.8M常数—(MLA 新增)
$W_{q\_a}$(Q 压缩)22.0M常数—(MLA 新增)
$W_{q\_b}$(Q nope 解压)25.2M常数—(MLA 新增)
$W_{q\_rope}$(Q RoPE 直投)58.7M常数MHA Q proj 176.2M → 节省 3×
QK 点积(nope + rope)24.6GO(T)相同
V 加权16.4GO(T)相同
$W_o$(输出投影)117.4M常数相同
单层合计~41.2G

MLA 单层节省的 FLOPs 主要来自于:用多个小矩阵(低秩)替代 Q、K、V 的直投大矩阵。$W_{kv\_a}$ + $W_{kv\_b}$ + $W_{q\_a}$ + $W_{q\_b}$ + $W_{q\_rope}$ 合计 $\approx 131\text{M FLOPs}$,而标准 MHA 的 Q+K+V 三个直投矩阵合计 $\approx 2 \times 7168 \times 64 \times 192 \times 3 \approx 528\text{M FLOPs}$。线性项节省约 4×

但 QK 点积(24.6G)+ V 加权(16.4G)= 41G——这部分在 T=1M 时占比超过 99%,且与标准 MHA 完全相同

3.4.7 关键结论

MLA 省的是 KV cache,不是 FLOPs 的主体。

  • 线性项(投影):MLA 将 QKV 投影从 $\approx 528\text{M}$ 降到 $\approx 131\text{M FLOPs/token}$,但这项在长上下文下只占总 FLOPs 的 $\sim 0.3\%$。
  • 平方项/长上下文项(QK + V):MLA 的 FLOPs 与 MHA 完全相同——$2 \times H \times T \times D_h$——因为最终 attention 计算的维度规模没有变。
  • KV Cache:MLA 将每个 token 的 KV cache 从 $H \times D_{qk} + H \times D_v = 64 \times 192 + 64 \times 128 = 20{,}480$ 个元素压到 $d_{kv} + D_{\text{rope}} = 512 + 64 = 576$ 个元素——压缩 35.6×。这才是 MLA 的主要价值。

MLA 就像“快递打包”——包裹运输时压缩(KV cache 小),但到了收件人手里必须拆开原样呈现(注意力计算时的 K、V 维度与 MHA 完全相同)。运费省了(显存),但收件人验货的工作量没少(FLOPs)。


3.5 Mamba-2 SSD FLOPs(Nemotron)

逐项拆解 Mamba-2 Structured State Space Duality 层的 FLOPs,展示为什么它是 O(T) 而非 O(T²)。Mamba-2 是 Nemotron 的核心非 Attention 序列建模层——48 个 Mamba 层的 FLOPs 特征决定了整个模型的长上下文行为。

3.5.1 Mamba-2 计算流程回顾

维度回顾(Nemotron):$d=8192$,$\text{expand}=2 \Rightarrow d_{\text{inner}}=16384$,$H_{\text{mamba}}=256$,$D_{\text{mamba}}=64$,$N=128$(ssm_state_size),$n_{\text{groups}}=8$,$C=128$(chunk size)。

验证自洽性:$d_{\text{inner}} = H_{\text{mamba}} \times D_{\text{mamba}} = 256 \times 64 = 16384$。$\checkmark$

Mamba-2 的 SSD 将序列分成大小为 C 的 chunk,每个 chunk 内部做因果 matmul(对角块),chunk 之间通过状态传递(非对角块)。总计算量分为四部分:

3.5.2 (a) in_proj 输入投影(线性项主力)

in_proj 一次性产生所有需要的分量:$\mathbf{x}$、$\mathbf{z}$、$\mathbf{B}$、$\mathbf{C}$、$\boldsymbol{\Delta}$。

投影维度:$d \to 2 \times d_{\text{inner}} + 2 \times n_{\text{groups}} \times N + H_{\text{mamba}}$
$= 8192 \to 2 \times 16384 + 2 \times 8 \times 128 + 256$
$= 8192 \to 32768 + 2048 + 256 = 35072$

$$\text{FLOPs}_{\text{in\_proj}} = 2 \times d \times 35072 \times T_{\text{new}}$$

Decode:$= 2 \times 8192 \times 35072 \times 1 = 574{,}619{,}648 \approx 574.6\text{M FLOPs}$

这是 Mamba-2 层单 token 计算中最大的一项。对比 Attention 的 Q 投影(134M),Mamba 的 in_proj 约大 4.3×——因为它是一次性投影出 5 个分量(x, z, B, C, Δ),相当于把 Attention 的 Q、K、V、外加两个额外的分量合并到一个矩阵里。

3.5.3 (b) conv1d 深度卷积(可忽略)

一维深度卷积,核大小 = 4,输入通道数 = $d_{\text{conv}} = d_{\text{inner}} + 2 \times n_{\text{groups}} \times N = 16384 + 2048 = 18432$。

$$\text{FLOPs}_{\text{conv1d}} = 2 \times d_{\text{conv}} \times \text{kernel} \times T_{\text{new}}$$

Decode:$= 2 \times 18432 \times 4 \times 1 = 147{,}456 \approx 0.15\text{M FLOPs}$

卷积核只有 4 个元素宽,而且是深度卷积(每个通道独立的 1D 卷积),所以计算量跟 in_proj 比可以忽略不计——就像“顺丰快递的包装费相对于货品价值”。

3.5.4 (c) SSD 对角块(chunk 内因果 matmul)

这是 Mamba-2 “Attention 等价” 的部分。在每个 chunk 内,SSD 做类似因果 Attention 的计算:

$$\text{FLOPs}_{\text{diag}} = 2 \times \frac{T}{C} \times \frac{C^2}{2} \times H_{\text{mamba}} \times D_{\text{mamba}} = T \times C \times H_{\text{mamba}} \times D_{\text{mamba}}$$

代入:$= T \times 128 \times 256 \times 64 = T \times 2{,}097{,}152$

Prefill(T=4096):$4096 \times 2{,}097{,}152 \approx 8.59 \times 10^9 \approx 8.6\text{G FLOPs}$

Decode($T_{\text{new}}=1$,但 chunk 内的因果 matmul 在 decode 时仅涉及当前 chunk 的累积状态):约 4.2M FLOPs(与 T 无关)。

这里需要澄清:在 decode 阶段,Mamba-2 不需要对每个新 token 重做所有 chunk 的内部计算——SSD 的递归特性意味着新 token 只需要更新当前 chunk 的对角块和状态传递。因此 decode 时这部分是常数。

3.5.5 (d) SSD 非对角块:chunk 间的状态传递

前面的对角块是每个 chunk “内部消化”——chunk 里的每个 token 看到前面 token 的计算。但 chunk 1 的最后一个 token 怎么看到 chunk 0 的第一个 token?这需要状态传递

Mamba-2 的 SSM 在每个 chunk 边界维护一个隐藏状态 $h \in \mathbb{R}^{H_{\text{mamba}} \times N}$($N = d_{state} = 128$)。这个状态向量"记住"了之前所有 chunk 的摘要。

当一个 chunk 结束时,它的状态 $h_{i}$ 需要"传递"给下一个 chunk。传递的数学操作是:下一个 chunk 的每个位置,将传入状态与当前 chunk 的 $C$(输出投影)相乘,得到对当前 chunk 内每个 token 的修正量。这个操作为每个 chunk 边界做一次 $h_i \times C_{i+1}$。

$h_i$ 的形状是 $[H_{\text{mamba}}, N]$,$C_{i+1}$(经过 decay 加权后)的形状也是 $[H_{\text{mamba}}, N]$。这不是简单的向量点积——Mamba-2 需要在 $N$ 维空间内做"状态混合",让 $N$ 维的每个分量都能影响当前 chunk 的输出。因此,实际的状态传递矩阵是一个 $[N, N]$ 的变换:

$$\text{FLOPs}_{\text{off-diag}} = 2 \times \underbrace{\frac{T}{C}}_{\text{chunk 数}} \times \underbrace{H_{\text{mamba}}}_{\text{heads}} \times \underbrace{N^2}_{\text{状态传递矩阵}}$$

代入 Nemotron 的数值:chunk 数 $= T/128$,$H_{\text{mamba}} = 256$,$N = 128$:

$$= 2 \times \frac{T}{128} \times 256 \times 128^2 = 2 \times \frac{T}{128} \times 256 \times 16{,}384$$$$= 2 \times \frac{T}{128} \times 4{,}194{,}304 = T \times 65{,}536 \approx 6.55 \times 10^4 \times T$$

Prefill(T=4096):$4096 \times 65{,}536 \approx 0.27\text{G FLOPs}$

Decode:约 $6.55 \times 10^4$ FLOPs(常数级别)。

对角块和非对角块加起来,就是 SSD 的完整 FLOPs。对角块做"块内注意"($O(C^2)$),非对角块做"块间传递"($O(N^2)$)。$C = 128$、$N = 128$ 时,$C^2 = N^2$——这是设计上的巧合,不是必然。如果 chunk_size 变了,对角块和非对角块的比例就会偏移。

3.5.6 (e) out_proj 输出投影

$$\text{FLOPs}_{\text{out\_proj}} = 2 \times d_{\text{inner}} \times d \times T_{\text{new}}$$

Decode:$= 2 \times 16384 \times 8192 \times 1 = 268{,}435{,}456 \approx 268.4\text{M FLOPs}$

3.5.7 Mamba-2 单层 FLOPs 汇总

Prefill(T=4096)

组件FLOPs占比复杂度
in_proj$574.6\text{M} \times 4096 = 2.35\text{T}$92.3%O(T)
conv1d$0.15\text{M} \times 4096 = 0.61\text{G}$~0%O(T)
SSD 对角块8.6G0.3%O(T×C)
SSD 非对角块0.27G~0%O(T)
out_proj$268.4\text{M} \times 4096 = 1.10\text{T}$7.4%O(T)
单层合计~3.46T FLOPs100%O(T)

48 层合计:$\approx 166\text{T FLOPs}$(prefill 4096 token)。全部是 O(T)——没有任何 O(T²) 项。

Decode($T_{\text{new}}=1$,$T=1\text{M}$)

组件FLOPs复杂度
in_proj574.6MO(1)
conv1d0.15MO(1)
SSD 对角块 (decode)~4.2MO(1)
SSD 非对角块 (decode)~0.07MO(1)
out_proj268.4MO(1)
单层合计~847MO(1)

48 层 Mamba-2 合计:$\approx 40.7\text{G FLOPs/token}$(与 T 无关!)

这是最关键的数字:Mamba-2 层的 decode FLOPs 与上下文长度完全无关——每 token 固定 $\approx 847\text{M FLOPs}$。而 Attention 层在 T=1M 时需要 $\approx 33.1\text{G FLOPs/token}$。

3.5.8 与 Attention 的对比:O(T) vs O(T²)

以 1M 上下文为例,单层对比

指标Full Attention (GQA)Mamba-2 SSD比率
线性项 (proj)277M843M0.33×(Mamba 更贵)
长上下文项 (QK/sSD)32.8G~4.3M7600×(Attention 更贵)
单层总计33.1G847M39×(Mamba 更快)

48 层 Mamba-2($\approx 40.7\text{G FLOPs}$) vs 48 层 Full Attention($\approx 48 \times 33.1\text{G} \approx 1.59\text{T FLOPs}$)——Mamba 快 39×

Mamba-2 的 SSD 是“聪明地算”——把 O(T²) 的 Attention 变成了 chunk 内 O(C²) 的因果 matmul(C=128,常数)。1M 个 token 被切成 ~7812 个 chunk,每个 chunk 内部做的计算量恒定。新 token 到来时,只更新当前 chunk 并传播状态。而 Attention 每来一个新 token,都要跟全部 1M 个历史 token 逐一“打招呼”。这就是 O(T) vs O(T²) 的本质区别。


3.6 Sliding Window Attention(SWA)FLOPs

Sliding Window Attention 是 MiMo-V2-Flash、Mistral 等模型使用的稀疏 Attention 方案。每个 token 只关注它前面固定窗口 $W$ 内的 token,而非全部 $T$ 个 token。

QK 点积的复杂度从 $O(T^2)$ 降到 $O(T \times W)$:

$$\text{FLOPs}_{\text{QK, SWA}} = 2 \times H_q \times T_{\text{new}} \times \min(T, W) \times D_h$$
  • Prefill(每个 token 看到前面 $W$ 个):$2 \times H_q \times T \times W \times D_h$
  • Decode(新 token 只往前看 $W$ 步):$2 \times H_q \times 1 \times W \times D_h$

以 MiMo-V2-Flash 为例:$H_q = 64$,$W = 131072$,$D_h = 128$。Prefill 时 $T=W=131\text{K}$:$2 \times 64 \times 131072 \times 131072 \times 128 \approx 2.8 \times 10^{14}$ FLOPs,是 Full Attention($8.4 \times 10^{14}$)的约 $1/3$。但 decode 时:$2 \times 64 \times 1 \times 131072 \times 128 = 2.15 \times 10^9$ FLOPs——与 Full Attention decode 完全相同(因为 decode 时 $T_{new}=1$,Full Attn 也只看全部 $T$ 个历史 token)。

SWA 省的是 prefill 而非 decode。它适合吞吐优先的短上下文场景,但在长上下文 decode 上没有优势。

SWA 的 $W$ 不是凭空取的——通常等于 max_position_embeddingssliding_window 字段。如果 config 中找不到 sliding_window 但模型声称是 SWA,查看 max_position_embeddings 是否与上下文窗口匹配。

3.7 Gated DeltaNet(Linear Attention)FLOPs

Gated DeltaNet 是 Qwen3.5-MoE 等模型使用的线性注意力方案。与 Mamba-2 共享核心思想——用固定大小的隐藏状态 $S_t \in \mathbb{R}^{H \times D_h \times D_h}$ 取代 Attention 的 $O(T^2)$ 点积。

DeltaNet 的更新规则(简化):

$$S_t = \alpha_t \cdot S_{t-1} + \beta_t \cdot (k_t \otimes v_t)$$

其中 $k_t \otimes v_t$ 是 key 和 value 的外积,形状为 $[H, D_h, D_h]$。$\alpha_t$ 是遗忘门(decay),$\beta_t$ 是输入门(input gate),两者都是通过投影从当前输入得到的标量。

输出:$y_t = S_t \cdot q_t$,其中 $S_t \cdot q_t$ 将一个 $[H, D_h, D_h]$ 矩阵与 $[H, D_h]$ 向量相乘,得到 $[H, D_h]$ 的注意力输出。

每 token FLOPs 分解

$$\text{FLOPs}_{\text{DeltaNet}} = \underbrace{2 \times H \times D_h^2}_{\text{外积 } k_t \otimes v_t} + \underbrace{2 \times H \times D_h^2}_{\text{状态乘 } S_t \cdot q_t} + \underbrace{2 \times H \times D_h^2}_{\text{状态更新 } S_t = \alpha S_{t-1} + \beta(k \otimes v)}$$

三项各 $2 \times H \times D_h^2$,合计 $6 \times H \times D_h^2$。全与 $T$ 无关——DeltaNet 的 decode FLOPs 是常数

以 Qwen3.5-MoE 为例($H = 64$,$D_h = 128$):$6 \times 64 \times 128^2 = 6 \times 64 \times 16384 \approx 6.3 \times 10^6$ FLOPs/token/layer。对比 Full Attention 的 decode($2 \times 64 \times 10^6 \times 128 \approx 1.6 \times 10^{10}$),DeltaNet 节省了约 2500×

与 Mamba-2 的核心差异:Mamba-2 通过 in_proj 一次性产生所有 SSM 参数($\Delta, B, C$),其输入投影的 FLOPs 远大于 SSM 核心计算。DeltaNet 的投影更简单(类似标准 Attention 的 QKV 投影),所以整体 FLOPs 更小。但 Mamba-2 的状态维度 $H \times N$($256 \times 128$)比 DeltaNet 的 $H \times D_h^2$($64 \times 128^2$)小得多——状态大小是 $O(H \times N)$ vs $O(H \times D_h^2)$,差了 $D_h$ 倍。

3.8 MoE Gating FLOPs

计算路由器(Router / Gate)的 FLOPs,证明它在总计算量中占比 <1%。很多人担心 MoE 的路由开销会抵消稀疏化的收益——这一页数值直接打消这个顾虑。

Router FLOPs

标准 sigmoid/softmax 路由器的核心计算是一个矩阵乘法:

$$\text{FLOPs}_{\text{router}} = 2 \times d \times E \times T_{\text{new}}$$

Nemotron($d=8192$,$E=512$,decode):

$$\text{FLOPs}_{\text{router}} = 2 \times 8192 \times 512 \times 1 = 8{,}388{,}608 \approx 8.4\text{M FLOPs}$$

M3($d=6144$,$E=128$,decode):

$$\text{FLOPs}_{\text{router}} = 2 \times 6144 \times 128 \times 1 = 1{,}572{,}864 \approx 1.6\text{M FLOPs}$$

对比单层 MoE 的专家计算量(激活 4-22 个专家,每个专家做 $2 \times d \times d_{ff}$ 或 $3 \times d \times d_{ff}$ 的 FFN):

  • Nemotron 单专家(ReLU$^2$,latent 空间):$2 \times 2048 \times 5120 \approx 21\text{M FLOPs}$
  • 激活 22 个专家:$\approx 462\text{M FLOPs}$

Router 的 8.4M FLOPs 占 462M 的 1.8%。在 M3(128 专家,激活 4 个)中占比更低。

DeepSeek V4 Flash 的 hash routing 稍复杂,但本质仍是查表+少量矩阵乘法,FLOPs 在百万量级,可忽略。

Router 就是给 512 扇门各配一把锁(一个 8192 维向量),新 token 来了用自己的 8192 维特征跟 512 把锁各算一次相似度。这个开销相当于一扇门打开后干活(一个专家 FFN)的几十分之一。Router 的 FLOPs 约等于半个 Attention 的 K 投影——在总计算量的大海里是一滴水。


3.9 Vision Encoder FLOPs

计算 ViT 编码器的 FLOPs,理解为什么视觉编码在总推理成本中的占比。多模态模型输入一张图时,ViT 要处理 576-2916 个 patch token——这部分计算量是“固定入场券”,与文本长度无关。

3.9.1 MiniMax M3 ViT FLOPs

M3 ViT:32 层,$d_{\text{vit}}=1280$,$H_{\text{vit}}=16$,$D_{\text{vit}}=80$,$d_{ff}^{\text{vit}}=5120$。

图像 token 数:$\left(\frac{2016}{14}\right)^2 = 144^2 = 20736$ patches,经过 pixel unshuffle($\times 4$ 压缩)后:$20736 / 4 = 5184$,再经 spatial merge:$5184 / 9 = 576$ tokens。本文取 576。

单层 Attention(标准 MHA):

$$\text{FLOPs}_{\text{ViT QKV}} = 4 \times 2 \times d_{\text{vit}} \times H_{\text{vit}} \times D_{\text{vit}} \times T_{\text{img}}$$$$= 8 \times 1280 \times 16 \times 80 \times 576 = 8 \times 1{,}638{,}400 \times 576$$$$= 8 \times 943{,}718{,}400 = 7{,}549{,}747{,}200 \approx 7.55\text{G FLOPs}$$

($4 \times 2 = 8$ 来自 Q、K、V、O 四个投影各 $2 \times m \times n \times k$)

QK 点积(causal 不适用,ViT 对图像做双向 Attention):

$$\text{FLOPs}_{\text{ViT QK}} = 2 \times H_{\text{vit}} \times T_{\text{img}}^2 \times D_{\text{vit}} = 2 \times 16 \times 576^2 \times 80$$$$= 2 \times 16 \times 331{,}776 \times 80 = 849{,}346{,}560 \approx 0.85\text{G FLOPs}$$

V 加权

$$\text{FLOPs}_{\text{ViT V}} = 2 \times H_{\text{vit}} \times T_{\text{img}}^2 \times D_{\text{vit}} = 0.85\text{G FLOPs}$$

单层 MLP(GELU,2 个矩阵):

$$\text{FLOPs}_{\text{ViT MLP}} = 2 \times 2 \times d_{\text{vit}} \times d_{ff}^{\text{vit}} \times T_{\text{img}}$$$$= 4 \times 1280 \times 5120 \times 576 = 4 \times 6{,}553{,}600 \times 576$$$$= 4 \times 3{,}774{,}873{,}600 = 15{,}099{,}494{,}400 \approx 15.1\text{G FLOPs}$$

单层合计:$7.55 + 0.85 + 0.85 + 15.1 \approx 24.35\text{G FLOPs}$

32 层合计:$32 \times 24.35\text{G} \approx 779\text{G FLOPs}$

加上 patch embedding、projector 等:$\approx 800\text{G FLOPs} = 0.8\text{T FLOPs}$(per image)。

对比文本骨干(60 层,prefill 4096 token,$\approx 100\text{T+ FLOPs}$),ViT 的 0.8T FLOPs 占比 <1%。

ViT 虽深(32 层),但维度小(1280 vs 6144)且 token 数固定(576 vs 4096+)。相当于“一辆 Smart 虽也能开到 120 迈,但跟重卡(文本骨干)不是一个吨位的”。

3.9.2 Kimi K2.5 ViT FLOPs(速算)

K2.5 ViT:27 层,$d_{\text{vit}}=1152$,$H_{\text{vit}}=16$,$D_{\text{vit}}=72$,$d_{ff}^{\text{vit}}=4304$。图像 token 数约 576-2916(取决于分辨率模式)。

用 576 token 近似:

$$\text{单层 Attn + MLP} \approx 8 \times 1152 \times 16 \times 72 \times 576 + 4 \times 1152 \times 4304 \times 576$$$$\approx 6.1\text{G} + 11.4\text{G} \approx 17.5\text{G FLOPs}$$

27 层:$\approx 0.47\text{T FLOPs}$。加上 PatchMerger 和投影器:$\approx 0.5-0.7\text{T FLOPs}$。


3.10 完整案例对比:1M 上下文下三种架构的 FLOPs

在同一张表中呈现纯 Full Attention、Nemotron Hybrid(Mamba + Attn)、M3 MSA 三种方案的 FLOPs 分解。这张表是 CH3 的终极输出——一行看懂 Mamba 和 MSA 为什么殊途同归地解决了 O(T²) 问题。

3.10.1 场景设定

  • 上下文长度:T = 1M tokens
  • 解码阶段:$T_{\text{new}} = 1$(单 token decode)
  • 对比模型:
    • 纯 Full Attn (hypothetical):60 层 Full Attention,$d=8192$,$H_q=64$,$H_{kv}=64$(MHA,无 GQA),$D_h=128$,SwiGLU FFN $d_{ff}=8192 \times 4 \approx 32768$(无 MoE 时 FFN 占比较小,此处简化用大维度)
    • Nemotron 3 Ultra (hybrid):48 层 Mamba-2 + 12 层 Attention(GQA 32:1,2 KV heads)+ 48 层 MoE(22/512 激活)。$d=8192$,$H_q=64$,$H_{kv}=2$,$D_h=128$。MoE 专家在 latent 空间计算。
    • M3 (MSA):57 层 MSA(GQA 16:1,4 KV heads)+ 3 层 Full Attention(GQA 16:1)+ 57 层 MoE(4/128 激活)。$d=6144$,$H_q=64$,$H_{kv}=4$,$D_h=128$。

3.10.2 逐项 FLOPs 分解(decode per token, T=1M)

Attention 部分(QK + V 加权)

模型Attention 层数单层 QK+V FLOPsAttn 部分合计
纯 Full Attn60$4 \times 64 \times 1\text{M} \times 128 = 32.8\text{G}$$60 \times 32.8\text{G} = 1.97\text{T}$
Nemotron Hybrid1232.8G (GQA 下 QK+V 仍为 $4 \times 64 \times T \times 128$)$12 \times 32.8\text{G} = 393.6\text{G}$
M3 MSA3 Full + 57 MSAFull: 32.8G(改用 $d=6144$,$H_q=64$,$H_{kv}=4$ 后实际 ~32.8G);MSA: Index QK 1.02G + Main QK+V 67.2M ≈ 1.09G$3 \times 32.8\text{G} + 57 \times 1.09\text{G} \approx 160.5\text{G}$

Mamba/SSD 部分

模型Mamba/SSD 层数单层 FLOPsMamba 部分合计
纯 Full Attn000
Nemotron Hybrid48~847M$48 \times 847\text{M} = 40.7\text{G}$
M3 MSA000

线性投影部分(QKV proj + O proj + in_proj + out_proj + FFN):

模型单层投影估算投影部分合计
纯 Full AttnQ(134M) + K(134M) + V(134M) + O(134M) + FFN(~1.6G) ≈ 2.14G$60 \times 2.14\text{G} \approx 128\text{G}$
Nemotron HybridAttn 投影(~277M) × 12 + Mamba 投影(~843M) × 48 + MoE FFN(~462M) × 48$\approx 3.3\text{G} + 40.5\text{G} + 22.2\text{G} \approx 66\text{G}$
M3 MSAMSA 投影(~220M) × 57 + Full Attn 投影(~220M) × 3 + MoE FFN(~220M) × 57$\approx 12.5\text{G} + 0.7\text{G} + 12.5\text{G} \approx 26\text{G}$

注:以上为近似量级估算。投影部分具体数值取决于 $d_{ff}$、MoE 专家数等配置细节,精确计算需代入各模型 config.json 的实际值。本表的重点是横比数量级差异。

3.10.3 总表

模型Attn QK+V 部分Mamba/SSD 部分线性投影总 FLOPs/token相对纯 Full Attn
纯 Full Attn (hypothetical)~1.97T0~128G~2.10T1×(基线)
Nemotron 3 Ultra (hybrid)~394G~41G~66G~501G~1/4
M3 (MSA)~161G0~26G~187G~1/11

核心发现:

  1. 纯 Full Attn 在 1M 上下文下几乎不可用:每产生一个 token 需要 2.1T FLOPs,单看 Attention QK+V 部分的 1.97T 占 94%。即使最强大的推理硬件也难以达到可接受的吞吐(2.1T / 989 TFLOPS(H100 FP16)$\approx 2.1$ 秒/ token)。

  2. Nemotron Hybrid 将 QK+V 开销砍到原来的 1/5(394G vs 1970G),因为 80% 的层(48/60)用 Mamba-2 完全避开了 O(T) Attention。但 12 个 Attention 层仍贡献了总 FLOPs 的 78%——12 个 Attention 层的成本超过了 48 个 Mamba 层的总和

  3. M3 MSA 更进一步:3 个 Full Attention 层占 98G 的 QK+V,57 个 MSA 层才占 62G(Index QK $57 \times 1.02\text{G} = 58.1\text{G}$ + Main QK+V $57 \times 0.067\text{G} = 3.8\text{G}$)。MSA 的 Index Branch 虽然仍是 O(T),但以 16× 的廉价系数执行。

  4. 殊途同归:Nemotron 用 Mamba-2(状态空间,O(1) decode),M3 用稀疏 Attention(O(T) 但系数极小)——两条不同的技术路线,但都在 1M 上下文上将 Attention 部分从 TFLOPs 量级压到了 GFLOPs 量级。原理不同,效果趋同。

3.10.4 不同上下文长度下的横比

为直观展示 O(T) vs O(1) 的差别,固定模型配置,变化 T。仅计算 Attention 相关的 QK+V 部分(不含投影和 FFN):

T纯 Full Attn QK+V (60层)Nemotron Hybrid Attn QK+V (12层)M3 QK+V (3 Full + 57 MSA)
4K$60 \times 4 \times 64 \times 4096 \times 128 = 8.05\text{G}$$12 \times 4 \times 64 \times 4096 \times 128 = 1.61\text{G}$3 Full: $3 \times 4 \times 64 \times 4096 \times 128 = 0.40\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 4096 \times 128 = 0.24\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~4.46G
32K$60 \times 4 \times 64 \times 32768 \times 128 = 64.4\text{G}$$12 \times 4 \times 64 \times 32768 \times 128 = 12.9\text{G}$3 Full: $3 \times 4 \times 64 \times 32768 \times 128 = 3.22\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 32768 \times 128 = 1.91\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~8.95G
128K$60 \times 4 \times 64 \times 131072 \times 128 = 258\text{G}$$12 \times 4 \times 64 \times 131072 \times 128 = 51.5\text{G}$3 Full: $3 \times 4 \times 64 \times 131072 \times 128 = 12.9\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 131072 \times 128 = 7.65\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~24.4G
1M$60 \times 4 \times 64 \times 1\text{M} \times 128 = 1.97\text{T}$$12 \times 4 \times 64 \times 1\text{M} \times 128 = 394\text{G}$3 Full: $3 \times 4 \times 64 \times 1\text{M} \times 128 = 98.3\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 1\text{M} \times 128 = 58.4\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~160.5G

注:M3 MSA 的 Main Branch 始终只在 2048 个入选 token 上做 Attention——与 T 无关,常数 3.82G。Index Branch 的 QK 评分随 T 线性增长但只有 4 个 head。Full Attention 的 3 层和 Index Branch 的 O(T) 项共同主导 M3 的长上下文成本。

观察

  • 4K 短上下文:三种方案差距较小(8.0G vs 1.6G vs 4.5G)。MSA 反而比纯 Full Attn(12 层)慢,因为 Index Branch 的额外开销 + Main Branch 选了 2048/4096=50% 的 token——稀疏化的好处在短序列上不明显。
  • 128K 中上下文:差距拉开(258G vs 52G vs 24G)。MSA Main Branch 仅访问 2048/131072 = 1.6% 的 token,而 Index Branch O(T) 项(7.7G)仍远小于 Full Attn O(T) 项(258G)。
  • 1M 长上下文:差距成为鸿沟(1970G vs 394G vs 161G)。MSA Main Branch 仅访问 2048/1M = 0.2% 的 token——近乎常数。M3 比纯 Full Attn 的 QK+V 部分快 ~12×,Nemotron Hybrid 快 ~5×。
  • 关键洞察:MSA 在超长上下文时 Main Branch 趋近于常数,Index Branch 成为唯一 O(T) 项。但因为 Index 只有 4 head,实际斜率仅为 Full Attn 的 1/16。MSA 本质是用 O(T) 斜率 1/16 的廉价计算替代全量 O(T)。

如果说短上下文(4K)是“在大厅里找人”,那长上下文(1M)就是“在鸟巢体育场里找人”。Full Attention 的做法是跟每一个观众对视一眼(O(T)),Mamba 的做法是先把体育场分片区,只跟片区组长沟通(chunk + state),MSA 的做法是先派几个侦察兵扫一眼观众席(Index Branch),找到目标区域后再派大队人马过去(Main Branch)。


3.11 速查表:FLOPs 公式汇总

给一张“查表即算”的公式大全。不需要重读整章,从这里抄公式代入 config.json 的数值即可。

组件公式适用场景
Q/K/V 投影$2 \times d \times (H_{\text{type}} \times D_h) \times T_{\text{new}}$Q 用 $H_q$,K/V 用 $H_{kv}$
QK 点积$2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$Prefill 时 $T_{\text{new}}=T_{\text{total}}$(causal 约 /2)
V 加权$2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$与 QK 等量级
O 投影$2 \times d \times (H_q \times D_h) \times T_{\text{new}}$与 Q 投影等量级
MLA $W_{kv\_a}$$2 \times d \times (d_{kv} + D_{\text{rope}}) \times T_{\text{new}}$MLA 模型
MLA $W_{kv\_b}$$2 \times d_{kv} \times H \times (D_{\text{nope}} + D_v) \times T_{\text{new}}$MLA 模型
MLA $W_{q\_a}$$2 \times d \times d_q \times T_{\text{new}}$MLA 模型
MLA $W_{q\_b}$$2 \times d_q \times H \times D_{\text{nope}} \times T_{\text{new}}$MLA 模型
MLA $W_{q\_rope}$$2 \times d \times H \times D_{\text{rope}} \times T_{\text{new}}$MLA 模型
MSA Index QK$2 \times H_{\text{idx}} \times T_{\text{new}} \times T_{\text{total}} \times D_{\text{idx}}$M3 式 MSA
MSA Main QK/V$2 \times H_q \times T_{\text{new}} \times T_{\text{selected}} \times D_h$$T_{\text{selected}} = \text{block\_size} \times \text{top\_k}$
Mamba-2 in_proj$2 \times d \times (2d_{\text{inner}} + 2n_{\text{groups}}N + H_{\text{mamba}}) \times T_{\text{new}}$Nemotron 式 Mamba-2
Mamba-2 SSD diag$T \times C \times H_{\text{mamba}} \times D_{\text{mamba}}$Prefill; decode 时为常数
Mamba-2 SSD off-diag$T / C \times H_{\text{mamba}} \times N^2 \times 2$Prefill; decode 时常数可忽略
Mamba-2 out_proj$2 \times d_{\text{inner}} \times d \times T_{\text{new}}$总是
Router$2 \times d \times E \times T_{\text{new}}$所有 MoE 模型
FFN (ReLU$^2$)$2 \times 2 \times d \times d_{ff} \times T_{\text{new}}$Nemotron
FFN (SwiGLU)$2 \times 3 \times d \times d_{ff} \times T_{\text{new}}$M3, K2.5
ViT Attn$4 \times 2 \times d_{\text{vit}} \times H_{\text{vit}} \times D_{\text{vit}} \times T_{\text{img}}$VL 模型视觉编码器
ViT MLP (GELU)$2 \times 2 \times d_{\text{vit}} \times d_{ff}^{\text{vit}} \times T_{\text{img}}$VL 模型视觉编码器

实战口诀

  1. 先确定场景:prefill 还是 decode?
  2. 线性项(投影 + FFN):直接代入 $T_{\text{new}}$(prefill = 输入长度,decode = 1)
  3. 平方项(QK + V):将 $T_{\text{new}}$ 和 $T_{\text{total}}$ 分开——prefill 时两者相等,decode 时 $T_{\text{new}}=1$ 但 $T_{\text{total}}$ 是全部历史
  4. 稀疏/MSA 项:把 $T_{\text{total}}$ 换成 $T_{\text{selected}}$(入选 token 数)
  5. Mamba 项:decode 时全部为常数,prefill 时乘以 $T$
  6. 把每层加起来,乘以层数,得到单 token FLOPs
  7. 乘以 bytes 和 batch size 得到总计算吞吐需求

CH3 常见计算错误

3.12 从推理到训练:系数体系

CH 3.1-3.11 描述的是前向(推理)FLOPs。训练时需要前向 + 反向,总 FLOPs 是前向的倍数。这个倍数不是笼统的 ×3——不同操作的系数不同,且受梯度重计算(gradient checkpointing)影响。

3.12.1 线性投影:系数 6

每个 nn.Linear($Y = X \cdot W$)在训练中执行 3 次 matmul:

Pass计算FLOPs
前向$Y = X \cdot W$$2 \times m \times n \times k$
反向(权重梯度)$\partial L/\partial W = (\partial L/\partial Y)^T \cdot X$$2 \times m \times n \times k$
反向(输入梯度)$\partial L/\partial X = \partial L/\partial Y \cdot W^T$$2 \times m \times n \times k$
合计$6 \times m \times n \times k$

所以训练 FLOPs = $6 \times \text{params} \times \text{tokens}$(训练 FLOPs = 6 × params × tokens 即由此而来)。

3.12.2 Attention QK 与 V:系数不同(4 vs 3)

Attention 的 Q@K^T 和 A@V 在训练中的 pass 数不同,原因是梯度重计算(Flash Attention 的核心机制)。

Flash Attention 前向时不存储注意力矩阵 $A = Q \cdot K^T$($S \times S$ 矩阵太大),反向时重算 Q@K^T 恢复 $A$。但 A@V 不需要重算——它直接用重算出的 $A$。

Q@K^T 的训练 pass 数推导

Pass计算维度FLOPs
前向$A = Q \cdot K^T$$[H,S,D_{qk}] \times [H,D_{qk},S]$$2 \times H \times S^2 \times D_{qk}$
反向($\partial L/\partial Q$)$\partial L/\partial A \cdot K$$[H,S,S] \times [H,S,D_{qk}]$$2 \times H \times S^2 \times D_{qk}$
反向($\partial L/\partial K$)$\partial L/\partial A^T \cdot Q$同上$2 \times H \times S^2 \times D_{qk}$
重计算前向$A = Q \cdot K^T$(恢复 $A$)同前向$2 \times H \times S^2 \times D_{qk}$
合计$4 \times H \times S^2 \times D_{qk}$

A@V 的训练 pass 数推导

Pass计算维度FLOPs
前向$O = A \cdot V$$[H,S,S] \times [H,S,D_v]$$2 \times H \times S^2 \times D_v$
反向($\partial L/\partial A$)$\partial L/\partial O \cdot V^T$$[H,S,D_v] \times [H,D_v,S]$$2 \times H \times S^2 \times D_v$
反向($\partial L/\partial V$)$A^T \cdot \partial L/\partial O$同上$2 \times H \times S^2 \times D_v$
合计无重计算$3 \times H \times S^2 \times D_v$

关键公式

$$\text{Attention FLOPs}_{\text{train}} = (4 \times D_{qk} + 3 \times D_v) \times H \times S^2 \times L$$
  • 标准 Attention($D_{qk} = D_v = D$):$(4+3) \times D = 7 \times D$ → 简记为系数 7
  • MLA($D_{qk} \neq D_v$):不能用 $7 \times (D_{qk} + D_v)$,必须分开算

如果不使用梯度重计算(关闭 gradient checkpointing):Q@K^T 的重计算 pass 消失,系数从 4 降到 3,总系数变为 $3 + 3 = 6$。

3.13 Indexer 与 Router 的 no_grad 特性

稀疏注意力模型(DSA/MSA/CSA-HCA)的 Indexer 和 MoE 的 Router 都包含一个 torch.topk() 操作——离散选择,数学上不可导。你无法对「选择第 42 号 token」这个动作求梯度。

这一步必然不在 autograd 图中。但 Indexer/Router 内部的可学习参数(线性投影)是否训练,是工程选择:

方案做法梯度来源
GLM-5/M3/V4 的选择@torch.no_grad() 包裹整个 indexer无——参数冻结
理论替代方案straight-through estimator梯度通过 topk 近似传播
理论替代方案REINFORCE / policy gradient梯度通过奖励信号传播

GLM-5 选择完全冻结 indexer(源码 modeling_glm_moe_dsa.py:L197),可能出于训练稳定性和计算成本考虑——indexer 的 O(S²) 如果要反向传播,计算量翻 3 倍。

MoE Router 的情况不同

MoE 的 router 和 DSA Indexer 有本质区别:

  • DSA/MSA Indexer:整个 forward 被 @torch.no_grad() 包裹 → 参数完全冻结
  • MoE Router:router 的 nn.Linear 在 autograd 图中 → 权重通过专家输出反向传播正常训练
  • MoE 的 no_grad 只包裹 dispatch 逻辑(token 到专家的 gather/scatter),不包裹 router 本身

从原理可推断的部分:MoE router 权重需要学习「哪个 token 给哪个专家」,这必须通过下游 loss 的梯度训练,所以 router Linear 不应该 no_grad。

从原理不可推断的部分:DSA Indexer 选择完全冻结(而非用 straight-through estimator),这是工程决策,只能从源码 @torch.no_grad() 确认。

对 FLOPs 计算的影响

操作训练系数原因
Linear 投影(标准)6前向 + 反向×2
Attention Q@K^T(标准)4前向 + 反向×2 + 重计算
Attention A@V(标准)3前向 + 反向×2
Indexer 全部操作no_grad1仅前向,无反向
MoE Router Linear6正常前向+反向
MoE dispatch(topk/gather)1no_grad 内,仅前向
TopK 比较1非 matmul,仅前向
Conv1d(depthwise)3前向 + 反向

3.14 IndexShare 对训练 FLOPs 的影响

GLM-5.2 的 IndexShare 机制(每 4 层共享 1 个 indexer)在推理时节省 indexer FLOPs。训练时的影响取决于实现:

  • GLM-5.2 的 indexer 在 @torch.no_grad() 下,系数 1 × 21 层(full)= 仅 21/78 的 indexer 前向 FLOPs
$$\text{Indexer FLOPs}_{\text{GLM-5.2}} = \frac{21}{78} \times \text{Indexer FLOPs}_{\text{GLM-5.1}}$$

这对总 FLOPs 的影响 <1%(indexer 本身占比小),但对推理延迟的改善显著(博客声称 2.9× per-token FLOPs 降低,因为推理时 indexer 的 O(S²) 在长上下文下占比大)。

#常见错误正确做法
1decode 时把 QKV 投影乘以 $T_{\text{total}}$decode 只投影 1 个新 token,投影 FLOPs 是常数
2GQA 下 QK 点积用 $H_{kv}$ 算QK 点积前 K 已经被 repeat_kv 扩展到 $H_q$,用 $H_q$ 算
3MLA 的 QK 点积以为能省 FLOPsMLA 省的是 KV cache(显存),不是 QK 点积 FLOPs——最终 attention 的 $D_h = D_{\text{nope}} + D_{\text{rope}}$ 与 MHA 相同
4把 prefill 的 causal /2 也用在 decodedecode 的 query 只有 1 个,不存在 causal mask 的对称简化,公式是 $T_{\text{new}} \times T_{\text{total}}$ 而非 $T^2/2$
5MSA 的 Index QK 以为不用算 O(T²)Index QK 仍然是 O(T²)(prefill)或 O(T)(decode),只是 head 数少(4 vs 64),系数省 16×
6Mamba-2 decode 时把 SSD 对角块按 O(T) 算Mamba-2 decode 是 O(1)——只需更新当前 chunk 的状态,不重算全部 chunk
7忘记乘 2(MAC 系数)深度学习框架中 1 MAC = 2 FLOPs,所有矩阵乘法公式必须有因子 2
8把参数数量当 FLOPs参数量是“存了多少数”,FLOPs 是“每次前向算多少下”,两者中间隔着序列长度 T(对 O(T) 项)或 T²(对 O(T²) 项)

下一章预告:CH 4 内存分析——KV Cache 大小推导、MLA/GQA 的缓存压缩比、显存带宽瓶颈(Roofline 模型)、batch size 与延迟的权衡。


系列导航(一)预备知识与参数分解 ← 当前 → (三)KV Cache 与推理显存(四)M3 实战 + Roofline(五)训练显存(六)通信分析(七)推理服务