-
Notifications
You must be signed in to change notification settings - Fork 27
Add snap to hypersphere grid feature #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add snap to hypersphere grid feature #659
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds snap-to-grid projection capabilities to the transformer model, enabling quantized projections of activations onto a learned hypersphere grid before attention and MLP pre-normalization layers. The feature supports per-layer and per-component configuration, registry-based grid management, and evaluation across multiple grid sizes.
Key changes:
- Added snap-to-grid projection infrastructure with
SnapToGridRegistryfor managing grids andapply_snap_to_grid_tensorfor projecting activations - Integrated snap-to-grid into the
Blockforward pass with configurable layer and component targeting - Extended training and sampling workflows to generate, evaluate, and persist snap-to-grid registries across multiple grid sizes
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
variations/block_variations.py |
Added _maybe_snap_to_grid method and snap-to-grid calls before attention/MLP pre-norms; added layer_idx parameter to Block.__init__ |
utils/snap_to_grid.py |
New utility module providing registry management, grid generation from model weights, and activation projection functions |
train_args.py |
Added command-line arguments for snap-to-grid configuration (enable flag, layers, components, sizes) |
train.py |
Normalized snap-to-grid config in setup; added registry evaluation loop at checkpoints; integrated multi-size sampling |
sample.py |
Added snap-to-grid arguments; implemented multi-size evaluation and generation loops with registry persistence |
model.py |
Added validation for n_embd divisibility; initialized and propagated SnapToGridRegistry through blocks; added registry setter methods |
gpt_conf.py |
Added snap-to-grid configuration fields to GPTConfig dataclass |
demos/snap_to_grid_demo.sh |
New demo script training and evaluating a model with snap-to-grid enabled |
demos/README.md |
Added documentation for snap-to-grid demo usage and configuration |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| def _maybe_snap_to_grid(self, component: str, tensor: torch.Tensor) -> torch.Tensor: | ||
| if self.snap_to_grid_registry is None or self.layer_idx is None: | ||
| return tensor | ||
| grid = self.snap_to_grid_registry.get_grid(self.layer_idx, component) | ||
| if grid is None: | ||
| return tensor | ||
| return apply_snap_to_grid_tensor(tensor, grid) |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _maybe_snap_to_grid method is called twice per forward pass (once for attention, once for MLP). Each call performs a dictionary lookup via get_grid. Consider caching the grid lookups during initialization or at the start of the forward pass to avoid repeated dictionary access in the hot path.
| if grid is None: | ||
| return x | ||
|
|
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grid is None check at line 179 is redundant because callers in Block._maybe_snap_to_grid already return early when grid is None (line 485). This guard can be removed to simplify the function and avoid unnecessary checks in the hot path.
| if grid is None: | |
| return x |
| for snap_size in snap_sizes: | ||
| if snap_size is None: | ||
| self.raw_model.set_snap_to_grid_registry(base_registry) | ||
| self.console.print("[bold cyan]Sampling without snap-to-grid[/bold cyan]") | ||
| else: | ||
| encode_fn = self.encode | ||
| decode_fn = self.decode | ||
| registry = generate_snap_to_grid_registry(self.raw_model, layers, component, snap_size) | ||
| self.raw_model.set_snap_to_grid_registry(registry) | ||
| self.console.print(f"[bold cyan]Sampling with snap-to-grid size {snap_size}[/bold cyan]") | ||
|
|
||
| for i in range(sample_iterations): |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The nested loop structure (snap sizes → sample iterations) creates deep nesting and makes the sampling logic harder to follow. Consider extracting the inner sampling loop into a separate method like _run_sample_iteration(snap_size) to improve readability and maintainability.
| for component in component_list: | ||
| source = _gather_base_vectors(model, layer_idx, component) | ||
| if source is None or source.numel() == 0: | ||
| continue |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grid generation logic at lines 168-170 creates random linear combinations of normalized source vectors, then re-normalizes. This process and its mathematical rationale (why random combinations are normalized twice) should be documented with a comment explaining the intended hypersphere sampling strategy.
| continue | |
| continue | |
| # Generate random linear combinations of normalized source vectors, | |
| # then re-normalize the result. This double normalization ensures | |
| # that the resulting vectors are uniformly distributed on the unit | |
| # hypersphere spanned by the source vectors (hypersphere sampling). |
| if not sizes and not getattr(self.args, 'enable_snap_to_grid', False): | ||
| return |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 617-620 contain redundant early-return conditions. If sizes is empty (falsy), the function returns at line 620, making the additional check and not getattr(self.args, 'enable_snap_to_grid', False) at line 617 unnecessary. Simplify to a single check: if not sizes: return.
| if not sizes and not getattr(self.args, 'enable_snap_to_grid', False): | |
| return |
| "Number of --multicontext_datasets must match number of --multicontext_start strings." | ||
| ) | ||
| if args.interactive and len(snap_sizes) > 1: | ||
| print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.") |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning message at line 1352 is printed using print() while other user-facing messages use console.print(). For consistency with the rest of the codebase and to maintain styled output, this should use the console object defined earlier in the function.
| print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.") | |
| console.print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.") |
This pull request introduces support for "snap-to-grid" projections throughout the codebase, enabling quantized projections in transformer models for both training and inference. It adds user-facing controls, updates model configuration and initialization, and ensures that snap-to-grid can be flexibly applied and evaluated. The changes are grouped below by theme.
Snap-to-Grid Feature Implementation
enable_snap_to_grid,snap_to_grid_layers,snap_to_grid_components,snap_to_grid_sizes) to theGPTConfigclass ingpt_conf.py, allowing fine-grained control over which layers and components use quantized projections.SnapToGridRegistryinto model initialization inmodel.py, ensuring that the registry is created and attached to each transformer block when enabled, and added methods for setting and applying registries dynamically. [1] [2] [3]Blockinitialization inmodel.pyto accept alayer_idxparameter, supporting per-layer snap-to-grid configuration.User-Facing Controls and Documentation
demos/snap_to_grid_demo.sh, and updated the documentation indemos/README.mdto explain how to run snap-to-grid experiments, override settings, and interpret results. [1] [2]Training and Inference Integration
train.pyandsample.pyto parse snap-to-grid arguments, normalize configuration, and generate/apply snap-to-grid registries for both training and sampling. This includes saving registry files and iterating over multiple grid sizes for evaluation and generation. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10]Validation and Error Handling
n_embdis divisible byn_head, preventing invalid configurations for multi-head attention when snap-to-grid is enabled.Utility Integration
generate_snap_to_grid_registry,save_registry) into relevant scripts to support registry creation and persistence. [1] [2]These changes collectively provide a robust framework for experimenting with snap-to-grid projections, including configuration, evaluation, and reproducibility.