diff --git a/app/blog/diffusion-transformer-representation-autoencoder/page.tsx b/app/blog/diffusion-transformer-representation-autoencoder/page.tsx new file mode 100644 index 0000000..5fce03a --- /dev/null +++ b/app/blog/diffusion-transformer-representation-autoencoder/page.tsx @@ -0,0 +1,387 @@ +'use client'; + +import Link from "next/link"; +import { useLanguage } from "@/components/providers/language-provider"; +import { MarkdownRenderer } from "@/components/markdown-renderer"; +import { useEffect, useState } from "react"; + +interface HeroData { + title: string; + subtitle: string; + tags: string[]; +} + +export default function DiffusionTransformerRAEProject() { + const { language } = useLanguage(); + const [markdownContent, setMarkdownContent] = useState(''); + const [heroData, setHeroData] = useState(null); + const [isLoading, setIsLoading] = useState(true); + const [copySuccess, setCopySuccess] = useState(false); + + useEffect(() => { + const fetchMarkdownContent = async () => { + try { + const filename = language === 'zh' ? 'diffusion-transformer-rae-content-zh.md' : 'diffusion-transformer-rae-content.md'; + const response = await fetch(`/content/diffusion-transformer-representation-autoencoder/${filename}`); + const content = await response.text(); + + // Parse frontmatter + const frontmatterMatch = content.match(/^---\n([\s\S]*?)\n---\n([\s\S]*)$/); + if (frontmatterMatch) { + const frontmatterContent = frontmatterMatch[1]; + const markdownBody = frontmatterMatch[2]; + + // Parse YAML-like frontmatter (simple parsing for our use case) + const heroData: HeroData = { + title: "47x Faster Image & Video Generation Training", + subtitle: "Diffusion Transformers with Representation Autoencoders", + tags: ["⏱️ Technical Deep Dive", "📄 Research Article"] + }; + + // Extract values from frontmatter + const lines = frontmatterContent.split('\n'); + let currentKey = ''; + let currentArray: string[] = []; + + for (const line of lines) { + const trimmedLine = line.trim(); + if (trimmedLine.startsWith('hero:')) continue; + + if (trimmedLine.includes(':')) { + const [key, ...valueParts] = trimmedLine.split(':'); + const value = valueParts.join(':').trim().replace(/^["']|["']$/g, ''); + + switch (key.trim()) { + case 'title': + heroData.title = value; + break; + case 'subtitle': + heroData.subtitle = value; + break; + case 'tags': + currentKey = 'tags'; + currentArray = []; + break; + } + } else if (trimmedLine.startsWith('- ')) { + if (currentKey === 'tags') { + const tagValue = trimmedLine.substring(2).replace(/^["']|["']$/g, ''); + currentArray.push(tagValue); + } + } else if (trimmedLine === '' && currentArray.length > 0) { + if (currentKey === 'tags') { + heroData.tags = currentArray; + currentArray = []; + currentKey = ''; + } + } + } + + // Handle final array + if (currentArray.length > 0 && currentKey === 'tags') { + heroData.tags = currentArray; + } + + setHeroData(heroData); + setMarkdownContent(markdownBody); + } else { + // Fallback if no frontmatter + setMarkdownContent(content); + } + } catch (error) { + console.error('Failed to fetch markdown content:', error); + setMarkdownContent('# Error loading content\n\nFailed to load the article content.'); + } finally { + setIsLoading(false); + } + }; + + fetchMarkdownContent(); + }, [language]); + + const handleCopyArticle = async () => { + try { + // Get the raw markdown content without frontmatter + const filename = language === 'zh' ? 'diffusion-transformer-rae-content-zh.md' : 'diffusion-transformer-rae-content.md'; + const response = await fetch(`/content/diffusion-transformer-representation-autoencoder/${filename}`); + const content = await response.text(); + + // Remove frontmatter if present + let contentWithoutFrontmatter = content.replace(/^---\n[\s\S]*?\n---\n/, ''); + + // Remove image paths (markdown image syntax: ![alt text](image-path)) + contentWithoutFrontmatter = contentWithoutFrontmatter.replace(/!\[.*?\]\(.*?\)/g, ''); + + await navigator.clipboard.writeText(contentWithoutFrontmatter); + setCopySuccess(true); + setTimeout(() => setCopySuccess(false), 2000); + } catch (error) { + console.error('Failed to copy article:', error); + } + }; + + if (isLoading) { + return ( +
+
+
+

Loading article content...

+
+
+ ); + } + + return ( + <> + {/* Hero Section */} +
+ {/* Background effects */} +
+
+
+
+ + {/* Animated background particles */} +
+
+
+
+
+
+ +
+
+
+

+ + {heroData?.title || '47x Faster Image & Video Generation Training'} + +

+
+ {heroData?.subtitle || 'Diffusion Transformers with Representation Autoencoders'} +
+ + {/* Tags */} + {heroData?.tags && heroData.tags.length > 0 && ( +
+ {heroData.tags.map((tag, index) => ( + + {index > 0 && } + + {tag.includes('⏱️') && ( + + + + )} + {tag.includes('📄') && ( + + + + )} + {tag.replace(/[⏱️📄]/g, '').trim()} + + + ))} +
+ )} + + {/* Links to Paper and GitHub */} + + + {/* Glow effect for the title */} +
+ + {heroData?.title || '47x Faster Image & Video Generation Training'} + +
+
+
+
+
+ + {/* Main Content */} +
+
+ {/* Article Container */} +
+ {/* Content Card */} +
+ {/* Copy Button at Top */} +
+
+
+ + + {/* Tooltip */} +
+ {language === 'en' + ? 'Perfect for pasting into AI chatbots for self-studying! 🤖' + : '非常适合粘贴到AI聊天机器人进行自学!🤖' + } + {/* Tooltip arrow */} +
+
+
+
+
+ + {/* Article Body */} +
+
+ +
+
+ + {/* Article Footer */} +
+
+
+ + + + + Open Superintelligence Lab + +
+
+ Share + + {/* Copy Article Button */} +
+ + + {/* Tooltip */} +
+ {language === 'en' + ? 'Perfect for pasting into AI chatbots for self-studying! 🤖' + : '非常适合粘贴到AI聊天机器人进行自学!🤖' + } + {/* Tooltip arrow */} +
+
+
+ + + + + + + + + + + +
+
+
+
+ + {/* Navigation */} +
+ + + + + {language === 'en' ? 'Back to Home' : '返回首页'} + + +
+ Scroll to + +
+
+
+
+
+ + ); +} + + diff --git a/app/page.tsx b/app/page.tsx index f0d2e0d..19acf3b 100644 --- a/app/page.tsx +++ b/app/page.tsx @@ -2,15 +2,13 @@ import Link from "next/link"; import { useLanguage } from "@/components/providers/language-provider"; -import { useEffect, useState } from "react"; export default function Home() { const { language } = useLanguage(); - const [isLocalhost, setIsLocalhost] = useState(false); - useEffect(() => { - setIsLocalhost(window.location.hostname === 'localhost' || window.location.hostname === '127.0.0.1'); - }, []); + const getText = (en: string, zh: string) => { + return language === 'zh' ? zh : en; + }; return ( <> @@ -70,44 +68,34 @@ export default function Home() {

- {language === 'en' ? ( - <> - Open - Superintelligence - - ) : ( - <> - 开放 - 超级智能 - - )} + + {getText('Open', '开放')} + + + {getText('Superintelligence', '超级智能')} +

{/* Glow effect for the entire title */}
- {language === 'en' ? ( - <> - Open - Superintelligence - - ) : ( - <> - 开放 - 超级智能 - - )} + + {getText('Open', '开放')} + + + {getText('Superintelligence', '超级智能')} +
{/* Subtitle */}

- The Most Difficult Project In Human History + {getText('The Most Difficult Project In Human History', '人类历史上最困难的项目')}

{/* Glow effect for subtitle */}
- The Most Difficult Project In Human History + {getText('The Most Difficult Project In Human History', '人类历史上最困难的项目')}
@@ -137,15 +125,15 @@ export default function Home() {
- {language === 'en' ? 'Open Source' : '开源'} + {getText('Open Source', '开源')}
- {language === 'en' ? 'LLM Research' : '大模型研究'} + {getText('LLM Research', '大模型研究')}
- {language === 'en' ? 'Innovation' : '创新'} + {getText('Innovation', '创新')}
@@ -156,7 +144,7 @@ export default function Home() { className="group px-8 py-4 bg-gradient-to-r from-blue-600 to-purple-600 text-white font-semibold rounded-xl hover:from-blue-700 hover:to-purple-700 transition-all duration-300 transform hover:scale-105 shadow-lg hover:shadow-2xl hover:shadow-blue-500/25" > - Explore & Participate + {getText('Explore & Participate', '探索并参与')} @@ -167,7 +155,7 @@ export default function Home() { className="group px-8 py-4 border-2 border-slate-600 text-slate-300 font-semibold rounded-xl hover:border-blue-500 hover:text-blue-400 transition-all duration-300 transform hover:scale-105" > - Learn More + {getText('Learn More', '了解更多')} @@ -182,44 +170,16 @@ export default function Home() {
- {/* Road to AI Researcher Project */} - {/* -
- Learning Path -
-
- New -
- -
-

- Zero To AI Researcher - Full Course -

-

- A comprehensive journey into becoming an AI researcher, covering everything from foundational concepts to cutting-edge research methodologies -

-
- Open Superintelligence Lab - - Start Learning → - -
-
- */} - {/* DeepSeek Sparse Attention Project */}
- Research + {getText('Research', '研究')}
- Featured + {getText('Featured', '精选')}
@@ -227,12 +187,15 @@ export default function Home() { DeepSeek Sparse Attention - DeepSeek-V3.2-Exp

- Advanced research on DeepSeek's innovative sparse attention mechanisms for efficient long-context processing and memory optimization + {getText( + 'Advanced research on DeepSeek\'s innovative sparse attention mechanisms for efficient long-context processing and memory optimization', + 'DeepSeek创新稀疏注意力机制的高级研究,用于高效长上下文处理和内存优化' + )}

DeepSeek Research - Learn More → + {getText('Learn More', '了解更多')} →
@@ -244,23 +207,26 @@ export default function Home() { className="group relative bg-gradient-to-br from-slate-800/50 to-slate-700/50 backdrop-blur-sm border border-slate-600/50 rounded-xl p-6 hover:border-purple-500/50 hover:shadow-2xl hover:shadow-purple-500/10 transition-all duration-300" >
- Research + {getText('Research', '研究')}
- Latest + {getText('Latest', '最新')}

- Tiny Recursive Model + {getText('Tiny Recursive Model', '微型递归模型')}

- How a 7M parameter model beats 100x bigger models at Sudoku, Mazes, and ARC-AGI using recursive reasoning with a 2-layer transformer + {getText( + 'How a 7M parameter model beats 100x bigger models at Sudoku, Mazes, and ARC-AGI using recursive reasoning with a 2-layer transformer', + '7M参数模型如何使用2层transformer的递归推理在数独、迷宫和ARC-AGI上击败100倍大的模型' + )}

- AI Research + {getText('AI Research', 'AI研究')} - Learn More → + {getText('Learn More', '了解更多')} →
@@ -272,380 +238,64 @@ export default function Home() { className="group relative bg-gradient-to-br from-slate-800/50 to-slate-700/50 backdrop-blur-sm border border-slate-600/50 rounded-xl p-6 hover:border-green-500/50 hover:shadow-2xl hover:shadow-green-500/10 transition-all duration-300" >
- Research + {getText('Research', '研究')}
- Featured + {getText('Featured', '精选')}

- Pretrain LLM with NVFP4 + {getText('Pretrain LLM with NVFP4', '用NVFP4预训练LLM')}

- NVIDIA's breakthrough 4-bit training methodology achieving 2-3x speedup and 50% memory reduction without sacrificing model quality + {getText( + 'NVIDIA\'s breakthrough 4-bit training methodology achieving 2-3x speedup and 50% memory reduction without sacrificing model quality', + 'NVIDIA突破性4位训练方法实现2-3倍加速和50%内存减少,同时不牺牲模型质量' + )}

NVIDIA Research - Explore → + {getText('Explore', '探索')} →
- {/* The Most Difficult Project in Human History */} - {/* -
- Mission -
-
- New -
- -
-

- The Most Difficult Project in Human History -

-

- Learn to implement all latest AI tech (Mamba, Gated DeltaNet, RWKV) and become an expert at fast experimentation and research -

-
- Open Superintelligence Lab - - Start Journey → - -
-
- */} - - {/* MobileLLM-R1 Project - HIDDEN */} - {/*
- Research + {getText('Research', '研究')}
- Latest + {getText('New', '新')}
-

- MobileLLM-R1 - Sub-Billion Reasoning +

+ {getText('47x Faster Image Generation Training', '图像生成训练加速47倍')}

- Meta's breakthrough in small-scale reasoning: 950M parameter model achieving AIME 15.5 with only 2T high-quality tokens + {getText( + 'Diffusion Transformers with Representation Autoencoders achieve state-of-the-art FID 1.13 on ImageNet while training 47x faster (80 vs 1400 epochs)', + '扩散变换器与表示自编码器在ImageNet上实现最先进的FID 1.13,同时训练速度提升47倍(80轮对比1400轮)' + )}

- Meta AI Research - - Learn More → + {getText('MIT-Han Lab', 'MIT韩松实验室')} + + {getText('Learn More', '了解更多')} →
- */} +
- - {/* Ideas Section - Only visible on localhost */} - {isLocalhost && ( -
-
-
-

- {language === 'en' ? 'Research Ideas & Concepts' : '研究想法与概念'} -

-
- {/* GLM4-MoE Project */} - -
- Research -
-
- MoE -
- -
-

- GLM4-MoE -

-

- Advanced Mixture of Experts implementation with GLM4 architecture -

-
- THUDM Research - - Explore → - -
-
-
- - {/* DeepSeek Attention + GLM4-MoE Project */} - -
- Research -
-
- Hybrid -
- -
-

- DeepSeek Attention + GLM4-MoE -

-

- Combination of DeepSeek's sparse attention with GLM4's MoE -

-
- Open Superintelligence Lab - - Learn More → - -
-
-
- - {/* SLA Sparse-Linear Attention Project */} - -
- Research -
-
- New -
- -
-

- SLA: Sparse-Linear Attention -

-

- 20x speedup with minimal quality loss in diffusion transformers -

-
- Tsinghua University - - Learn More → - -
-
- - - {/* SDLM Sequential Diffusion Language Model Project */} - -
- Research -
-
- Latest -
- -
-

- SDLM: Sequential Diffusion Language Model -

-

- Adaptive generation length with 2x speedup and KV-cache compatibility -

-
- OpenGVLab - - Learn More → - -
-
- -
-
-
-
- )} - - {/* Localhost Todos Section - Only visible on localhost */} - {isLocalhost && ( -
-
-
-

- {language === 'en' ? 'Development Todos' : '开发待办事项'} -

- - {/* Dream-Coder 7B Todo */} -
-
-
-
-
-
-

- Dream-Coder 7B Exploration -

-

- Explore Dream-Coder 7B - diffusion LLM for code generation with 21.4% pass@1 on LiveCodeBench -

-
-
-
- Review GitHub repository: https://github.com/DreamLM/Dream-Coder -
-
-
- Understand diffusion LLM architecture for code -
-
-
- Test code generation capabilities -
-
-
- Experiment with flexible generation patterns -
-
-
- Install dependencies: transformers==4.46.2 and torch==2.5.1 -
-
-
-
-
- Features: Emergent any-order generation, variable-length infilling, open-source trained -
-
-
-
-
- - {/* NVFP4 Research Todo */} -
-
-
-
-
-
-

- NVFP4 LLM Pretraining Research -

-

- Research NVIDIA's NVFP4 (4-bit floating point) training methodology - 2-3x performance boost with 50% memory reduction -

-
-
-
- Study paper: https://arxiv.org/pdf/2509.25149 -
-
-
- https://github.com/NVIDIA/TransformerEngine/pull/2177/files - Code -
-
-
- Learn Random Hadamard Transforms (RHT) for outlier bounding -
-
-
- Study two-dimensional quantization scheme -
-
-
- Research stochastic rounding for unbiased gradients -
-
-
- Set up experimental environment for FP4 training -
-
-
- Study TransformerEngine NVFP4 implementation (PR #2177) -
-
-
- Test NVFP4 support with fusible operations -
-
-
-
-
- Results: 12B model on 10T tokens, MMLU-pro 62.58% (vs 62.62% FP8) - first successful 4-bit billion-parameter training -
-
-
-
-
- - {/* MobileLLM-R1 Research Todo */} -
-
-
-
-
-
-

- MobileLLM-R1 Sub-Billion Reasoning Research -

-

- Research Meta's MobileLLM-R1 - sub-billion parameter reasoning models with strong capabilities using only ~2T high-quality tokens -

-
-
-
- Study paper: https://arxiv.org/pdf/2509.24945 -
-
-
- Understand data curation and resampling techniques (~2T tokens) -
-
-
- Learn benchmark-free, self-evolving data optimization -
-
-
- Study data-model co-evolution strategy for mid-training -
-
-
- Analyze training recipe: 4.2T tokens from resampled ~2T tokens -
-
-
- Test MobileLLM-R1-950M model capabilities -
-
-
-
-
- Results: AIME 15.5 vs OLMo-2-1.48B (0.6), matches Qwen3-0.6B with only 11.7% of tokens (4.2T vs 36T) -
-
-
-
-
-
-
-
- )} ); -} \ No newline at end of file +} diff --git a/components/markdown-renderer.tsx b/components/markdown-renderer.tsx index 8a7e3e6..30b95ab 100644 --- a/components/markdown-renderer.tsx +++ b/components/markdown-renderer.tsx @@ -6,7 +6,7 @@ import remarkMath from 'remark-math'; import rehypeHighlight from 'rehype-highlight'; import rehypeKatex from 'rehype-katex'; import Image from 'next/image'; -import 'highlight.js/styles/github-dark.css'; +import 'highlight.js/styles/atom-one-dark.css'; import 'katex/dist/katex.min.css'; import '../styles/math-styles.css'; @@ -92,8 +92,10 @@ export function MarkdownRenderer({ content }: MarkdownRendererProps) { ); }, pre: ({ children }) => ( -
-              {children}
+            
+              
+                {children}
+              
             
), // Custom blockquote styles @@ -106,9 +108,19 @@ export function MarkdownRenderer({ content }: MarkdownRendererProps) { img: ({ src, alt }) => { if (!src) return null; - // Check if this is the architecture diagram that should be smaller + // Check if this is the architecture comparison diagram that should be larger + const isArchitectureComparison = alt?.includes('SD-VAE vs RAE') || (typeof src === 'string' && src.includes('architecture-comparison')); + // Check if this is other architecture diagrams that should be smaller const isArchitectureDiagram = alt?.includes('Architecture') || (typeof src === 'string' && src.includes('architecture')); - const imageClassName = isArchitectureDiagram ? "w-1/2 h-auto mx-auto" : "w-full h-auto"; + + let imageClassName; + if (isArchitectureComparison) { + imageClassName = "w-full h-auto"; // Full width for comparison diagrams + } else if (isArchitectureDiagram) { + imageClassName = "w-1/2 h-auto mx-auto"; // Half width for other architecture diagrams + } else { + imageClassName = "w-full h-auto"; // Default full width + } // Handle external images if (typeof src === 'string' && src.startsWith('http')) { diff --git a/public/content/diffusion-transformer-representation-autoencoder/diffusion-transformer-rae-content.md b/public/content/diffusion-transformer-representation-autoencoder/diffusion-transformer-rae-content.md new file mode 100644 index 0000000..294a409 --- /dev/null +++ b/public/content/diffusion-transformer-representation-autoencoder/diffusion-transformer-rae-content.md @@ -0,0 +1,390 @@ +--- +hero: + title: "47x Faster Image Generation Training" + subtitle: "Diffusion Transformers with Representation Autoencoders" + tags: + - "⏱️ Technical Deep Dive" + - "📄 Research Article" +--- + +This paper introduces a new method that allows Diffusion Transformers to: + +* 🚀 **Train up to 47x Faster:** Achieve state-of-the-art results in a fraction of the time (e.g., 80 epochs vs. 1400+). +* 🏆 **Achieve Better Image Quality:** Sets a new state-of-the-art FID score of 1.13 on ImageNet. + +## The New Foundation for Image (and later Video) Generation + +High-quality image generators like the **Diffusion Transformer (DiT)** are powerful, but for years they've been held back by a critical component: the autoencoder. This paper argues that the standard **Variational Autoencoder (VAE)** from Stable Diffusion is outdated and creates a bottleneck. + +They introduce the **Representation Autoencoder (RAE)**, a new approach that leverages powerful pretrained vision models to create a richer, more meaningful latent space for diffusion. + +In this deep dive, we'll explore how RAEs work, the challenges they solve, and how they set a new standard for generative modeling. + +--- + +### Step 1: The Core Problem: The "Old Way" Has a Bottleneck + +Modern image generators don't work with pixels directly. It's too slow and computationally expensive. Instead, they first use an **autoencoder** to compress a high-resolution image into a small, dense "latent" representation. The diffusion model then learns to generate these latents, which are finally decoded back into a full image. + +For years, the go-to choice has been the **VAE from Stable Diffusion (SD-VAE)**. While revolutionary at the time, the authors argue it's now holding back progress. They identify three key problems: + +* **Outdated Architecture:** The SD-VAE is built on older, less efficient convolutional network designs. +* **Weak Representations:** It's trained *only* to reconstruct images. This means its latent space is good at capturing textures and local details but lacks a deep understanding of the image's content (semantics). It doesn't inherently know that a "dog" and a "cat" are conceptually different, only that they have different pixel patterns. +* **Information Bottleneck:** The VAE aggressively compresses images into a very low-dimensional latent space, which limits the amount of detail that can be preserved. + +This bottleneck means that even if you improve the main diffusion model, its final output quality is capped by what the VAE can represent and reconstruct. + +![SD-VAE vs RAE Comparison](/content/diffusion-transformer-representation-autoencoder/images/architecture-comparison.png) +*Figure 2: Comparison of SD-VAE and RAE (DINOv2-B) architectures. The SD-VAE uses U-Net-like convolutional components with aggressive downsampling to a 4-dimensional latent space, requiring 135 GFlops for encoding and 310 GFlops for decoding. The RAE uses ViT blocks without compression to a 768-dimensional latent space, requiring only 22 GFlops for encoding and 106 GFlops for decoding - making it approximately 6x more efficient for encoding and 3x more efficient for decoding.* + +--- + +### Step 2: The Proposed Solution: The "New Way" with RAEs + +The paper introduces a new autoencoder called a **Representation Autoencoder (RAE)**. The idea is simple but powerful: instead of training an autoencoder from scratch just for reconstruction, why not leverage the massive progress made in visual representation learning? + +An RAE has two key parts: + +1. **A Frozen, Pretrained Encoder:** It uses a powerful, off-the-shelf vision model (like **DINOv2**, SigLIP, or MAE) that is already an expert at understanding images. These models are trained on massive datasets to produce rich, high-dimensional representations packed with semantic meaning. This encoder is **frozen**, meaning its weights are not changed during training. +2. **A Trained Decoder:** A lightweight, transformer-based decoder is then trained to do one job: perfectly reconstruct the original image from the rich features provided by the frozen encoder. + +This design creates a latent space that is both **semantically rich** (thanks to the expert encoder) and optimized for **high-fidelity reconstruction** (thanks to the trained decoder). + +Here's how the `RAE` is structured in code. The key idea is simple: combine a frozen pretrained encoder with a trainable decoder. + +```python:src/stage1/rae.py +class RAE(nn.Module): + def __init__(self, + # Choose which pretrained vision model to use as the encoder + encoder_cls: str = 'Dinov2withNorm', # e.g., DINOv2, SigLIP, MAE + encoder_params: dict = None, # Model-specific parameters + + # Configure the decoder architecture + decoder_config_path: str = 'vit_mae-base', # HuggingFace model name + latent_dim: int = 768, # Must match encoder output + base_patches: int = 256, # Number of spatial patches (16×16) + ): + super().__init__() + + # === PART 1: The frozen, pretrained encoder === + # Load the architecture class from a registry + EncoderClass = ARCHS[encoder_cls] + # Instantiate with pretrained weights (will be frozen during training) + self.encoder = EncoderClass(**(encoder_params or {})) + + # Store dimensions for later use + self.latent_dim = latent_dim # e.g., 768 for DINOv2-B + self.base_patches = base_patches # 16×16 = 256 spatial locations + + # === PART 2: The lightweight, trainable decoder === + # Load decoder config from HuggingFace (e.g., ViT-MAE architecture) + decoder_config = AutoConfig.from_pretrained(decoder_config_path) + # Adjust hidden size to match our encoder's output dimension + decoder_config.hidden_size = self.latent_dim + # Create the decoder (this will be trained) + self.decoder = GeneralDecoder(decoder_config, num_patches=self.base_patches) +``` + +During training, the script explicitly sets the encoder to evaluation mode and disables its gradients, ensuring that only the decoder learns. + +```python:src/train_stage1.py +# In the main training script +rae: RAE = instantiate_from_config(rae_config).to(device) +# Freeze the encoder +rae.encoder.eval() +rae.encoder.requires_grad_(False) +# Train the decoder +rae.decoder.train() +rae.decoder.requires_grad_(True) +``` + +--- + +### Step 3: Making RAEs Work: Solving New Challenges + +Switching to RAEs isn't a simple drop-in replacement. Their rich, high-dimensional latent spaces create new problems for Diffusion Transformers, which were designed for the VAE's small, simple space. The paper identifies and solves three main issues: + +**1. Challenge: A standard DiT struggles with RAE's high-dimensional tokens.** + +* **Observation:** The authors first found that a standard DiT, which works well with low-dimensional VAE latents, fails to train properly on the high-dimensional latents from an RAE. A small DiT fails completely, and even a large one underperforms significantly. + +* **Experiment (The "How"):** To understand why, they designed a simple test: can a DiT learn to perfectly reconstruct a *single* image encoded by RAE? + * They found that the DiT could only succeed if its internal hidden dimension (its "width") was **greater than or equal to** the dimension of the RAE's output tokens (e.g., 768 for DINOv2-B). + * If the DiT was too "narrow" (width < token dimension), it failed to reconstruct the image, no matter how "deep" they made the model (i.e., adding more layers didn't help). + +* **Explanation (The "Why"):** The paper gives a theoretical reason for this width requirement. + * The diffusion process works by adding noise. This noise spreads the data across the *entire* high-dimensional latent space. The data no longer lies on a simple, low-dimensional manifold. + +> #### Deep Dive: The Dimensionality Bottleneck +> +> Every dimension in the rich representataion encoder is important. If the diffusion tranasformer has less dimensiosn, it will lose information and will not be able to reconstruct the image. +> +> To understand the problem, let's use a simpler analogy with colors. +> +> * **The Latent Space:** Imagine the entire RGB color space (3 dimensions: Red, Green, Blue). This is our high-dimensional space. +> * **The "Manifold" of Valid Data:** Now, imagine our goal is to only generate shades of gray. These "valid" points (where R=G=B) form a straight line — a 1D manifold — running through the 3D color space. +> +> 1. **Noise Pushes Data Off the Manifold:** The diffusion process starts with a pure gray color (on the line) and adds random noise. This is like adding a bit of red and green, pushing the color off the "gray line" into the full 3D space (e.g., creating a muddy brown). +> +> 2. **The Denoising Task:** The DiT's job is to take that muddy brown color and figure out which shade of gray it came from. +> +> 3. **The Bottleneck:** A "narrow" DiT is like trying to solve this problem while being **red-green colorblind**. It can't see the 'R' and 'G' dimensions. It sees the muddy brown and the pure gray as having the same brightness, but it has lost the color information required to correctly "pull" the brown back to the gray line. This is an **information bottleneck**. +> +> **The solution:** The DiT's internal "width" or hidden size must be at least as large as the number of dimensions in the latent space (e.g., 768). This ensures has enough dimensions to encode / understand same information and can reverse the noise process accurately. + +* A DiT with a narrow width acts as an information bottleneck. The input and output linear projections of its transformer blocks constrain the model to operate within a lower-dimensional subspace - imagine each number in a vector as a dimension, if DiT has less dimensions it literally has less "storage" to store information. +* This architectural limitation makes it mathematically impossible for the narrower model to fully represent the data and reverse the noise, leading to high error and poor results. This is formalized in the paper's **Theorem 1**. + +* **Solution:** The straightforward solution is to ensure the DiT's width is scaled to be at least as large as the RAE's token dimension. + +--- + +#### 🔬 Experimental Validation: Single-Image Overfitting Test + +To verify this theory, we replicated the paper's single-image overfitting experiment using a real cat photo. The goal: train DiT models with different widths to reconstruct a single image encoded by DINOv2-B (768-dimensional tokens). + +**Setup:** +- **Image:** Real cat photo (256×256) +- **Encoder:** DINOv2-B with 768-dimensional tokens +- **Training:** 1200 steps with varying DiT widths +- **Test:** Can the model "overfit" and perfectly reconstruct this one image? + +**Results:** + +| DiT Width | Final Loss | Width ≥ 768? | Reconstruction Quality | Status | +|-----------|-----------|---------------|----------------------|--------| +| 384 | 0.671 | ❌ (384 < 768) | Poor, blurry | **Failed** | +| 768 | 0.197 | ✅ (768 = 768) | Good, recognizable | **Success** | +| 896 | 0.135 | ✅ (896 > 768) | **"Almost perfect"** | **Success** | + +![Overfitting to a single sample](/content/diffusion-transformer-representation-autoencoder/images/overfitting-to-single-image-dimension-vs-loss.png) +*Figure 3: Overfitting to a single sample. Left: increasing model width leads to lower loss and better sample quality; Right: changing model depth has marginal effect on overfitting results.* + +**Visual Evidence:** + +![Cat Reconstructions with 1200 training steps](/content/diffusion-transformer-representation-autoencoder/images/cat_reconstructions_1200steps.png) +*Left to right: Original cat, Width 384 (failed), Width 768 (good), Width 896 (almost perfect)* + +![Loss Curves](/content/diffusion-transformer-representation-autoencoder/images/cat_loss_curves_1200steps.png) +*Training loss over 1200 steps. Note how width 384 cannot converge, while 768 and 896 successfully minimize loss.* + +**Key Findings:** +1. ✅ **Width < 768 fails completely** - Loss stays high (~0.67) and reconstruction is poor +2. ✅ **Width = 768 works** - Loss drops to 0.20, producing recognizable reconstructions +3. ✅ **Width > 768 is better** - Loss drops to 0.14, achieving "almost perfect" reconstruction as stated in the paper + +This confirms the paper's Theorem 1: **DiT width must match or exceed the token dimension for successful generation in high-dimensional RAE latent spaces.** + +> 💡 **Important Note:** The paper states the DiT "reproduces the input **almost perfectly**" (not perfectly). Our results with loss ~0.14 for width 896 align perfectly with this expectation. + +--- + +**2. Challenge: Standard noise schedules are poorly suited for high dimensions.** + +* **Finding:** A standard noise schedule, which works well for VAEs, is too "easy" for the high-dimensional latents of RAEs. At the same noise level, the RAE's information-rich tokens are less corrupted than the VAE's, which impairs the model's training. + +> #### Deep Dive: The "Corrupted Message" Analogy +> +> Imagine trying to corrupt a secret message with random errors. +> +> 1. **Low Dimension (like VAE):** The message is a short phrase: `THE CAT SAT`. It has 11 characters. If you introduce 3 random errors (e.g., `THX CPT SQT`), the message is significantly damaged and hard to decipher. +> +> 2. **High Dimension (like RAE):** The message is a full paragraph with 768 characters. If you introduce the same 3 random errors, the overall meaning of the paragraph is barely affected. The original information is still overwhelmingly present. +> +> This is exactly what happens in diffusion. The RAE's 768-dimensional tokens are so information-dense that a standard level of noise doesn't corrupt them enough. The model is never forced to learn from truly difficult, noisy examples, so it fails to generalize. +> +* **Solution:** The paper implements a **dimension-dependent noise schedule shift**. This is like adjusting the difficulty of the training curriculum. It mathematically "shifts" the schedule to apply much stronger noise at earlier stages of training, forcing the model to work harder and learn more effectively from the high-dimensional RAE latents. + +--- + +#### 🔬 Experimental Validation: Noise Schedule Shift + +**Experiment Goal:** Test if the dimension-dependent noise schedule shift actually improves training on real data. + +**Setup:** We trained two identical DiT models on 2,000 CIFAR-10 images for 10 epochs: +- **Control (A):** Standard noise schedule (`time_dist_shift = 1.0`) +- **Experiment (B):** Dimension-dependent shift (`time_dist_shift = α = 6.93`) +- Both use: DINOv2-B encoder (768-dim tokens), DiT width=768, depth=12, AdamW optimizer + +##### Calculating Alpha (the Shift Parameter) + +**Understanding the Latent Shape:** + +The DINOv2-B encoder outputs a 3D tensor: `[batch, channels, height, width]` = `[B, 768, 16, 16]` +- **768 channels:** Each spatial location has a 768-dimensional feature vector +- **16 × 16 grid:** 256 spatial locations total +- **Total dimension:** 768 × 256 = **196,608 numbers** per image + +```python +# Step 1: Calculate effective dimension +effective_dim = 768 × (16 × 16) # channels × height × width + = 768 × 256 # channels × spatial_locations + = 196,608 # total numbers in the latent + +# Step 2: Compare to VAE baseline +base_dim = 4096 # Typical VAE latent dimension (e.g., Stable Diffusion uses ~4×64×64) + +# Step 3: Calculate scaling factor +alpha = sqrt(effective_dim / base_dim) + = sqrt(196,608 / 4,096) + = sqrt(48) + = 6.93 +``` + +**Why sqrt?** In high-dimensional spaces, variance scales with dimensionality. To maintain the same "relative noise strength," we scale by √(dimension_ratio), not the ratio itself. + +##### The Results: A Clear Winner + +| Configuration | Final Loss | Improvement | +|--------------|-----------|-------------| +| **WITHOUT shift** (α = 1.0) | 1.1326 | Baseline | +| **WITH shift** (α = 6.93) | 0.9668 | **14.6% better** ✅ | + +![CIFAR-10 Loss Comparison](/content/diffusion-transformer-representation-autoencoder/images/cifar10_loss_comparison.png) + +The model trained with the noise schedule shift (orange line) achieves consistently lower loss throughout all 10 epochs. This validates the paper's theory on real data, not just single-image overfitting. + +##### Why This Matters + +**The Intuition:** In high-dimensional spaces, the same amount of noise has less relative impact due to information being spread across more dimensions. + +- **VAE (4K dims):** 10% noise significantly corrupts the signal +- **RAE (196K dims, no shift):** Same 10% noise is relatively weaker—model has an easier task +- **RAE (196K dims, α=6.93 shift):** Noise scaled by ~7×, creating comparable difficulty to VAE + +The √48 ≈ 7× scaling compensates for how variance behaves in high dimensions, forcing the model to learn robust denoising instead of exploiting redundancy. + +--- + +##### Implementation: A Single Parameter Change + +The key difference was a single parameter, `time_dist_shift`, which alters the distribution of noise levels during training. A higher value shifts the distribution toward higher noise, forcing the model to learn a more robust denoising function. This single change yielded a 14.6% improvement in final loss on our CIFAR-10 test. + +```diff + # In the transport configuration + transport = create_transport( + path_type='Linear', + prediction='velocity', + time_dist_shift=6.93, # Dimension-dependent shift (Final Loss: 0.9668) + ) +``` + +> 💡 **Key Takeaway:** The dimension-dependent noise schedule shift is simple to implement (one parameter), theoretically grounded (scales with √dimension), and empirically validated (14.6% improvement on real data). For high-dimensional RAE latents, this adjustment is essential for effective diffusion training. + +--- + +**3. Challenge: The RAE decoder is fragile.** +* **Finding:** The RAE decoder is trained to reconstruct images from the "perfect," clean outputs of the encoder. However, a diffusion model at inference time generates slightly imperfect latents. This mismatch can degrade the final image quality. +* **Solution:** They use **noise-augmented decoding**. During the decoder's training, they add a small amount of random noise to the encoder's outputs. This makes the decoder more robust and better at handling the imperfect latents generated by the diffusion model. + +This robustness is achieved with a simple `noising` function applied to the latent code `z` during the `encode` step. + +```python:src/stage1/rae.py +class RAE(nn.Module): + # ... + def noising(self, x: torch.Tensor) -> torch.Tensor: + # Add a random amount of noise during training + noise_sigma = self.noise_tau * torch.rand(...) + noise = noise_sigma * torch.randn_like(x) + return x + noise + + def encode(self, x: torch.Tensor) -> torch.Tensor: + # ... + z = self.encoder(x) + # Apply noise augmentation only during training + if self.training and self.noise_tau > 0: + z = self.noising(z) + # ... + return z +``` + +--- + +#### 🔬 Experimental Validation: Decoder Fragility + +**The Problem:** RAE decoders are trained to reconstruct from *perfect* encoder outputs. But DiT models generate *imperfect* latents. How much does this hurt? + +##### Experiment: Testing Decoder Robustness + +We used the **pretrained RAE decoder** (trained with `noise_tau=0`, meaning NO noise augmentation) and tested its sensitivity to latent noise: + +**Setup:** +1. Took 6 diverse CIFAR-10 images +2. Encoded them to clean latents using frozen DINOv2 encoder +3. Added varying amounts of noise to latents (σ = 0.0 to 2.0) +4. Decoded with the pretrained decoder +5. Measured PSNR degradation + +**This simulates what happens when a DiT generates imperfect latents at inference.** + +##### Results: The Decoder IS Fragile + +| Latent Noise (σ) | Avg PSNR | Quality Degradation | +|-----------------|----------|---------------------| +| 0.0 (clean) | 25.87 dB | Baseline ✅ | +| 0.1 | 25.79 dB | -0.08 dB (minimal) | +| 0.3 | 25.66 dB | -0.21 dB (noticeable) | +| 0.5 | 24.40 dB | **-1.47 dB** ⚠️ | +| 1.0 | 22.97 dB | **-2.90 dB** ❌ | +| 2.0 | 19.68 dB | **-6.19 dB** ❌❌ | + +![Decoder Fragility Visual Comparison](/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_visual.png) + +**Visual Evidence:** The image above shows 6 CIFAR-10 examples reconstructed at different noise levels. Notice how quality degrades rapidly as latent noise increases. By σ=2.0, images are severely blurred. + +##### The Solution: Noise-Augmented Training (Validated!) + +We fine-tuned two decoder versions on 500 CIFAR-10 images for 15 epochs to test if noise augmentation actually helps: + +**Decoder A:** Trained with `noise_tau = 0` (no augmentation) → expects perfect latents +**Decoder B:** Trained with `noise_tau = 0.5` (with augmentation) → handles noisy latents + +**Results on Noisy Test Latents:** + +| Latent Noise (σ) | No Aug PSNR | With Aug PSNR | Improvement | +|-----------------|-------------|---------------|-------------| +| 0.0 (clean) | 25.65 dB | 25.30 dB | -0.35 dB (baseline) | +| 0.3 | 24.87 dB | **25.73 dB** | **+0.86 dB** ✅ | +| 0.6 | 24.84 dB | 24.44 dB | -0.40 dB | +| 1.0 | 22.39 dB | **22.91 dB** | **+0.53 dB** ✅ | + +![Decoder Robustness PSNR](/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_psnr.png) + +**Key Findings:** + +1. At moderate noise levels (σ=0.3, 1.0), noise augmentation provides **+0.53 to +0.86 dB improvements** ✅ +2. On clean latents, the non-augmented decoder is slightly better (expected—it's specialized for this) +3. **The tradeoff is worth it:** Small loss on perfect inputs, but better handling of realistic DiT outputs + +![Visual Comparison](/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_visual.png) + +> 💡 **Key Takeaway:** Noise augmentation (`noise_tau = 0.5-0.8`) makes decoders measurably more robust (+0.5-0.9 dB) to the imperfect latents generated by DiT models, with minimal cost on clean inputs. This simple technique is essential for RAE-based generation. + +--- + +### Step 4: A More Efficient Architecture: DiT DH + +Instead of making the entire DiT wider, make just the last few layers wider. It also takes the initial noised latent together with the output from the previous layer. + +We've established that a DiT needs to be *wide* to handle the rich, high-dimensional tokens from an RAE. But making the entire transformer network wide is incredibly expensive due to the quadratic cost of attention mechanisms. This creates a dilemma: how can we get the necessary width without a massive computational budget? + +To solve this, the authors propose a clever architectural improvement called **DiT DH** (Diffusion Transformer with a DDT Head). + +The core idea is to split the DiT into two specialized parts, creating a "best of both worlds" design: + +1. **The Body (Deep & Narrow):** The first part of the network is a standard, deep DiT with a *narrow* hidden dimension (e.g., 768). This is the workhorse of the model. Its many layers are responsible for the complex, core processing: understanding the image's semantics, learning relationships between features, and performing the bulk of the denoising steps. Because it's narrow, it does this work very efficiently. + +2. **The Head (Shallow & Wide):** At the very end of the network, they attach a **DDT Head**—a small number of transformer layers (e.g., 2) that are exceptionally *wide* (e.g., 2048). This head has one job: take the highly processed features from the deep body and perform the crucial final steps of the **reverse diffusion (denoising) process**. It's an active transformer module, not just a simple projection layer, that handles the final prediction in the high-dimensional RAE latent space. It provides the critical width needed to avoid the information bottleneck we discussed in Step 3, but only for the last few steps where it's absolutely necessary. +3. Different Inputs: This is the most critical distinction. A standard Transformer block in the DiT backbone takes the output of the immediately preceding block as its main input. The DDT head, however, takes two distinct inputs: +The original noisy latent xt. +The processed representation zt from the entire main DiT backbone (M). + +This design gives the model the width it needs to handle RAE's high-dimensional space without making the entire network wide. It's like having a specialized, high-bandwidth output port attached to an efficient processing core. + +--- + +The paper makes a strong case that the VAE bottleneck is real and that RAEs are the solution. By effectively bridging the gap between state-of-the-art representation learning and generative modeling, RAEs offer clear advantages and should be considered the **new default foundation** for training future diffusion models. + +--- + +Thank you for reading tihs tutorial and see you in the next one. \ No newline at end of file diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/architecture-comparison.png b/public/content/diffusion-transformer-representation-autoencoder/images/architecture-comparison.png new file mode 100644 index 0000000..b249a86 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/architecture-comparison.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/cat_loss_curves_1200steps.png b/public/content/diffusion-transformer-representation-autoencoder/images/cat_loss_curves_1200steps.png new file mode 100644 index 0000000..50a1b2f Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/cat_loss_curves_1200steps.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/cat_reconstructions_1200steps.png b/public/content/diffusion-transformer-representation-autoencoder/images/cat_reconstructions_1200steps.png new file mode 100644 index 0000000..ff84135 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/cat_reconstructions_1200steps.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/cifar10_loss_comparison.png b/public/content/diffusion-transformer-representation-autoencoder/images/cifar10_loss_comparison.png new file mode 100644 index 0000000..6e64456 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/cifar10_loss_comparison.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_metrics.png b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_metrics.png new file mode 100644 index 0000000..4f82057 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_metrics.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_visual.png b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_visual.png new file mode 100644 index 0000000..0ace448 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_fragility_visual.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_psnr.png b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_psnr.png new file mode 100644 index 0000000..6e36f87 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_psnr.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_visual.png b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_visual.png new file mode 100644 index 0000000..1c27766 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/decoder_robustness_visual.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/overfitting-to-single-image-dimension-vs-loss.png b/public/content/diffusion-transformer-representation-autoencoder/images/overfitting-to-single-image-dimension-vs-loss.png new file mode 100644 index 0000000..60fe568 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/overfitting-to-single-image-dimension-vs-loss.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/pixabay_cat.png b/public/content/diffusion-transformer-representation-autoencoder/images/pixabay_cat.png new file mode 100644 index 0000000..eb21710 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/pixabay_cat.png differ diff --git a/public/content/diffusion-transformer-representation-autoencoder/images/pixabay_cat_recon.png b/public/content/diffusion-transformer-representation-autoencoder/images/pixabay_cat_recon.png new file mode 100644 index 0000000..23753c6 Binary files /dev/null and b/public/content/diffusion-transformer-representation-autoencoder/images/pixabay_cat_recon.png differ diff --git a/public/content/tiny-recursive-model/tiny-recursive-model-content-zh.md b/public/content/tiny-recursive-model/tiny-recursive-model-content-zh.md new file mode 100644 index 0000000..70735f4 --- /dev/null +++ b/public/content/tiny-recursive-model/tiny-recursive-model-content-zh.md @@ -0,0 +1,275 @@ +--- +hero: + title: "微型递归模型" + subtitle: "全新递归推理 AI 架构" + tags: + - "⏱️ 技术深度解析" + - "📄 研究文章" +--- + +## 全新的 AI 推理架构 + +700 万参数模型如何在数独、迷宫和 ARC-AGI 任务中击败万亿参数模型 + +**微型推理模型(TRM)** 使用一个 2 层 Transformer(仅 700 万参数),通过数百次重复使用相同的层来递归地推理问题。 + +它在数独极限版、迷宫、ARC-AGI 等任务中击败了比它大 100 倍的模型。 + +在本教程中,我们将学习 TRM 的工作原理并进行自己的实验。 + +--- + +## TRM 架构概览 + +![微型递归模型架构](/content/tiny-recursive-model/images/tiny-recursive-model-architecture.png) +*图示:微型递归模型架构展示了主处理块(4 层 Transformer)、输入组合(问题 (x)、答案 (y)、推理 (z))、损失计算的输出处理,以及递归更新机制,该机制通过最多 16 步迭代地优化推理和预测。* + +上图展示了完整的 TRM 架构。该模型处理三个关键组件: +- **输入 (x)**:要解决的问题(例如迷宫布局) +- **预测 (y)**:模型当前的答案尝试 +- **潜在变量 (z)**:模型的内部推理状态 + +这些组件被组合在一起并通过 4 层 Transformer 栈处理,输出用于计算交叉熵损失。关键创新是底部的递归更新机制,它通过多步迭代逐步优化推理 (z) 和预测 (y),以改进解决方案。 + +--- + +## TRM 如何工作 + +### 步骤 1:设置 + +让我们训练 TRM 来解决迷宫。 + +**1. 将迷宫表示为网格:** +首先,我们将迷宫显示为数字网格。网格中的每个单元格都有一个数字。 + +- `0` = 空路径 +- `1` = 墙壁 +- `2` = 起点 +- `3` = 终点 + +举个具体例子,让我们追踪一个 3x3 的小迷宫。 + +- **`x_input`**(未解决的迷宫) + ``` + [[2, 0, 1], + [1, 0, 1], + [1, 0, 3]] + ``` +- **`y_true`**(正确的解决方案,用 `4` 表示路径) + ``` + [[2, 4, 1], + [1, 4, 1], + [1, 4, 3]] + ``` + +**2. 标记化(Tokenization):** +术语 **token** 只是指我们数据的一个单元。在这种情况下,网格中的单个数字(`0`、`1`、`2`、`3` 或 `4`)就是一个 token。为了让网络更容易处理,我们将网格"展开"成一个长的一维列表。 + +对于我们的 3x3 示例,网格展开为包含 9 个 token 的列表。 + +**3. 嵌入:赋予数字意义:** +为了让模型理解像 `4` 和 `1` 这样的数字的含义,我们将为每个数字分配一个大的**向量嵌入**。向量嵌入是一个长向量(数字数组),模型可以修改它来存储关于墙壁、空路径等的信息。 + +这些向量将表示"墙壁"或"终点"的含义。 + +我建议你通过在 YouTube 上搜索或与 AI 聊天机器人交谈来复习一下向量嵌入(在大语言模型中,词、token 等的嵌入)是什么。 + +- **嵌入层**就像一个字典。 +- 它包含我们每个数字的向量嵌入。 +- `1`: `[0.3, -1.2, 0.7, 0.0, 1.5, -0.4, 0.9, 2.3]` ← "墙壁"的示例向量嵌入 +- **输出:** 一个称为**向量**的长数字列表。这个向量以网络可以理解的方式表示"墙壁"的*含义*。网络本身在训练期间选择(学习)这个向量中的数字,这样它就可以"理解"它。 + +完成此步骤后,我们的输入迷宫不再是简单数字的列表。它是一个向量列表。对于我们的 3x3 迷宫,如果我们为每个 token 使用大小为 8 的向量,我们的输入变为: + +- `x`:一个 `9x8` 的向量矩阵,表示迷宫。 + +这种丰富的表示就是我们输入主模型的内容。 + +--- + +### 步骤 2:核心架构:TRM 大脑 + +TRM 的"大脑"是一个称为 `net` 的微型 2 层 Transformer。它处理信息以产生输出。为了"思考",TRM 使用两个变量,它们的形状都与 `x` 相同: + +- `y`:模型当前对解决方案的**最佳猜测**。可能是错误的 +``` +[[2, 4, 1], + [1, 4, 1], + [1, 0, 3]] +``` +- `z`:一个**潜在思考**。`z` 告诉需要在 `y` 中改变什么才能将其变成正确的解决方案。`z` 多次通过 Transformer,让模型细化需要在 `y` 中改变的内容,这就是模型推理或思考的方式。然后将更改应用于 `y`。 + +对于我们的 3x3 示例,`z` 和 `y` 最初都是 `9x8` 的零矩阵。 + +--- + +### 步骤 3:学习过程,由内而外 + +TRM 通过一系列嵌套循环学习。让我们从核心开始,逐步向外构建。 + +#### 最内层循环:`latent_recursion`(核心思考) + +这是微型 `net`(一个 2 层 Transformer)完成所有工作的地方。该过程分为两个阶段,它们重复形成思考和优化的循环。 + +**阶段 A:推理(更新草稿本 `z`)** +模型通过在 6 步循环中优化其内部规划 token `z` 来"思考"。目标是建立一个越来越好的改变 `y` 的计划。 + +1. **过程:** 在 6 个步骤中的每一个步骤中,`net` 接受三个输入: + - 迷宫本身(`x`)。 + - 模型当前对解决方案的最佳猜测(`y`)——在开始时这可能全是零。 + - 上一步的草稿本(`z`)。 +2. **工作原理:** + - **组合输入:** 三个输入按元素相加(`x + y + z`)。这创建了一个丰富向量的单个序列,其中每个向量(表示迷宫中的一个单元格)包含有关迷宫布局(`x`)、当前猜测(`y`)和正在进行的思考过程(`z`)的组合信息。 + - **用注意力思考:** 这个组合序列被输入到 2 层 Transformer 中。Transformer 的自注意力机制允许它同时查看所有单元格并识别关系。例如,它可以看到"起点"单元格如何与潜在路径单元格相关,并结合输入数据 `x` 和推理 `z`。 + - **生成下一个思考:** 两个 Transformer 层处理这些信息并输出一个相同形状的新向量序列。这个输出*就是*新的 `z`。没有单独的"输出头"来生成它;两层执行的转换*就是*创建下一个、更精细的思考的行为。尽管输入是包含 `x` 和 `y` 的和,但网络学会产生一个用作下一步的有用的新 `z` 的输出。 + + 这个过程重复 6 次,意味着信息连续六次通过相同的两层,每次传递都变得越来越复杂。 +3. **示例追踪:** 在通过 Transformer 的几次传递后,`z` 可能编码低级特征,如墙壁位置。到第六次传递时,它可能表示更新答案(`y`)的高级计划。 + + - 有趣的是,相同的 2 个 Transformer 层用于检测低级特征、制定高级计划,以及后来更新 `y` 本身。这 2 层具有多重用途,这是神经网络的力量,它可以学习执行多个不太相关或不相关的转换,这只取决于输入数据。 + +**阶段 B:优化答案(更新猜测 `y`)** +在 6 步推理循环之后,使用最新的潜在思考 `z`,模型更新其答案 `y`。 + +- **工作原理:** 它将先前的答案(`y`)与最终的、精细的思考(`z`)通过将它们相加(`y + z`)组合在一起,并最后一次通过相同的 `net`。输出是新的、改进的 `y`。 + - **关键是,`x` 不包含在此步骤中。** 这是一个刻意的设计选择,告诉单个 `net` 要执行哪个任务。 + - `x` 存在于推理中(`x + y + z`)。 + - `x` 不存在于答案优化中(`y + z`)。 + +我说"答案优化"的原因是,这个 6+1 循环会发生多次,每次"思考" 6 次并更新 `y` 一次。 + +#### 中间循环:`deep_recursion`(完整的思考过程) + +现在我们理解了推理 + y 优化循环是如何工作的,让我们看看从头开始的完整思考过程,在这个过程中,整个循环重复 3 次以获得最佳的 `y`。 + +前面描述的内部循环(推理和 `y` 优化的 6+1 步骤)运行 `T` 次(例如,`T=3`)。状态(`y` 和 `z`)在这些运行之间**被保留**;它不会重置为零。 + +- **第 1 轮(热身):** 从空白(全零)的 `y` 和 `z` 开始(请记住,这是过程的绝对开始,所以没有 `y` 和 `z` 可以保留)。它运行完整的内部循环(6 步推理 + 1 步 `y` 优化)以产生更智能的 `y_1` 和 `z_1`。这是在"无梯度"模式下完成的,以节省速度和内存——神经网络在这里不学习。 +- **第 2 轮(热身):** 它以 `y_1` 和 `z_1` 作为起点,再次运行内部循环以产生更好的 `y_2` 和 `z_2`。仍然没有梯度和学习。 +- **第 3 轮(真正的):** 它从经过充分推理的 `y_2` 和 `z_2` 开始,最后一次运行内部循环,这次所有计算都被跟踪,以便模型可以通过反向传播学习。 + +在最终的可学习步骤之前预热模型的"思考"的这个过程是一个关键的优化。 + +#### 最外层循环:更多循环! + +模型获得多次"机会"(最多 16 次)来解决同一个迷宫,每次机会后,它都会优化其 `net` 权重。状态(`y` 和 `z`)**从一个中间循环迭代保留**到下一个,如论文的伪代码所示。它允许模型获得多次"机会"(最多 16 次)来解决同一个迷宫,每次都在改进。 + +这只是重复中间循环最多 16 次。如果模型觉得它得到了正确的答案,它可以决定提前停止。 + +为什么我们需要这个循环: + +在每次中间循环迭代之后,这个外部循环更新一次权重(记住中间循环中的第 3 轮执行反向传播)。 + +然后在下一次迭代中,它用更新的权重重复中间循环,允许模型在每次尝试中逐步改进其解决方案。 + +#### 知道何时停止思考(Q 头) + +外部循环最多可以运行 16 次,但不必如此。继续思考它已经解决的迷宫将是浪费时间。 + +因此,模型有一个称为"Q 头"的小副脑。在每个完整的思考过程(每个中间循环)之后,这个 Q 头会给出一个分数。这个分数基本上是模型的信心:"我有多确定我做对了?" + +如果信心分数足够高,外部循环就会停止(`break`),模型会继续下一个迷宫。 + +它学会正确获得这个信心分数,因为它是训练的一部分。如果它自信*并且*正确,它会得到奖励,如果它自信但错误,它会受到惩罚。论文称之为自适应计算时间(ACT)。 + +--- + +```python +# 初始化 +y, z = zeros_like(x), zeros_like(x) + +# 深度监督循环(最多 16 次) +for supervision_step in range(16): + + # 深度递归:热身(2 次,无梯度) + with torch.no_grad(): + for _ in range(2): + # 潜在递归 + for _ in range(6): + z = net(x + y + z) + y = net(y + z) + + # 深度递归:最终(1 次,有梯度) + for _ in range(6): + z = net(x + y + z) + y = net(y + z) + + # 学习 + y_pred = output_head(y) + loss = cross_entropy(y_pred, y_true) + loss.backward() + optimizer.step() + + # 我们应该停止吗? + q = Q_head(y) + if q > 0: + break +``` + +--- + +### 步骤 4:消融研究——是什么让 TRM 起作用? + +![完整消融研究](/content/tiny-recursive-model/images/complete_ablation_study.png) +*图示:在迷宫求解(30x30,困难)上进行 10 个 epoch 训练,四种 TRM 配置的训练损失比较。基线(蓝色实线)使用 TRM 的标准设计:2 层网络,H=3(中间循环),L=6(内部循环),带 EMA。消融测试:移除 EMA(红色虚线),减少递归深度(绿色点划线),以及使用更大的 4 层网络(品红色点线)。* + +为了理解是什么使 TRM 有效,我们系统地测试变体,移除或更改关键组件。这些**消融研究**揭示了哪些设计选择是必不可少的。 + +#### 实验设置 + +我们在迷宫求解任务(30x30 困难迷宫,1000 个训练示例)上测试四种配置: + +| 配置 | 层数 | H_cycles | L_cycles | EMA | 有效深度* | +|---------------|--------|----------|----------|-----|------------------| +| **基线 TRM** | 2 | 3 | 6 | 是 | 42 | +| **无 EMA** | 2 | 3 | 6 | 否 | 42 | +| **更少递归** | 2 | 2 | 2 | 是 | 12 | +| **更大的大脑** | 4 | 3 | 3 | 是 | 48 | + +*有效深度 = T × (n+1) × 层数 + +#### 结果 + +**注意:** 这些是 10 个 epoch 的实验——与论文的 50,000 多个 epoch 的运行相比,训练量非常小。更长的训练可能会显著改变这些配置的相对性能,特别是对于泛化(如我们在下面的"更大的大脑"结果中看到的)。 + +| 配置 | 初始损失 | 最终损失 | 最小损失 | 改进 | +|---------------|--------------|------------|----------|-------------| +| 基线 | 1.789 | 1.062 | 1.045 | 40.6% | +| 无 EMA | 1.789 | 1.042 | 1.041 | 41.7% | +| 更少递归 | **2.100** | 1.100 | 1.042 | 47.6% | +| 更大的大脑(4 层) | 1.789 | **1.007** | **1.007** | **43.7%** | + +#### 关键发现 + +**1. "更大的大脑"悖论:短期 vs. 长期性能** + +4 层网络在我们的 10 个 epoch 实验中实现了**最佳的最终损失**(1.007),比 2 层基线好约 5%。这似乎与论文的"少即是多"主张相矛盾。 + +**为什么不同?** +- **短期**(10 个 epoch):更多容量 = 更快学习。4 层网络可能快速记住模式。 +- **长期**(50,000 多个 epoch):更多容量 = 过拟合。2 层网络*被迫*学习可重用的推理策略,而不是记住特定的解决方案。 + +论文的核心论点:**被迫递归思考的小型网络比大型网络泛化得更好**,即使它们最初训练得更慢。选择 2 层架构是为了防止记忆并强制依赖递归。 + +**2. 递归深度是基础** + +"更少递归"配置(H=2,L=2)显示出严重退化的性能: +- 在任何训练之前,初始损失高出 **17%**(2.100 对 1.789) +- 实现最差的最终损失(1.100),尽管改进了 47.6% + +**论文说什么:** 将递归从 T=3,n=6 减少到 T=2,n=2,数独准确率从 87.4% 下降到 73.7% ——下降 14%。 + +**为什么这很重要:** 高初始损失揭示了浅递归*在设计上*削弱了模型的表征能力。即使有完美的训练,也没有足够的递归"思考步骤"来解决复杂问题。**你不能用更好的训练来弥补不足的递归深度。** + +**3. EMA 的短期影响最小** + +移除 EMA 几乎不影响 10 个 epoch 的性能(最终损失 1.042 对基线的 1.062,只有约 2% 的差异)。 + +**论文说什么:** 在数独极限版上,移除 EMA 使准确率从 87.4% 下降到 79.9% ——在完整训练后下降 8%。 + +**为什么不同?** EMA 是模型权重的**指数移动平均**,它在长期运行中稳定训练。在短期实验中,两个模型仍在探索,还没有遇到 EMA 防止的不稳定性。在 50,000 多个 epoch 中,EMA 防止了灾难性的发散和过拟合峰值,使其对最终性能至关重要。 + +--- + +感谢您阅读本教程,我们下一个教程见。 + diff --git a/styles/math-styles.css b/styles/math-styles.css index d7c41fd..bc96ad0 100644 --- a/styles/math-styles.css +++ b/styles/math-styles.css @@ -71,3 +71,58 @@ border-color: rgba(139, 92, 246, 0.5); transition: all 0.2s ease; } + +/* Custom syntax highlighting for code blocks */ +.hljs { + background: #1e293b !important; /* slate-800 */ + color: #e2e8f0 !important; /* slate-200 */ +} + +.hljs-keyword { + color: #c084fc !important; /* purple-300 */ +} + +.hljs-string { + color: #34d399 !important; /* emerald-400 */ +} + +.hljs-comment { + color: #64748b !important; /* slate-500 */ + font-style: italic; +} + +.hljs-number { + color: #fbbf24 !important; /* amber-400 */ +} + +.hljs-function { + color: #60a5fa !important; /* blue-400 */ +} + +.hljs-variable { + color: #f472b6 !important; /* pink-400 */ +} + +.hljs-type { + color: #fb7185 !important; /* rose-400 */ +} + +.hljs-built_in { + color: #a78bfa !important; /* violet-400 */ +} + +.hljs-title { + color: #60a5fa !important; /* blue-400 */ +} + +.hljs-params { + color: #e2e8f0 !important; /* slate-200 */ +} + +.hljs-attr { + color: #34d399 !important; /* emerald-400 */ +} + +.hljs-value { + color: #fbbf24 !important; /* amber-400 */ +}