Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions records/track_10min_16mb/2026-03-21_NeuralCache_Research/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Neural Cache: Cross-Window KV Cache for Extended Context at Eval Time

**Research proposal (no record claim)** | Base model: PR #287 reproduction (1.1284 BPB) | 8xH100 SXM

## The Idea

Standard sliding window evaluation processes each window independently. A window at position 10,000 has no memory of what happened at position 5,000 — even though those tokens were already evaluated. Neural Cache fixes this by **caching K/V pairs across windows**, extending effective context from 2,048 tokens to 50K+ tokens at zero artifact cost.

```
Standard sliding window (stride=64, seq=2048):
Window 1: [tokens 0-2047] -> score tokens 0-2047 (context: 2048)
Window 2: [tokens 64-2111] -> score tokens 2048-2111 (context: 2048)
Window 3: [tokens 128-2175] -> score tokens 2112-2175 (context: 2048)
...each window is INDEPENDENT. Token 2048 cannot see token 0.

Neural Cache (stride=64, seq=2048, cache=8192):
Window 1: [tokens 0-2047] -> score, cache K/V for stride tokens
Window 2: [tokens 64-2111] -> attend to cached K/V + current window
Window 3: [tokens 128-2175] -> attend to growing cache + current window
...token 8000 can attend to token 0 through the cache. Effective context: 10K+
```

## Why This Should Work

