Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Oct 19, 2025

This pull request introduces support for "snap-to-grid" projections throughout the codebase, enabling quantized projections in transformer models for both training and inference. It adds user-facing controls, updates model configuration and initialization, and ensures that snap-to-grid can be flexibly applied and evaluated. The changes are grouped below by theme.

Snap-to-Grid Feature Implementation

  • Added snap-to-grid configuration options (enable_snap_to_grid, snap_to_grid_layers, snap_to_grid_components, snap_to_grid_sizes) to the GPTConfig class in gpt_conf.py, allowing fine-grained control over which layers and components use quantized projections.
  • Integrated SnapToGridRegistry into model initialization in model.py, ensuring that the registry is created and attached to each transformer block when enabled, and added methods for setting and applying registries dynamically. [1] [2] [3]
  • Updated the Block initialization in model.py to accept a layer_idx parameter, supporting per-layer snap-to-grid configuration.

User-Facing Controls and Documentation

  • Added a new demo script, demos/snap_to_grid_demo.sh, and updated the documentation in demos/README.md to explain how to run snap-to-grid experiments, override settings, and interpret results. [1] [2]

Training and Inference Integration

  • Updated train.py and sample.py to parse snap-to-grid arguments, normalize configuration, and generate/apply snap-to-grid registries for both training and sampling. This includes saving registry files and iterating over multiple grid sizes for evaluation and generation. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10]

Validation and Error Handling

  • Added validation in model initialization to ensure that n_embd is divisible by n_head, preventing invalid configurations for multi-head attention when snap-to-grid is enabled.

Utility Integration

  • Imported snap-to-grid utility functions (generate_snap_to_grid_registry, save_registry) into relevant scripts to support registry creation and persistence. [1] [2]

These changes collectively provide a robust framework for experimenting with snap-to-grid projections, including configuration, evaluation, and reproducibility.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds snap-to-grid projection capabilities to the transformer model, enabling quantized projections of activations onto a learned hypersphere grid before attention and MLP pre-normalization layers. The feature supports per-layer and per-component configuration, registry-based grid management, and evaluation across multiple grid sizes.

Key changes:

  • Added snap-to-grid projection infrastructure with SnapToGridRegistry for managing grids and apply_snap_to_grid_tensor for projecting activations
  • Integrated snap-to-grid into the Block forward pass with configurable layer and component targeting
  • Extended training and sampling workflows to generate, evaluate, and persist snap-to-grid registries across multiple grid sizes

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
variations/block_variations.py Added _maybe_snap_to_grid method and snap-to-grid calls before attention/MLP pre-norms; added layer_idx parameter to Block.__init__
utils/snap_to_grid.py New utility module providing registry management, grid generation from model weights, and activation projection functions
train_args.py Added command-line arguments for snap-to-grid configuration (enable flag, layers, components, sizes)
train.py Normalized snap-to-grid config in setup; added registry evaluation loop at checkpoints; integrated multi-size sampling
sample.py Added snap-to-grid arguments; implemented multi-size evaluation and generation loops with registry persistence
model.py Added validation for n_embd divisibility; initialized and propagated SnapToGridRegistry through blocks; added registry setter methods
gpt_conf.py Added snap-to-grid configuration fields to GPTConfig dataclass
demos/snap_to_grid_demo.sh New demo script training and evaluating a model with snap-to-grid enabled
demos/README.md Added documentation for snap-to-grid demo usage and configuration

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +481 to +487
def _maybe_snap_to_grid(self, component: str, tensor: torch.Tensor) -> torch.Tensor:
if self.snap_to_grid_registry is None or self.layer_idx is None:
return tensor
grid = self.snap_to_grid_registry.get_grid(self.layer_idx, component)
if grid is None:
return tensor
return apply_snap_to_grid_tensor(tensor, grid)
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _maybe_snap_to_grid method is called twice per forward pass (once for attention, once for MLP). Each call performs a dictionary lookup via get_grid. Consider caching the grid lookups during initialization or at the start of the forward pass to avoid repeated dictionary access in the hot path.

Copilot uses AI. Check for mistakes.
Comment on lines +179 to +181
if grid is None:
return x

Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grid is None check at line 179 is redundant because callers in Block._maybe_snap_to_grid already return early when grid is None (line 485). This guard can be removed to simplify the function and avoid unnecessary checks in the hot path.

Suggested change
if grid is None:
return x

Copilot uses AI. Check for mistakes.
Comment on lines +551 to +560
for snap_size in snap_sizes:
if snap_size is None:
self.raw_model.set_snap_to_grid_registry(base_registry)
self.console.print("[bold cyan]Sampling without snap-to-grid[/bold cyan]")
else:
encode_fn = self.encode
decode_fn = self.decode
registry = generate_snap_to_grid_registry(self.raw_model, layers, component, snap_size)
self.raw_model.set_snap_to_grid_registry(registry)
self.console.print(f"[bold cyan]Sampling with snap-to-grid size {snap_size}[/bold cyan]")

for i in range(sample_iterations):
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The nested loop structure (snap sizes → sample iterations) creates deep nesting and makes the sampling logic harder to follow. Consider extracting the inner sampling loop into a separate method like _run_sample_iteration(snap_size) to improve readability and maintainability.

Copilot uses AI. Check for mistakes.
for component in component_list:
source = _gather_base_vectors(model, layer_idx, component)
if source is None or source.numel() == 0:
continue
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grid generation logic at lines 168-170 creates random linear combinations of normalized source vectors, then re-normalizes. This process and its mathematical rationale (why random combinations are normalized twice) should be documented with a comment explaining the intended hypersphere sampling strategy.

Suggested change
continue
continue
# Generate random linear combinations of normalized source vectors,
# then re-normalize the result. This double normalization ensures
# that the resulting vectors are uniformly distributed on the unit
# hypersphere spanned by the source vectors (hypersphere sampling).

Copilot uses AI. Check for mistakes.
Comment on lines +617 to +618
if not sizes and not getattr(self.args, 'enable_snap_to_grid', False):
return
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 617-620 contain redundant early-return conditions. If sizes is empty (falsy), the function returns at line 620, making the additional check and not getattr(self.args, 'enable_snap_to_grid', False) at line 617 unnecessary. Simplify to a single check: if not sizes: return.

Suggested change
if not sizes and not getattr(self.args, 'enable_snap_to_grid', False):
return

Copilot uses AI. Check for mistakes.
"Number of --multicontext_datasets must match number of --multicontext_start strings."
)
if args.interactive and len(snap_sizes) > 1:
print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.")
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning message at line 1352 is printed using print() while other user-facing messages use console.print(). For consistency with the rest of the codebase and to maintain styled output, this should use the console object defined earlier in the function.

Suggested change
print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.")
console.print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.")

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant