Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions demos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,25 @@ python3 demos/check_ckpt_for_gelu_shift.py \

`adam_vs_adamw.sh` trains two tiny Shakespeare models, one with Adam and one
with AdamW, then compares their statistics using `view_model_stats.py`.

## Snap-to-Grid Projections

`snap_to_grid_demo.sh` prepares the Shakespeare character dataset, trains a
small model with snap-to-grid enabled, evaluates multiple grid sizes, and then
generates text with and without the projections. Run it from the repository
root:

```bash
bash demos/snap_to_grid_demo.sh
```

You can override the default model or snap-to-grid settings by exporting
environment variables before running the script, for example:

```bash
N_HEAD=3 N_EMBD=384 SNAP_SIZES="100 1000 10000" bash demos/snap_to_grid_demo.sh
```

Ensure that `N_EMBD` remains divisible by `N_HEAD`; otherwise the
multi-head attention projection will be invalid and the script will exit with
an explanatory error.
119 changes: 119 additions & 0 deletions demos/snap_to_grid_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/bin/bash
# snap_to_grid_demo.sh
# Demonstrates training and sampling with snap-to-grid projections.

set -euo pipefail

DATASET="${DATASET:-shakespeare_char}"
DATA_DIR="data/${DATASET}"
OUT_DIR="${OUT_DIR:-out/snap_to_grid_demo}"
SNAP_SIZES_STR="${SNAP_SIZES:-"8 32"}"
IFS=' ' read -r -a SNAP_SIZES <<< "${SNAP_SIZES_STR}"

# Model hyperparameters (override via env vars, e.g. `N_HEAD=3 ./snap_to_grid_demo.sh`).
N_LAYER=${N_LAYER:-4}
N_HEAD=${N_HEAD:-4}
N_EMBD=${N_EMBD:-128}

if (( N_EMBD % N_HEAD != 0 )); then
echo "error: N_EMBD (${N_EMBD}) must be divisible by N_HEAD (${N_HEAD})." >&2
echo "update N_EMBD or N_HEAD before running the demo." >&2
exit 1
fi

mkdir -p "${OUT_DIR}"

echo "=== Step 1: Ensure the ${DATASET} dataset is prepared ==="
if [ ! -f "${DATA_DIR}/train.bin" ] || [ ! -f "${DATA_DIR}/val.bin" ]; then
pushd "${DATA_DIR}" > /dev/null
if [ ! -f "input.txt" ]; then
echo "Downloading Shakespeare corpus..."
bash get_dataset.sh
fi
echo "Tokenizing dataset with tiktoken encoder..."
python3 prepare.py -t input.txt --method tiktoken
popd > /dev/null
else
echo "Found existing tokenized dataset artifacts."
fi

CKPT_PATH="${OUT_DIR}/ckpt.pt"

cat <<CONFIG
=== Step 2: Train a tiny model with snap-to-grid enabled ===
- output directory: ${OUT_DIR}
- snap-to-grid sizes evaluated: ${SNAP_SIZES[*]}
CONFIG

python3 train.py \
--dataset "${DATASET}" \
--out_dir "${OUT_DIR}" \
--block_size 128 \
--batch_size 12 \
--n_layer "${N_LAYER}" \
--n_head "${N_HEAD}" \
--n_embd "${N_EMBD}" \
--max_iters 200 \
--eval_interval 100 \
--eval_iters 50 \
--log_interval 10 \
--learning_rate 3e-4 \
--enable_snap_to_grid \
--snap_to_grid_sizes "${SNAP_SIZES[@]}" \
--snap_to_grid_components both

if [ ! -f "${CKPT_PATH}" ]; then
echo "Expected checkpoint not found at ${CKPT_PATH}" >&2
exit 1
fi

cat <<CONFIG
=== Step 3: Evaluate validation loss with snap-to-grid registries ===
- checkpoint: ${CKPT_PATH}
- snap-to-grid sizes: ${SNAP_SIZES[*]}
CONFIG

python3 sample.py \
--out_dir "${OUT_DIR}" \
--init_from resume \
--eval_only \
--eval_dataset "${DATASET}" \
--eval_iters 50 \
--enable_snap_to_grid \
--snap_to_grid_sizes "${SNAP_SIZES[@]}"

cat <<CONFIG
=== Step 4: Generate baseline samples (snap-to-grid disabled) ===
CONFIG

python3 sample.py \
--out_dir "${OUT_DIR}" \
--init_from resume \
--start "ROMEO: " \
--num_samples 1 \
--max_new_tokens 64 \
--temperature 0.8 \
--top_k 200 \
--seed 1337

cat <<CONFIG
=== Step 5: Generate samples with snap-to-grid enabled ===
- snap-to-grid sizes: ${SNAP_SIZES[*]}
CONFIG

python3 sample.py \
--out_dir "${OUT_DIR}" \
--init_from resume \
--start "ROMEO: " \
--num_samples 1 \
--max_new_tokens 64 \
--temperature 0.8 \
--top_k 200 \
--seed 1337 \
--enable_snap_to_grid \
--snap_to_grid_sizes "${SNAP_SIZES[@]}"

cat <<MSG
Demo complete. Snap-to-grid registries and logs are stored under ${OUT_DIR}/snap_to_grid.
Check tensorboard logs for the snap_to_grid/val_loss_size_* series to compare validation performance.
MSG
6 changes: 6 additions & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ class GPTConfig:
plot_statistics: bool = False
softmax_io_log_interval: int = 1

# Snap-to-grid options
enable_snap_to_grid: bool = False
snap_to_grid_layers: List[int] = field(default_factory=list)
snap_to_grid_components: str = "both"
snap_to_grid_sizes: List[int] = field(default_factory=list)

# Training options
## Gradient Checkpointing - More memory efficient (can do long contexts), but is slower
use_gradient_checkpointing: bool = False
Expand Down
42 changes: 35 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from shared_param_utils import SharedParamGroupCreator
from variations.block_variations import Block
from utils.snap_to_grid import SnapToGridRegistry

class LearnedPositionEmbedding(nn.Module):
"""
Expand All @@ -68,7 +69,7 @@ def __init__(self, config):

self.drop = nn.Dropout(config.dropout)
# reuse the same Block init as GPT.transformer.h
self.blocks = nn.ModuleList([Block(self.lpe_config) for _ in range(self.lpe_config.n_layer)])
self.blocks = nn.ModuleList([Block(self.lpe_config, layer_idx=i) for i in range(self.lpe_config.n_layer)])

def forward(self, b, t, x, iter_num=None):
# add absolute position embeddings if used
Expand All @@ -92,6 +93,14 @@ def __init__(self, config):

self.config = config

n_head = getattr(config, "n_head", None)
if n_head:
if config.n_embd % n_head != 0:
raise ValueError(
f"n_embd ({config.n_embd}) must be divisible by n_head ({n_head}); "
"adjust the configuration (e.g. change --n_embd or --n_head)."
)

self.uses_numerical_multicontext = bool(config.numerical_multicontext)
if self.uses_numerical_multicontext:
if not config.multicontext:
Expand Down Expand Up @@ -195,20 +204,30 @@ def __init__(self, config):


self.transformer['drop'] = nn.Dropout(config.dropout)
self.transformer['h'] = nn.ModuleList([Block(config, mlp=shared_mlp_array[i], attn=shared_attn_array[i]) for i in range(config.n_layer)])
if getattr(config, "enable_snap_to_grid", False) and getattr(config, "snap_to_grid_registry", None) is None:
config.snap_to_grid_registry = SnapToGridRegistry()

self.transformer['h'] = nn.ModuleList([
Block(config, layer_idx=i, mlp=shared_mlp_array[i], attn=shared_attn_array[i])
for i in range(config.n_layer)
])
self.transformer['ln_f'] = norm_dictionary[config.norm_variant_output](config)

self.snap_to_grid_registry = getattr(config, "snap_to_grid_registry", None)
self.config.snap_to_grid_registry = self.snap_to_grid_registry
self._apply_snap_to_grid_registry(self.snap_to_grid_registry)

# Optional post-embedding normalizations
if self.config.norm_variant_wte is not None:
self.transformer['post_embedding_norm'] = self.build_norm_from_variant(config, "norm_variant_wte", "norm_wte")
self.transformer['post_embedding_norm'] = self.build_norm_from_variant(self.config, "norm_variant_wte", "norm_wte")
if self.config.norm_variant_abs is not None:
self.transformer['post_abs_norm'] = self.build_norm_from_variant(config, "norm_variant_abs", "norm_abs")
self.transformer['post_abs_norm'] = self.build_norm_from_variant(self.config, "norm_variant_abs", "norm_abs")

if self.config.use_abs_pos_embeddings:
if config.quantize_wpe:
pos_embd = QuantizedEmbedding(config.block_size, config.n_embd, config.quantize_wpe_method, config.quantize_wpe_bits)
if self.config.quantize_wpe:
pos_embd = QuantizedEmbedding(self.config.block_size, self.config.n_embd, self.config.quantize_wpe_method, self.config.quantize_wpe_bits)
else:
pos_embd = nn.Embedding(config.block_size, config.n_embd)
pos_embd = nn.Embedding(self.config.block_size, self.config.n_embd)
self.transformer['wpe'] = pos_embd

# Select softmax variant for output layer
Expand Down Expand Up @@ -271,6 +290,15 @@ def __init__(self, config):
# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

def _apply_snap_to_grid_registry(self, registry: SnapToGridRegistry | None) -> None:
for block in self.transformer['h']:
block.snap_to_grid_registry = registry

def set_snap_to_grid_registry(self, registry: SnapToGridRegistry | None) -> None:
self.snap_to_grid_registry = registry
self.config.snap_to_grid_registry = registry
self._apply_snap_to_grid_registry(registry)

def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
Expand Down
Loading