1. **More context = better prediction.** This is proven: seq2048 > seq1024 > seq512 (PR #136: -0.014 BPB from longer context). Neural Cache extends this principle beyond the training sequence length.

2. **Flash Attention natively supports it.** When `seqlen_k > seqlen_q`, FA3 treats the extra K/V as "earlier" context — exactly the KV-cache pattern used in LLM inference. No custom kernels needed.

3. **Backward-looking only.** The cache contains K/V from already-evaluated tokens. No future information leaks. This is the same principle as backward-looking TTT (PR #267, confirmed rule-compliant) but lighter weight — no gradient computation, just cached hidden states.

4. **Zero artifact cost.** No extra parameters, no model changes. Pure eval-time technique. ~50 lines of code.

## Implementation

The core idea: modify the attention forward pass to accept and prepend cached K/V.

```python
def attn_forward_with_cache(attn_module, x, kv_cache=None, cache_seqlen=0):
# Compute Q, K, V for current window
q, k, v = compute_qkv(attn_module, x)

# Apply RoPE with position offset (critical for correctness)
cos, sin = attn_module.rotary(cache_seqlen + seqlen, device, dtype)
q = apply_rotary_emb(q, cos[cache_seqlen:], sin[cache_seqlen:])
k = apply_rotary_emb(k, cos[cache_seqlen:], sin[cache_seqlen:])

# Prepend cached K/V from previous windows
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=1) # [B, cache+seq, H, D]
v = torch.cat([kv_cache[1], v], dim=1)

# Flash Attention handles seqlen_k > seqlen_q natively
y = flash_attn_func(q, k, v, causal=True)
return y, (new_k, new_v) # Return current K/V for future caching
```

The eval loop maintains a per-layer cache, only storing the `stride` newest tokens per window to avoid redundancy:

```python
layer_caches = [None] * num_layers
for window in sliding_windows:
logits, new_caches = forward_with_cache(model, window, layer_caches)
for layer_idx in range(num_layers):
# Only cache the NEW tokens (stride=64), not the full 2048 window
new_k = new_caches[layer_idx][0][:, -stride:]
# Append to existing cache, trim to max_cache_tokens
layer_caches[layer_idx] = concat_and_trim(old_cache, new_k, max_tokens=8192)
score_tokens(logits, window)
```

## RoPE Considerations

The model was trained with `train_seq_len=1024` and uses NTK-aware RoPE scaling (auto-scales base frequency for longer sequences). For cache positions beyond the training length, RoPE quality degrades gradually. This is a known limitation — the same issue affects any long-context evaluation.

Potential mitigations:
- **Cache only last N layers** (e.g., last 4 with XSA) — earlier layers handle local patterns that don't need extended context
- **Limit cache to 4096 tokens** — stays within 4x of training length where NTK scaling is still effective
- **Use RoPE base 50000** (as in PR #254) — extends the effective RoPE range

## Rule Compliance

Per the organizer ruling on TTT (Mar 20):
> "You can't train on the validation tokens before you evaluate on those same tokens."

Neural Cache does NOT train on anything. It caches intermediate hidden states (K/V pairs) from **already-evaluated** tokens and uses them as additional context for future tokens. This is:
- **No weight modification** (unlike TTT)
- **Backward-looking only** (only uses K/V from scored tokens)
- **Equivalent to a longer context window** — evaluation methods are explicitly unrestricted

## Status: Untested Due to Compute Constraints

We implemented the full Neural Cache eval but encountered a bug in the model state after `torch.compile` — the custom forward path produced invalid results when called on the compiled `base_model`. The fix (using a fresh `eval_model` loaded from saved weights) was identified but we ran out of compute budget before re-running.

**The code is provided below for anyone to test.** Expected cost: one 8xH100 run (~$5) to train + eval with Neural Cache.

## Estimated Impact

- **Conservative:** 0.005-0.01 BPB (from context extension alone)
- **Optimistic:** 0.01-0.03 BPB (if the model effectively leverages 10K+ context)
- **Risk:** RoPE degradation beyond training length could limit gains

For reference, sliding window eval (extending context via overlap) gave -0.034 BPB (PR #77). Neural Cache extends context further via a complementary mechanism.

## Reproduction

Base model: PR #287's recipe (XSA + EMA + 11L + SmearGate + BigramHash)

```bash
NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 XSA_LAST_N=4 \
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \
MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3000 \
ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

Our reproduction: 7,009 steps @ 85.6ms/step, **1.1284 BPB** sliding window (vs PR #287's 1.1271).

## Hardware

8x NVIDIA H100 80GB SXM, RunPod. Training: 600s. Standard eval: ~30s. Sliding window: ~85s. Neural Cache eval (estimated): ~300s for 1M token subset.

## Author

Xiaoan Liu | NYU | GitHub: @sseanliu
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""Neural Cache Evaluation: Cross-window KV caching for extended context.

Usage: Add this to the end of the training script's main() function,
AFTER the int6 sliding window eval creates `eval_model`.

# --- NEURAL CACHE EVAL ---
if master_process:
for cache_size in [0, 2048, 4096]:
nc_loss, nc_bpb = eval_neural_cache(
eval_model, rank, device, val_tokens, base_bytes_lut,
has_leading_space_lut, is_boundary_token_lut,
seq_len=args.train_seq_len, stride=64,
max_cache_tokens=cache_size, max_eval_tokens=1000000)
print(f"neural_cache cache={cache_size} bpb={nc_bpb:.6f}")

IMPORTANT: Use `eval_model` (fresh model loaded from saved weights),
NOT `base_model` (which has torch.compile applied and produces invalid results).
"""

import math
import time
import torch
import torch.nn.functional as F
from flash_attn_interface import flash_attn_func as flash_attn_3_func


def attn_forward_with_cache(attn_module, x, kv_cache=None, cache_seqlen=0):
"""Attention forward with KV cache prepended for extended context.

Args:
attn_module: CausalSelfAttention module
x: input [bsz, seqlen, dim] (already through attn_norm)
kv_cache: tuple (cached_k, cached_v) or None
cache_seqlen: number of tokens in cache (for RoPE position offset)

Returns:
output: [bsz, seqlen, dim]
new_kv: tuple (k, v) for current window
"""
# Import apply_rotary_emb from the training script
from train_gpt import apply_rotary_emb

bsz, seqlen, dim = x.shape
q = attn_module.c_q(x).reshape(bsz, seqlen, attn_module.num_heads, attn_module.head_dim)
k = attn_module.c_k(x).reshape(bsz, seqlen, attn_module.num_kv_heads, attn_module.head_dim)
v = attn_module.c_v(x).reshape(bsz, seqlen, attn_module.num_kv_heads, attn_module.head_dim)

q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))

# RoPE with position offset for cached context
total_len = cache_seqlen + seqlen
cos, sin = attn_module.rotary(total_len, x.device, q.dtype)
q = apply_rotary_emb(q, cos[cache_seqlen:total_len], sin[cache_seqlen:total_len])
k = apply_rotary_emb(k, cos[cache_seqlen:total_len], sin[cache_seqlen:total_len])

q = q * attn_module.q_gain.to(dtype=q.dtype)[None, None, :, None]

# Save current K/V before cache concatenation
new_k, new_v = k.clone(), v.clone()

# Prepend cached K/V from previous windows
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=1)
v = torch.cat([kv_cache[1], v], dim=1)

# flash_attn handles seqlen_k > seqlen_q with causal=True correctly:
# queries attend to all cached tokens + causal portion of current window
y = flash_attn_3_func(q, k, v, causal=True)

if attn_module.use_xsa:
y = attn_module._xsa_efficient(y, new_v)

y = y.reshape(bsz, seqlen, dim)
return attn_module.proj(y), (new_k, new_v)


def forward_logits_cached(model, input_ids, layer_caches=None, cache_seqlen=0):
"""Full forward pass with per-layer KV caches."""
x = model.tok_emb(input_ids)
if model.bigram is not None:
x = x + model.bigram(input_ids)
x = F.rms_norm(x, (x.size(-1),))
x = model.smear(x)
x0 = x

new_caches = []
skips = []
layer_idx = 0

for i in range(model.num_encoder_layers):
block = model.blocks[i]
mix = block.resid_mix.to(dtype=x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0

lc = layer_caches[layer_idx] if layer_caches else None
attn_out, new_kv = attn_forward_with_cache(
block.attn, block.attn_norm(x), kv_cache=lc, cache_seqlen=cache_seqlen)
new_caches.append(new_kv)
layer_idx += 1

x = x + block.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
x = x + block.mlp_scale.to(dtype=x.dtype)[None, None, :] * block.mlp(block.mlp_norm(x))
skips.append(x)

for i in range(model.num_decoder_layers):
if skips:
x = x + model.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
block = model.blocks[model.num_encoder_layers + i]
mix = block.resid_mix.to(dtype=x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0

lc = layer_caches[layer_idx] if layer_caches else None
attn_out, new_kv = attn_forward_with_cache(
block.attn, block.attn_norm(x), kv_cache=lc, cache_seqlen=cache_seqlen)
new_caches.append(new_kv)
layer_idx += 1

x = x + block.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
x = x + block.mlp_scale.to(dtype=x.dtype)[None, None, :] * block.mlp(block.mlp_norm(x))

x = model.final_norm(x)
if model.tie_embeddings:
logits_proj = F.linear(x, model.tok_emb.weight)
else:
logits_proj = model.lm_head(x)
logits = model.logit_softcap * torch.tanh(logits_proj / model.logit_softcap)
return logits, new_caches


def eval_neural_cache(
model, rank, device, val_tokens,
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
seq_len=2048, stride=64, max_cache_tokens=4096, max_eval_tokens=1000000,
):
"""Sliding window eval with cross-window KV caching.

Args:
model: GPT model (use eval_model, NOT base_model after torch.compile)
rank: distributed rank (only rank 0 runs this)
device: CUDA device
val_tokens: validation token tensor
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut: BPB lookup tables
seq_len: window size (default 2048)
stride: scoring stride (default 64)
max_cache_tokens: maximum cached K/V tokens per layer (0 = no caching)
max_eval_tokens: subset size for quick testing

Returns:
(val_loss, val_bpb) tuple
"""
if rank != 0:
return 0.0, 0.0

total_tokens = min(val_tokens.numel() - 1, max_eval_tokens)
num_layers = len(model.blocks)

loss_sum = 0.0
token_count = 0
byte_count = 0.0
layer_caches = [None] * num_layers
cache_seqlen = 0

model.eval()
t0 = time.perf_counter()

with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
for ws in range(0, total_tokens, stride):
end = min(ws + seq_len, total_tokens)
wlen = end - ws
if wlen < 1:
break

chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device)
x_in = chunk[:-1].unsqueeze(0)
y_tgt = chunk[1:].unsqueeze(0)

logits, new_caches = forward_logits_cached(
model, x_in, layer_caches=layer_caches, cache_seqlen=cache_seqlen)

# Update per-layer caches: only store the stride-worth of NEW tokens
for li in range(num_layers):
if max_cache_tokens == 0:
layer_caches[li] = None
continue
new_k, new_v = new_caches[li]
cache_k = new_k[:, -stride:]
cache_v = new_v[:, -stride:]
if layer_caches[li] is not None:
old_k, old_v = layer_caches[li]
cache_k = torch.cat([old_k, cache_k], dim=1)
cache_v = torch.cat([old_v, cache_v], dim=1)
if cache_k.size(1) > max_cache_tokens:
cache_k = cache_k[:, -max_cache_tokens:]
cache_v = cache_v[:, -max_cache_tokens:]
layer_caches[li] = (cache_k, cache_v)

cache_seqlen = min(ws + wlen, max_cache_tokens) if max_cache_tokens > 0 else 0

# Score only the NEW tokens
nll = F.cross_entropy(logits[0].float(), y_tgt[0], reduction="none")
s = 0 if ws == 0 else max(wlen - stride, 0)
scored_nll = nll[s:wlen].to(torch.float64)
loss_sum += scored_nll.sum().item()
token_count += wlen - s

tgt = y_tgt[0, s:wlen]
prev = x_in[0, s:wlen]
tb = base_bytes_lut[tgt].to(torch.float64)
tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64)
byte_count += tb.sum().item()

if ws % (stride * 500) == 0 and ws > 0:
elapsed = time.perf_counter() - t0
running_bpb = (loss_sum / token_count / math.log(2.0)) * (token_count / byte_count)
print(f" ncache pos={ws}/{total_tokens} bpb={running_bpb:.4f} "
f"cache={cache_seqlen} elapsed={elapsed:.0f}s")

elapsed = time.perf_counter() - t0
val_loss = loss_sum / token_count
bpb = (val_loss / math.log(2.0)) * (token_count / byte_count)
return val_loss, bpb
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"track": "10min_16mb",
"date": "2026-03-21",
"name": "Neural Cache: Cross-Window KV Cache for Extended Eval Context (research proposal)",
"author": "Xiaoan Liu",
"github_id": "sseanliu",
"blurb": "Research proposal for extending effective eval context from 2K to 50K+ tokens by caching K/V pairs across sliding windows. Backward-looking, zero artifact cost, rule-compliant. Implementation provided but untested due to compute constraints. Base: PR #287 reproduction at 1.1284 BPB.",
"seed_results": {
"1337": {"val_loss": 1.90519942, "val_bpb": 1.12836940, "steps": 7009, "ms_per_step": 85.62}
},
"mean_val_bpb": 1.1284,
"artifact_bytes": 15532039,
"code_bytes": 71412,
"notes": "Non-record research submission. Neural Cache eval not yet validated — torch.compile interaction bug prevented valid results. Base reproduction of PR #287 confirms 1.1284 BPB (vs original 1.1271). FA3 + 8xH100 SXM."
}