|
| 1 | +--- |
| 2 | +hero: |
| 3 | + title: "DeepSeek 稀疏注意力" |
| 4 | + subtitle: "⚡ 从二次方到近线性注意力 - 闪电索引器突破" |
| 5 | + tags: |
| 6 | + - "⏱️ 技术深度解析" |
| 7 | + - "📄 研究文章" |
| 8 | +--- |
| 9 | + |
| 10 | +[研究论文](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf) • [模型](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp) • [GPU内核](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) (TileLang用于研究,CUDA用于生产) |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | +新模型通过稀疏注意力使推理和训练成本显著降低,在保持性能的同时减少计算成本。 |
| 15 | + |
| 16 | +先决条件:注意力机制 |
| 17 | + |
| 18 | +**📺 推荐视频资源:** 为了全面理解注意力机制和DeepSeek的多头潜在注意力,请观看此视频:[DeepSeek V3 从零开始](https://youtu.be/TfEG0TwueTs) |
| 19 | + |
| 20 | +- **如果您是注意力机制的新手:** 从视频开头开始 |
| 21 | +- **如果您想专注于DeepSeek的多头潜在注意力(MLA):** 跳转到38:53或使用此直接链接:[https://youtu.be/TfEG0TwueTs?t=2333](https://youtu.be/TfEG0TwueTs?t=2333) |
| 22 | +- **注意:** 我会在本文/视频中再次解释MLA,但我建议两者都观看以获得更好的理解。 |
| 23 | + |
| 24 | +标准Transformer使用"注意力"机制,其中每个新生成的token都会回看序列中所有先前的token。 |
| 25 | + |
| 26 | +这在计算上非常昂贵。如果您有一个长度为$L$的序列,复杂度是$O(L^2)$,意味着所需的计算和内存呈二次方增长。 |
| 27 | + |
| 28 | +将文本长度从10,000个token加倍到20,000个token不仅使成本翻倍——而是四倍。这使得处理非常长的文档(如书籍或大型代码库)变得极其缓慢和昂贵。 |
| 29 | + |
| 30 | +DeepSeek稀疏注意力(DSA)不是让每个token关注所有先前的token,而是智能地选择一小部分固定大小($k$)的最相关先前token来关注。这将复杂度从$O(L^2)$改变为$O(L \cdot k)$,这更容易管理,因为$k$是一个小常数(例如2048),而$L$可以非常大(例如128,000)。 |
| 31 | + |
| 32 | +DSA由两个主要组件组成: |
| 33 | + |
| 34 | +闪电索引器将在每个token之间执行完全注意力,但它更小更快——ReLU激活非常快,键和查询的维度更小。 |
| 35 | + |
| 36 | +#### 组件1:闪电索引器 |
| 37 | + |
| 38 | +这是一个快速轻量级的机制,其唯一工作是找出哪些过去的token对当前token很重要。 |
| 39 | + |
| 40 | +- **工作原理:** 对于当前token($h_t$),索引器快速计算每个先前token($h_s$)的"索引分数"($I_{t,s}$)。这个分数表示token $s$对token $t$的预测相关性。 |
| 41 | +- **公式(1):** 公式1本质上是一个简化的注意力计算。它使用自己的一套小查询($q^I$)和键($k^I$)来计算这些分数。 |
| 42 | +- **为什么它是"闪电":** 它专为速度而设计。它使用简单的$\text{ReLU}$激活函数,可以用低精度数字(FP8)运行,使其在计算上非常便宜,即使它在技术上仍然查看所有先前的token($O(L^2)$操作,但非常非常快)。 |
| 43 | + |
| 44 | +### 1. 公式解释("什么") |
| 45 | + |
| 46 | +论文提供了两个关键公式来描述这个两步过程。 |
| 47 | + |
| 48 | +#### **公式(1):闪电索引器** |
| 49 | + |
| 50 | +$$ |
| 51 | +I_{t,s} = \sum_{j=1}^{H_I} w_{t,j}^I \cdot \text{ReLU}(q_{t,j}^I \cdot k_s^I) |
| 52 | +$$ |
| 53 | + |
| 54 | +这个公式计算**索引分数**($I_{t,s}$),它表示过去token $s$对当前token $t$的"相关性"。让我们分解一下: |
| 55 | + |
| 56 | +* $I_{t,s}$:最终重要性分数。分数越高意味着token $s$对token $t$更重要。 |
| 57 | +* $h_t$和$h_s$:这些是当前token($t$)和先前token($s$)的向量表示(隐藏状态)。 |
| 58 | +* $q_{t,j}^I$和$k_s^I$:这些是为索引器创建的特殊、轻量级**查询**和**键**向量(由$I$上标表示)。它们分别从$h_t$和$h_s$派生。 |
| 59 | +* $q_{t,j}^I \cdot k_s^I$:这是点积,注意力中的基本操作。它测量查询和键之间的相似性或兼容性。 |
| 60 | +* $\text{ReLU}(\cdots)$:一个简单的激活函数(修正线性单元)。计算非常快。如果点积为负,它变为0;否则保持不变。 |
| 61 | +* $w_{t,j}^I$:一个额外的权重,也从查询token $h_t$派生。它充当每个索引器头$j$的学习门或重要性因子。 |
| 62 | +* $\sum \cdots$:这将对所有索引器头($H^I$)的结果求和。索引器只有几个头以保持快速。 |
| 63 | + |
| 64 | +**简单来说:** 闪电索引器是一个迷你、简化的注意力机制。它的唯一工作是快速计算每对token的相关性分数,而不进行完整、昂贵的注意力计算。 |
| 65 | + |
| 66 | +#### **公式(2):主要注意力计算** |
| 67 | + |
| 68 | +$$ |
| 69 | +u_t = \text{Attn}(h_t, \{c_s | I_{t,s} \in \text{Top-k}(I_{t,:})\}) |
| 70 | +$$ |
| 71 | + |
| 72 | +这个公式描述了在完成选择后如何计算最终输出($u_t$)。 |
| 73 | + |
| 74 | +* $u_t$:当前token $t$的最终输出隐藏状态。 |
| 75 | +* $\text{Attn}(\cdots)$:这表示主要、强大的注意力机制(在这种情况下,多查询注意力)。 |
| 76 | +* $h_t$:来自当前token的查询。 |
| 77 | +* $\{c_s | I_{t,s} \in \text{Top-k}(I_{t,:})\}$:这是最重要的部分。它意味着:"仅当它们对应的索引分数$I_{t,s}$(在公式1中计算)是当前token $t$的$\text{top-k}$最高分数之一时,才使用键值条目$c_s$的集合。" |
| 78 | + |
| 79 | +**简单来说:** 主要注意力机制被告知忽略几乎所有先前的token,只关注闪电索引器识别为最重要的少数键值条目。 |
| 80 | + |
| 81 | +#### 组件2:细粒度Token选择 |
| 82 | +这个组件很简单:它获取闪电索引器计算的所有索引分数,并选择$\text{top-k}$最高分数。 |
| 83 | + |
| 84 | +- **功能:** 它充当守门员。它告诉主要、强大的注意力机制:"你不需要查看所有100,000个先前的token。我为你找到了2,048个最重要的。只看这些。" |
| 85 | + |
| 86 | +然后通过主要注意力模块计算最终注意力输出($u_t$),但只使用当前token的查询和选定的$k$个键值对。 |
| 87 | + |
| 88 | +### 步骤3:模型如何训练 |
| 89 | + |
| 90 | +他们没有从头训练这个模型。他们巧妙地适应了一个现有的、强大的模型(**DeepSeek-V3.1-Terminus**),该模型已经在长上下文中训练过。训练分几个阶段进行。 |
| 91 | + |
| 92 | +#### 阶段1:继续预训练(两个阶段) |
| 93 | + |
| 94 | +1. **密集预热阶段:** |
| 95 | + - **目标:** 教全新的闪电索引器"重要"token是什么样子。 |
| 96 | + - **方法:** 他们冻结了主模型并保持标准(密集)注意力活跃。然后他们只训练闪电索引器。索引器的目标是使其重要性分数与强大的、预训练主模型的注意力分数匹配。他们使用KL散度损失,这是一种测量两个概率分布相似程度的方法。本质上,他们告诉索引器:"学习预测主模型*会*关注什么。"这个阶段很短(1,000步)。 |
| 97 | + |
| 98 | +2. **稀疏训练阶段:** |
| 99 | + - **目标:** 使整个模型适应稀疏注意力模式。 |
| 100 | + - **方法:** 他们"开启"$\text{top-k}$选择器,使注意力稀疏。他们解冻主模型并一起训练所有内容。 |
| 101 | + * 主模型在其通常任务上训练:预测下一个词(语言建模损失)。它必须学会在选择器提供的有限上下文中表现良好。 |
| 102 | + * 闪电索引器继续用KL散度损失训练以与主模型的注意力对齐,但现在只在选定的$k$个token上。 |
| 103 | + * 这是主要训练阶段(15,000步,使用943.7亿个token)。 |
| 104 | + |
| 105 | +#### 阶段2:后训练 |
| 106 | +预训练完成后,他们使用强化学习(RL)为特定任务(如编码、数学、推理和遵循指令)微调模型。重要的是,他们使用了与原始DeepSeek-V3.1-Terminus模型**完全相同的数据和方法**。这确保了密集和稀疏模型之间的公平比较。 |
| 107 | + |
| 108 | +## 深度解析:多头潜在注意力(MLA)架构 |
| 109 | + |
| 110 | +让我们使用提供的公式和文本逐步分解多头潜在注意力(MLA)架构。 |
| 111 | + |
| 112 | +MLA的核心目标是通过巧妙的"压缩然后解压缩"策略大幅减少键值(KV)缓存的大小,这是处理长序列时的主要内存瓶颈。 |
| 113 | + |
| 114 | +该过程可以分为两个主要部分: |
| 115 | +1. 创建键和值(用于缓存)。 |
| 116 | +2. 创建查询(与缓存交互)。 |
| 117 | + |
| 118 | +--- |
| 119 | + |
| 120 | +### 步骤1:处理键和值(公式1-5) |
| 121 | + |
| 122 | +本节解释模型如何获取当前token($h_t$)的输入并创建将存储(以压缩形式)并由未来token使用的键和值向量。 |
| 123 | + |
| 124 | +#### 公式(1):压缩步骤 |
| 125 | +$$ |
| 126 | +c_t^{KV} = W^{DKV} \cdot h_t |
| 127 | +$$ |
| 128 | + |
| 129 | +- **作用:** 这是节省内存的最关键步骤。它获取当前token的大、高维输入向量($h_t$)并将其投影到称为**压缩潜在向量**($c_t^{KV}$)的更小、低维向量中。 |
| 130 | +- **$W^{DKV}$:** 这是一个学习的"下投影"矩阵。模型学习如何在训练期间最好地将$h_t$的信息压缩到$c_t^{KV}$中。 |
| 131 | +- **类比:** 将$h_t$视为高分辨率图像,将$c_t^{KV}$视为高度压缩的JPEG。JPEG存储起来要小得多,但保留了最重要的视觉信息。$c_t^{KV}$是与token的*内容*相关的唯一存储在缓存中的部分。 |
| 132 | + |
| 133 | +--- |
| 134 | + |
| 135 | +#### 公式(2)、(3)和(4):重构最终键 |
| 136 | + |
| 137 | +每个注意力头的最终键由两个独立的部分构成:一个"内容"部分和一个"位置"部分。 |
| 138 | + |
| 139 | +- **公式(2):解压缩"内容"键** |
| 140 | + $$ |
| 141 | + \begin{bmatrix} k_{t,1}^C \\ \vdots \\ k_{t,n_h}^C \end{bmatrix} = W^{UK} \cdot c_t^{KV} |
| 142 | + $$ |
| 143 | + * 这获取小潜在向量$c_t^{KV}$并将其*投影回*完整维度,为所有$n_h$个注意力头创建键的"内容"部分($k_t^C$)。 |
| 144 | + - **$W^{UK}$:** 这是键的学习"上投影"矩阵。它是解压缩器。 |
| 145 | + |
| 146 | +- **公式(3):创建"位置"键** |
| 147 | + $$ |
| 148 | + k_t^R = \text{RoPE}(W^{KR} \cdot h_t) |
| 149 | + $$ |
| 150 | + * 这部分处理token在序列中的位置。它获取*原始*高维输入$h_t$并应用变换($W^{KR}$)后跟**旋转位置嵌入(RoPE)**。 |
| 151 | + * 这创建了一个"解耦"键$k_t^R$,它纯粹编码位置信息。这是存储在缓存中的第二个也是最后一块。 |
| 152 | + |
| 153 | +- **公式(4):组合为最终键** |
| 154 | + $$ |
| 155 | + k_{t,i} = \begin{bmatrix} k_{t,i}^C \\ k_t^R \end{bmatrix} |
| 156 | + $$ |
| 157 | + * 特定注意力头$i$的最终键($k_{t,i}$)通过简单连接(粘在一起)内容部分($k_{t,i}^C$)和位置部分($k_t^R$)形成。 |
| 158 | + |
| 159 | +--- |
| 160 | + |
| 161 | +#### 公式(5):解压缩值 |
| 162 | +$$ |
| 163 | +\begin{bmatrix} v_{t,1}^C \\ \vdots \\ v_{t,n_h}^C \end{bmatrix} = W^{UV} \cdot c_t^{KV} |
| 164 | +$$ |
| 165 | + |
| 166 | +* 这与键解压缩非常相似。它使用*相同*的小潜在向量$c_t^{KV}$但*不同*的上投影矩阵($W^{UV}$)来重构所有$n_h$个头的完整大小值向量。 |
| 167 | +* 这表明$c_t^{KV}$是键和值信息的**联合**压缩。 |
| 168 | + |
| 169 | +**KV缓存的关键要点:** |
| 170 | +文本明确说明**只需要缓存蓝框向量($c_t^{KV}$和$k_t^R$)。** 这是MLA的魔力。不是为每个头存储巨大的键和值向量,您只需要存储一个微小的潜在向量($c_t^{KV}$)和一个位置向量($k_t^R$)。完整的键和值在需要时即时重构。 |
| 171 | + |
| 172 | +--- |
| 173 | + |
| 174 | +### 步骤2:处理查询(公式6-9) |
| 175 | + |
| 176 | +这个过程镜像键生成,但它是为将关注缓存中过去键的*当前*token的查询。 |
| 177 | + |
| 178 | +- **公式(6):压缩查询** |
| 179 | + $$ |
| 180 | + c_t^Q = W^{DQ} \cdot h_t |
| 181 | + $$ |
| 182 | + * 就像KV一样,输入$h_t$使用单独的下投影矩阵($W^{DQ}$)压缩成小潜在查询向量$c_t^Q$。 |
| 183 | + |
| 184 | +- **公式(7):解压缩"内容"查询** |
| 185 | + $$ |
| 186 | + \begin{bmatrix} q_{t,1}^C \\ \vdots \\ q_{t,n_h}^C \end{bmatrix} = W^{UQ} \cdot c_t^Q |
| 187 | + $$ |
| 188 | + * 小潜在查询$c_t^Q$被投影回以为每个头创建查询的"内容"部分($q_t^C$)。 |
| 189 | + |
| 190 | +- **公式(8):创建"位置"查询** |
| 191 | + $$ |
| 192 | + \begin{bmatrix} q_{t,1}^R \\ \vdots \\ q_{t,n_h}^R \end{bmatrix} = \text{RoPE}(W^{QR} \cdot c_t^Q) |
| 193 | + $$ |
| 194 | + * 查询的位置部分($q_t^R$)通过对*压缩*潜在查询$c_t^Q$的投影应用RoPE创建。 |
| 195 | + |
| 196 | +- **公式(9):组合为最终查询** |
| 197 | + $$ |
| 198 | + q_{t,i} = \begin{bmatrix} q_{t,i}^C \\ q_{t,i}^R \end{bmatrix} |
| 199 | + $$ |
| 200 | + * 每个头$i$的最终查询通过连接其内容和位置部分形成。 |
| 201 | + |
| 202 | +### 整个MLA流程总结 |
| 203 | + |
| 204 | +1. **对于每个token $t$:** 获取其输入嵌入$h_t$。 |
| 205 | +2. **压缩:** 创建一个微小的潜在向量$c_t^{KV}$,它联合表示键和值。 |
| 206 | +3. **获取位置:** 从$h_t$创建位置键$k_t^R$。 |
| 207 | +4. **缓存:** 在KV缓存中只存储$c_t^{KV}$和$k_t^R$。这是**内存节省**步骤。 |
| 208 | +5. **关注:** 当新token需要执行注意力时,它生成其查询($q_{t,i}$)。然后它检索所有先前token $s$的缓存$c_s^{KV}$和$k_s^R$,使用上投影矩阵即时重构它们的完整键和值,并计算注意力分数。 |
| 209 | + |
| 210 | +### MLA如何与DeepSeek稀疏注意力集成 |
| 211 | + |
| 212 | +这个架构的美妙之处在于MLA如何与DSA无缝协作: |
| 213 | + |
| 214 | +1. **DSA选择相关token:** 闪电索引器识别top-k最重要的先前token |
| 215 | +2. **MLA只处理选定的token:** 不是为所有128,000个先前token重构键和值,MLA只需要为选定的$\text{top-k}$个token解压缩缓存的$c_s^{KV}$和$k_s^R$ |
| 216 | +3. **内存效率倍增:** DSA减少要处理的token数量,而MLA减少每个token的内存占用 |
| 217 | + |
| 218 | +这种组合允许DeepSeek-V3.2处理极长序列(128,000+个token),同时保持计算效率和内存效率。 |
| 219 | +--- |
| 220 | + |
| 221 | +## 实验研究结果 |
| 222 | + |
| 223 | +*来自[开放超级智能实验室](https://opensuperintelligencelab.com/)研究的初步发现* |
| 224 | + |
| 225 | +### 研究问题 |
| 226 | + |
| 227 | +我们的实验旨在回答: |
| 228 | + |
| 229 | +1. **稀疏注意力是否改善标准注意力架构的性能?** |
| 230 | +2. **稀疏注意力在应用于已经高效的多头潜在注意力(MHLA)时是否提供额外好处?** |
| 231 | +3. **这些机制如何在不同序列长度上扩展?** |
| 232 | + |
| 233 | +未来研究(您可以参与): |
| 234 | +## 核心架构 |
| 235 | +1. **为什么我们需要索引器分数的额外权重?** ($w_{t,j}^I$的必要性) |
| 236 | +2. **不同序列长度的最优$k$值是什么?** |
| 237 | + |
| 238 | +## 闪电索引器 |
| 239 | +3. **索引器性能如何随序列长度扩展?** |
| 240 | +4. **扩展如何影响索引器准确性和计算效率?** |
| 241 | + |
| 242 | + |
| 243 | +### 实验1:标准注意力 vs 稀疏注意力 |
| 244 | + |
| 245 | +| 序列长度 | 标准损失 | 稀疏损失 | 改进 | 标准准确率 | 稀疏准确率 | |
| 246 | +|------------|---------------|-------------|-------------|--------------|------------| |
| 247 | +| 64 | 8.52 | **3.56** | **139%更好** | 4.3% | **53.2%** | |
| 248 | +| 128 | 7.28 | **3.00** | **143%更好** | 6.5% | **57.6%** | |
| 249 | +| 256 | 7.15 | **1.78** | **302%更好** | 7.6% | **68.4%** | |
| 250 | + |
| 251 | +**关键发现**:稀疏注意力显著优于标准注意力,对更长序列的好处增加。 |
| 252 | + |
| 253 | +### 实验2:MHLA密集 vs MHLA + 稀疏 |
| 254 | + |
| 255 | +| 序列长度 | MHLA损失 | MHLA+稀疏损失 | 改进 | MHLA准确率 | MHLA+稀疏准确率 | |
| 256 | +|------------|-----------|------------------|-------------|----------|-----------------| |
| 257 | +| 64 | 7.43 | **6.64** | **12%更好** | 9.2% | **15.5%** | |
| 258 | +| 128 | 6.85 | 6.97 | -2%更差 | 10.3% | 10.3% | |
| 259 | +| 256 | 6.61 | **6.55** | **1%更好** | 12.5% | **13.2%** | |
| 260 | +| 1024 | **4.10** | 6.91 | **-41%更差** | **32.2%** | 10.7% | |
| 261 | +| 2048 | 6.64 | **6.63** | **0%相同** | 11.9% | **14.4%** | |
| 262 | + |
| 263 | +**关键发现**:结果混合——稀疏帮助短序列但显著伤害长序列上的MHLA。 |
| 264 | + |
| 265 | +### 速度分析 |
| 266 | + |
| 267 | +**实验1**:相似的训练速度(两者每步约0.06秒) |
| 268 | +**实验2**:由于闪电索引器开销,稀疏版本慢1-4% |
| 269 | + |
| 270 | +### 研究洞察 |
| 271 | + |
| 272 | +**为什么稀疏帮助标准注意力:** |
| 273 | +- **强制选择性**充当正则化 |
| 274 | +- **减少密集注意力中的注意力稀释** |
| 275 | +- **通过关注相关token防止过拟合** |
| 276 | + |
| 277 | +**为什么稀疏可能不帮助MHLA:** |
| 278 | +- **冗余机制**:MHLA已经通过潜在空间压缩 |
| 279 | +- **冲突模式**:MHLA的学习压缩 vs 闪电索引器选择 |
| 280 | +- **双重压缩**:对长序列可能过于激进 |
| 281 | + |
| 282 | +### 限制和注意事项 |
| 283 | + |
| 284 | +这些是有限实验的初步结果。几个因素可能影响泛化性: |
| 285 | + |
| 286 | +- **有限的训练时间**:每个实验只有500-1000步 |
| 287 | +- **小模型大小**:512d模型可能不反映更大模型的行为 |
| 288 | +- **数据集**:TinyStories上的结果可能不泛化到其他领域 |
| 289 | +- **超参数**:没有为每个配置广泛调整 |
| 290 | + |
| 291 | +### 结论 |
| 292 | + |
| 293 | +我们的初步发现表明: |
| 294 | + |
| 295 | +1. **稀疏注意力显著改善标准注意力架构** |
| 296 | +2. **MHLA的潜在压缩可能已经提供稀疏性的大部分好处** |
| 297 | +3. **结合两种机制对长序列可能是冗余的甚至有害的** |
| 298 | + |
| 299 | +然而,这些结果需要更大的模型、更长的训练和多样化数据集的进一步验证。 |
| 300 | + |
| 301 | +### 关于开放超级智能实验室 |
| 302 | + |
| 303 | +[开放超级智能实验室](https://opensuperintelligencelab.com/)致力于推进开源AI研究。我们进行这些实验来理解大语言模型中的基本机制,并与社区透明地分享我们的发现。 |
| 304 | + |
| 305 | +我们的研究正在进行中,我们欢迎社区的协作和反馈。这些实验代表可能包含缺陷或限制的活跃研究,我们鼓励独立验证我们的发现。 |
| 306 | + |
| 307 | +--- |
| 308 | + |
| 309 | +*这项研究是我们对高效注意力机制持续调查的一部分。结果是初步的,随着我们进行更广泛的实验,可能会修订。* |
0 commit comments