Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 387 additions & 0 deletions app/blog/diffusion-transformer-representation-autoencoder/page.tsx

Large diffs are not rendered by default.

476 changes: 63 additions & 413 deletions app/page.tsx

Large diffs are not rendered by default.

22 changes: 17 additions & 5 deletions components/markdown-renderer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import remarkMath from 'remark-math';
import rehypeHighlight from 'rehype-highlight';
import rehypeKatex from 'rehype-katex';
import Image from 'next/image';
import 'highlight.js/styles/github-dark.css';
import 'highlight.js/styles/atom-one-dark.css';
import 'katex/dist/katex.min.css';
import '../styles/math-styles.css';

Expand Down Expand Up @@ -92,8 +92,10 @@ export function MarkdownRenderer({ content }: MarkdownRendererProps) {
);
},
pre: ({ children }) => (
<pre className="bg-slate-900 border border-gray-700 rounded-lg p-6 overflow-x-auto mb-10 text-sm">
{children}
<pre className="bg-slate-900 border border-gray-700 rounded-lg p-6 overflow-x-auto mb-10 text-sm font-mono leading-relaxed text-gray-200">
<code className="text-gray-200">
{children}
</code>
</pre>
),
// Custom blockquote styles
Expand All @@ -106,9 +108,19 @@ export function MarkdownRenderer({ content }: MarkdownRendererProps) {
img: ({ src, alt }) => {
if (!src) return null;

// Check if this is the architecture diagram that should be smaller
// Check if this is the architecture comparison diagram that should be larger
const isArchitectureComparison = alt?.includes('SD-VAE vs RAE') || (typeof src === 'string' && src.includes('architecture-comparison'));
// Check if this is other architecture diagrams that should be smaller
const isArchitectureDiagram = alt?.includes('Architecture') || (typeof src === 'string' && src.includes('architecture'));
const imageClassName = isArchitectureDiagram ? "w-1/2 h-auto mx-auto" : "w-full h-auto";

let imageClassName;
if (isArchitectureComparison) {
imageClassName = "w-full h-auto"; // Full width for comparison diagrams
} else if (isArchitectureDiagram) {
imageClassName = "w-1/2 h-auto mx-auto"; // Half width for other architecture diagrams
} else {
imageClassName = "w-full h-auto"; // Default full width
}

// Handle external images
if (typeof src === 'string' && src.startsWith('http')) {
Expand Down

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
275 changes: 275 additions & 0 deletions public/content/tiny-recursive-model/tiny-recursive-model-content-zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
---
hero:
title: "微型递归模型"
subtitle: "全新递归推理 AI 架构"
tags:
- "⏱️ 技术深度解析"
- "📄 研究文章"
---

## 全新的 AI 推理架构

700 万参数模型如何在数独、迷宫和 ARC-AGI 任务中击败万亿参数模型

**微型推理模型(TRM)** 使用一个 2 层 Transformer(仅 700 万参数),通过数百次重复使用相同的层来递归地推理问题。

它在数独极限版、迷宫、ARC-AGI 等任务中击败了比它大 100 倍的模型。

在本教程中,我们将学习 TRM 的工作原理并进行自己的实验。

---

## TRM 架构概览

![微型递归模型架构](/content/tiny-recursive-model/images/tiny-recursive-model-architecture.png)
*图示:微型递归模型架构展示了主处理块(4 层 Transformer)、输入组合(问题 (x)、答案 (y)、推理 (z))、损失计算的输出处理,以及递归更新机制,该机制通过最多 16 步迭代地优化推理和预测。*

上图展示了完整的 TRM 架构。该模型处理三个关键组件:
- **输入 (x)**:要解决的问题(例如迷宫布局)
- **预测 (y)**:模型当前的答案尝试
- **潜在变量 (z)**:模型的内部推理状态

这些组件被组合在一起并通过 4 层 Transformer 栈处理,输出用于计算交叉熵损失。关键创新是底部的递归更新机制,它通过多步迭代逐步优化推理 (z) 和预测 (y),以改进解决方案。

---

## TRM 如何工作

### 步骤 1:设置

让我们训练 TRM 来解决迷宫。

**1. 将迷宫表示为网格:**
首先,我们将迷宫显示为数字网格。网格中的每个单元格都有一个数字。

- `0` = 空路径
- `1` = 墙壁
- `2` = 起点
- `3` = 终点

举个具体例子,让我们追踪一个 3x3 的小迷宫。

- **`x_input`**(未解决的迷宫)
```
[[2, 0, 1],
[1, 0, 1],
[1, 0, 3]]
```
- **`y_true`**(正确的解决方案,用 `4` 表示路径)
```
[[2, 4, 1],
[1, 4, 1],
[1, 4, 3]]
```

**2. 标记化(Tokenization):**
术语 **token** 只是指我们数据的一个单元。在这种情况下,网格中的单个数字(`0`、`1`、`2`、`3` 或 `4`)就是一个 token。为了让网络更容易处理,我们将网格"展开"成一个长的一维列表。

对于我们的 3x3 示例,网格展开为包含 9 个 token 的列表。

**3. 嵌入:赋予数字意义:**
为了让模型理解像 `4` 和 `1` 这样的数字的含义,我们将为每个数字分配一个大的**向量嵌入**。向量嵌入是一个长向量(数字数组),模型可以修改它来存储关于墙壁、空路径等的信息。

这些向量将表示"墙壁"或"终点"的含义。

我建议你通过在 YouTube 上搜索或与 AI 聊天机器人交谈来复习一下向量嵌入(在大语言模型中,词、token 等的嵌入)是什么。

- **嵌入层**就像一个字典。
- 它包含我们每个数字的向量嵌入。
- `1`: `[0.3, -1.2, 0.7, 0.0, 1.5, -0.4, 0.9, 2.3]` ← "墙壁"的示例向量嵌入
- **输出:** 一个称为**向量**的长数字列表。这个向量以网络可以理解的方式表示"墙壁"的*含义*。网络本身在训练期间选择(学习)这个向量中的数字,这样它就可以"理解"它。

完成此步骤后,我们的输入迷宫不再是简单数字的列表。它是一个向量列表。对于我们的 3x3 迷宫,如果我们为每个 token 使用大小为 8 的向量,我们的输入变为:

- `x`:一个 `9x8` 的向量矩阵,表示迷宫。

这种丰富的表示就是我们输入主模型的内容。

---

### 步骤 2:核心架构:TRM 大脑

TRM 的"大脑"是一个称为 `net` 的微型 2 层 Transformer。它处理信息以产生输出。为了"思考",TRM 使用两个变量,它们的形状都与 `x` 相同:

- `y`:模型当前对解决方案的**最佳猜测**。可能是错误的
```
[[2, 4, 1],
[1, 4, 1],
[1, 0, 3]]
```
- `z`:一个**潜在思考**。`z` 告诉需要在 `y` 中改变什么才能将其变成正确的解决方案。`z` 多次通过 Transformer,让模型细化需要在 `y` 中改变的内容,这就是模型推理或思考的方式。然后将更改应用于 `y`。

对于我们的 3x3 示例,`z` 和 `y` 最初都是 `9x8` 的零矩阵。

---

### 步骤 3:学习过程,由内而外

TRM 通过一系列嵌套循环学习。让我们从核心开始,逐步向外构建。

#### 最内层循环:`latent_recursion`(核心思考)

这是微型 `net`(一个 2 层 Transformer)完成所有工作的地方。该过程分为两个阶段,它们重复形成思考和优化的循环。

**阶段 A:推理(更新草稿本 `z`)**
模型通过在 6 步循环中优化其内部规划 token `z` 来"思考"。目标是建立一个越来越好的改变 `y` 的计划。

1. **过程:** 在 6 个步骤中的每一个步骤中,`net` 接受三个输入:
- 迷宫本身(`x`)。
- 模型当前对解决方案的最佳猜测(`y`)——在开始时这可能全是零。
- 上一步的草稿本(`z`)。
2. **工作原理:**
- **组合输入:** 三个输入按元素相加(`x + y + z`)。这创建了一个丰富向量的单个序列,其中每个向量(表示迷宫中的一个单元格)包含有关迷宫布局(`x`)、当前猜测(`y`)和正在进行的思考过程(`z`)的组合信息。
- **用注意力思考:** 这个组合序列被输入到 2 层 Transformer 中。Transformer 的自注意力机制允许它同时查看所有单元格并识别关系。例如,它可以看到"起点"单元格如何与潜在路径单元格相关,并结合输入数据 `x` 和推理 `z`。
- **生成下一个思考:** 两个 Transformer 层处理这些信息并输出一个相同形状的新向量序列。这个输出*就是*新的 `z`。没有单独的"输出头"来生成它;两层执行的转换*就是*创建下一个、更精细的思考的行为。尽管输入是包含 `x` 和 `y` 的和,但网络学会产生一个用作下一步的有用的新 `z` 的输出。

这个过程重复 6 次,意味着信息连续六次通过相同的两层,每次传递都变得越来越复杂。
3. **示例追踪:** 在通过 Transformer 的几次传递后,`z` 可能编码低级特征,如墙壁位置。到第六次传递时,它可能表示更新答案(`y`)的高级计划。

- 有趣的是,相同的 2 个 Transformer 层用于检测低级特征、制定高级计划,以及后来更新 `y` 本身。这 2 层具有多重用途,这是神经网络的力量,它可以学习执行多个不太相关或不相关的转换,这只取决于输入数据。

**阶段 B:优化答案(更新猜测 `y`)**
在 6 步推理循环之后,使用最新的潜在思考 `z`,模型更新其答案 `y`。

- **工作原理:** 它将先前的答案(`y`)与最终的、精细的思考(`z`)通过将它们相加(`y + z`)组合在一起,并最后一次通过相同的 `net`。输出是新的、改进的 `y`。
- **关键是,`x` 不包含在此步骤中。** 这是一个刻意的设计选择,告诉单个 `net` 要执行哪个任务。
- `x` 存在于推理中(`x + y + z`)。
- `x` 不存在于答案优化中(`y + z`)。

我说"答案优化"的原因是,这个 6+1 循环会发生多次,每次"思考" 6 次并更新 `y` 一次。

#### 中间循环:`deep_recursion`(完整的思考过程)

现在我们理解了推理 + y 优化循环是如何工作的,让我们看看从头开始的完整思考过程,在这个过程中,整个循环重复 3 次以获得最佳的 `y`。

前面描述的内部循环(推理和 `y` 优化的 6+1 步骤)运行 `T` 次(例如,`T=3`)。状态(`y` 和 `z`)在这些运行之间**被保留**;它不会重置为零。

- **第 1 轮(热身):** 从空白(全零)的 `y` 和 `z` 开始(请记住,这是过程的绝对开始,所以没有 `y` 和 `z` 可以保留)。它运行完整的内部循环(6 步推理 + 1 步 `y` 优化)以产生更智能的 `y_1` 和 `z_1`。这是在"无梯度"模式下完成的,以节省速度和内存——神经网络在这里不学习。
- **第 2 轮(热身):** 它以 `y_1` 和 `z_1` 作为起点,再次运行内部循环以产生更好的 `y_2` 和 `z_2`。仍然没有梯度和学习。
- **第 3 轮(真正的):** 它从经过充分推理的 `y_2` 和 `z_2` 开始,最后一次运行内部循环,这次所有计算都被跟踪,以便模型可以通过反向传播学习。

在最终的可学习步骤之前预热模型的"思考"的这个过程是一个关键的优化。

#### 最外层循环:更多循环!

模型获得多次"机会"(最多 16 次)来解决同一个迷宫,每次机会后,它都会优化其 `net` 权重。状态(`y` 和 `z`)**从一个中间循环迭代保留**到下一个,如论文的伪代码所示。它允许模型获得多次"机会"(最多 16 次)来解决同一个迷宫,每次都在改进。

这只是重复中间循环最多 16 次。如果模型觉得它得到了正确的答案,它可以决定提前停止。

为什么我们需要这个循环:

在每次中间循环迭代之后,这个外部循环更新一次权重(记住中间循环中的第 3 轮执行反向传播)。

然后在下一次迭代中,它用更新的权重重复中间循环,允许模型在每次尝试中逐步改进其解决方案。

#### 知道何时停止思考(Q 头)

外部循环最多可以运行 16 次,但不必如此。继续思考它已经解决的迷宫将是浪费时间。

因此,模型有一个称为"Q 头"的小副脑。在每个完整的思考过程(每个中间循环)之后,这个 Q 头会给出一个分数。这个分数基本上是模型的信心:"我有多确定我做对了?"

如果信心分数足够高,外部循环就会停止(`break`),模型会继续下一个迷宫。

它学会正确获得这个信心分数,因为它是训练的一部分。如果它自信*并且*正确,它会得到奖励,如果它自信但错误,它会受到惩罚。论文称之为自适应计算时间(ACT)。

---

```python
# 初始化
y, z = zeros_like(x), zeros_like(x)

# 深度监督循环(最多 16 次)
for supervision_step in range(16):

# 深度递归:热身(2 次,无梯度)
with torch.no_grad():
for _ in range(2):
# 潜在递归
for _ in range(6):
z = net(x + y + z)
y = net(y + z)

# 深度递归:最终(1 次,有梯度)
for _ in range(6):
z = net(x + y + z)
y = net(y + z)

# 学习
y_pred = output_head(y)
loss = cross_entropy(y_pred, y_true)
loss.backward()
optimizer.step()

# 我们应该停止吗?
q = Q_head(y)
if q > 0:
break
```

---

### 步骤 4:消融研究——是什么让 TRM 起作用?

![完整消融研究](/content/tiny-recursive-model/images/complete_ablation_study.png)
*图示:在迷宫求解(30x30,困难)上进行 10 个 epoch 训练,四种 TRM 配置的训练损失比较。基线(蓝色实线)使用 TRM 的标准设计:2 层网络,H=3(中间循环),L=6(内部循环),带 EMA。消融测试:移除 EMA(红色虚线),减少递归深度(绿色点划线),以及使用更大的 4 层网络(品红色点线)。*

为了理解是什么使 TRM 有效,我们系统地测试变体,移除或更改关键组件。这些**消融研究**揭示了哪些设计选择是必不可少的。

#### 实验设置

我们在迷宫求解任务(30x30 困难迷宫,1000 个训练示例)上测试四种配置:

| 配置 | 层数 | H_cycles | L_cycles | EMA | 有效深度* |
|---------------|--------|----------|----------|-----|------------------|
| **基线 TRM** | 2 | 3 | 6 | 是 | 42 |
| **无 EMA** | 2 | 3 | 6 | 否 | 42 |
| **更少递归** | 2 | 2 | 2 | 是 | 12 |
| **更大的大脑** | 4 | 3 | 3 | 是 | 48 |

*有效深度 = T × (n+1) × 层数

#### 结果

**注意:** 这些是 10 个 epoch 的实验——与论文的 50,000 多个 epoch 的运行相比,训练量非常小。更长的训练可能会显著改变这些配置的相对性能,特别是对于泛化(如我们在下面的"更大的大脑"结果中看到的)。

| 配置 | 初始损失 | 最终损失 | 最小损失 | 改进 |
|---------------|--------------|------------|----------|-------------|
| 基线 | 1.789 | 1.062 | 1.045 | 40.6% |
| 无 EMA | 1.789 | 1.042 | 1.041 | 41.7% |
| 更少递归 | **2.100** | 1.100 | 1.042 | 47.6% |
| 更大的大脑(4 层) | 1.789 | **1.007** | **1.007** | **43.7%** |

#### 关键发现

**1. "更大的大脑"悖论:短期 vs. 长期性能**

4 层网络在我们的 10 个 epoch 实验中实现了**最佳的最终损失**(1.007),比 2 层基线好约 5%。这似乎与论文的"少即是多"主张相矛盾。

**为什么不同?**
- **短期**(10 个 epoch):更多容量 = 更快学习。4 层网络可能快速记住模式。
- **长期**(50,000 多个 epoch):更多容量 = 过拟合。2 层网络*被迫*学习可重用的推理策略,而不是记住特定的解决方案。

论文的核心论点:**被迫递归思考的小型网络比大型网络泛化得更好**,即使它们最初训练得更慢。选择 2 层架构是为了防止记忆并强制依赖递归。

**2. 递归深度是基础**

"更少递归"配置(H=2,L=2)显示出严重退化的性能:
- 在任何训练之前,初始损失高出 **17%**(2.100 对 1.789)
- 实现最差的最终损失(1.100),尽管改进了 47.6%

**论文说什么:** 将递归从 T=3,n=6 减少到 T=2,n=2,数独准确率从 87.4% 下降到 73.7% ——下降 14%。

**为什么这很重要:** 高初始损失揭示了浅递归*在设计上*削弱了模型的表征能力。即使有完美的训练,也没有足够的递归"思考步骤"来解决复杂问题。**你不能用更好的训练来弥补不足的递归深度。**

**3. EMA 的短期影响最小**

移除 EMA 几乎不影响 10 个 epoch 的性能(最终损失 1.042 对基线的 1.062,只有约 2% 的差异)。

**论文说什么:** 在数独极限版上,移除 EMA 使准确率从 87.4% 下降到 79.9% ——在完整训练后下降 8%。

**为什么不同?** EMA 是模型权重的**指数移动平均**,它在长期运行中稳定训练。在短期实验中,两个模型仍在探索,还没有遇到 EMA 防止的不稳定性。在 50,000 多个 epoch 中,EMA 防止了灾难性的发散和过拟合峰值,使其对最终性能至关重要。

---

感谢您阅读本教程,我们下一个教程见。

55 changes: 55 additions & 0 deletions styles/math-styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,58 @@
border-color: rgba(139, 92, 246, 0.5);
transition: all 0.2s ease;
}

/* Custom syntax highlighting for code blocks */
.hljs {
background: #1e293b !important; /* slate-800 */
color: #e2e8f0 !important; /* slate-200 */
}

.hljs-keyword {
color: #c084fc !important; /* purple-300 */
}

.hljs-string {
color: #34d399 !important; /* emerald-400 */
}

.hljs-comment {
color: #64748b !important; /* slate-500 */
font-style: italic;
}

.hljs-number {
color: #fbbf24 !important; /* amber-400 */
}

.hljs-function {
color: #60a5fa !important; /* blue-400 */
}

.hljs-variable {
color: #f472b6 !important; /* pink-400 */
}

.hljs-type {
color: #fb7185 !important; /* rose-400 */
}

.hljs-built_in {
color: #a78bfa !important; /* violet-400 */
}

.hljs-title {
color: #60a5fa !important; /* blue-400 */
}

.hljs-params {
color: #e2e8f0 !important; /* slate-200 */
}

.hljs-attr {
color: #34d399 !important; /* emerald-400 */
}

.hljs-value {
color: #fbbf24 !important; /* amber-400 */
}
Loading