diff --git a/demos/README.md b/demos/README.md index e7f610a756..f890ee4707 100644 --- a/demos/README.md +++ b/demos/README.md @@ -21,3 +21,25 @@ python3 demos/check_ckpt_for_gelu_shift.py \ `adam_vs_adamw.sh` trains two tiny Shakespeare models, one with Adam and one with AdamW, then compares their statistics using `view_model_stats.py`. + +## Snap-to-Grid Projections + +`snap_to_grid_demo.sh` prepares the Shakespeare character dataset, trains a +small model with snap-to-grid enabled, evaluates multiple grid sizes, and then +generates text with and without the projections. Run it from the repository +root: + +```bash +bash demos/snap_to_grid_demo.sh +``` + +You can override the default model or snap-to-grid settings by exporting +environment variables before running the script, for example: + +```bash +N_HEAD=3 N_EMBD=384 SNAP_SIZES="100 1000 10000" bash demos/snap_to_grid_demo.sh +``` + +Ensure that `N_EMBD` remains divisible by `N_HEAD`; otherwise the +multi-head attention projection will be invalid and the script will exit with +an explanatory error. diff --git a/demos/snap_to_grid_demo.sh b/demos/snap_to_grid_demo.sh new file mode 100755 index 0000000000..4df16ca0c4 --- /dev/null +++ b/demos/snap_to_grid_demo.sh @@ -0,0 +1,119 @@ +#!/bin/bash +# snap_to_grid_demo.sh +# Demonstrates training and sampling with snap-to-grid projections. + +set -euo pipefail + +DATASET="${DATASET:-shakespeare_char}" +DATA_DIR="data/${DATASET}" +OUT_DIR="${OUT_DIR:-out/snap_to_grid_demo}" +SNAP_SIZES_STR="${SNAP_SIZES:-"8 32"}" +IFS=' ' read -r -a SNAP_SIZES <<< "${SNAP_SIZES_STR}" + +# Model hyperparameters (override via env vars, e.g. `N_HEAD=3 ./snap_to_grid_demo.sh`). +N_LAYER=${N_LAYER:-4} +N_HEAD=${N_HEAD:-4} +N_EMBD=${N_EMBD:-128} + +if (( N_EMBD % N_HEAD != 0 )); then + echo "error: N_EMBD (${N_EMBD}) must be divisible by N_HEAD (${N_HEAD})." >&2 + echo "update N_EMBD or N_HEAD before running the demo." >&2 + exit 1 +fi + +mkdir -p "${OUT_DIR}" + +echo "=== Step 1: Ensure the ${DATASET} dataset is prepared ===" +if [ ! -f "${DATA_DIR}/train.bin" ] || [ ! -f "${DATA_DIR}/val.bin" ]; then + pushd "${DATA_DIR}" > /dev/null + if [ ! -f "input.txt" ]; then + echo "Downloading Shakespeare corpus..." + bash get_dataset.sh + fi + echo "Tokenizing dataset with tiktoken encoder..." + python3 prepare.py -t input.txt --method tiktoken + popd > /dev/null +else + echo "Found existing tokenized dataset artifacts." +fi + +CKPT_PATH="${OUT_DIR}/ckpt.pt" + +cat <&2 + exit 1 +fi + +cat < None: + for block in self.transformer['h']: + block.snap_to_grid_registry = registry + + def set_snap_to_grid_registry(self, registry: SnapToGridRegistry | None) -> None: + self.snap_to_grid_registry = registry + self.config.snap_to_grid_registry = registry + self._apply_snap_to_grid_registry(registry) + def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. diff --git a/sample.py b/sample.py index b27f115017..2fb17bd6b3 100644 --- a/sample.py +++ b/sample.py @@ -32,6 +32,7 @@ import lm_eval from benchmarks.gpt_lm_eval_wrapper import NanoGPTLM from benchmarks import run_all +from utils.snap_to_grid import generate_snap_to_grid_registry, save_registry def parse_args(): parser = argparse.ArgumentParser(description="Inference from trained models") @@ -96,6 +97,18 @@ def parse_args(): + # Snap-to-grid controls + parser.add_argument('--enable_snap_to_grid', default=False, action=argparse.BooleanOptionalAction, + help="Enable snap-to-grid projections before attention/MLP pre-norms.") + parser.add_argument('--snap_to_grid_layers', type=int, nargs='+', default=None, + help="Zero-indexed layers (>=1) to apply snap-to-grid. Defaults to all layers except the first.") + parser.add_argument('--snap_to_grid_components', type=str, default='both', choices=['attn', 'mlp', 'both'], + help="Apply snap-to-grid to attention, MLP, or both components.") + parser.add_argument('--snap_to_grid_sizes', type=int, nargs='+', default=None, + help="Number of random combinations per location. Multiple values evaluate each size independently.") + + + # Steering Vector Related parser.add_argument('--save_avg_vector', type=str, default=None, help="Path to save the average vector of the start text to an .npy file") parser.add_argument('--apply_vector_file1', type=str, default=None, help="First .npy file to load the vector for subtraction") @@ -1267,6 +1280,15 @@ def main(): ) return + snap_sizes = args.snap_to_grid_sizes or [] + base_registry = getattr(model, 'snap_to_grid_registry', None) + if args.snap_to_grid_layers is not None: + snap_layers = args.snap_to_grid_layers + else: + snap_layers = list(range(1, getattr(model.config, 'n_layer', 0))) + snap_component = args.snap_to_grid_components + dataset_idx = 0 + if args.eval_only: print("Running in eval_only mode...") dataset_name = args.eval_dataset @@ -1284,187 +1306,225 @@ def main(): print(f"Using validation dataset: {dataset_name}") print(f"Model block size: {model.config.block_size}") val_data = load_validation_data(model.config.block_size, dataset_name) - metrics = calculate_validation_loss( - model, - val_data, - model.config.block_size, - args.eval_iters, - args.device, - ptdtype, - ) + size_values = snap_sizes or [None] + for snap_size in size_values: + if snap_size is None: + model.set_snap_to_grid_registry(base_registry) + print("Evaluating without snap-to-grid") + else: + registry = generate_snap_to_grid_registry(model, snap_layers, snap_component, snap_size) + model.set_snap_to_grid_registry(registry) + snap_dir = os.path.join(args.out_dir, 'snap_to_grid') + filename = f"eval_{timestamp}_size{snap_size}.pt" + save_registry(os.path.join(snap_dir, filename), registry) + print(f"Evaluating with snap-to-grid size {snap_size}") + + metrics = calculate_validation_loss( + model, + val_data, + model.config.block_size, + args.eval_iters, + args.device, + ptdtype, + ) - val_loss = metrics.get("val", float("nan")) - print(f"Validation Loss: {val_loss:.4f}") - if metrics.get("elapsed_time_s") is not None: - print(f"Elapsed time: {metrics['elapsed_time_s']:.4f} seconds") - - summary: Dict[str, object] = dict(metrics) - summary.setdefault("eval_dataset", dataset_name) - summary.setdefault("timestamp", timestamp) - summary.setdefault("out_dir", args.out_dir) - summary.setdefault("init_from", args.init_from) - - write_eval_summary( - args.out_dir, - summary, - extra_dirs=[out_dir], - ) + val_loss = metrics.get("val", float("nan")) + print(f"Validation Loss (size={snap_size or 'baseline'}): {val_loss:.4f}") + if metrics.get("elapsed_time_s") is not None: + print(f"Elapsed time: {metrics['elapsed_time_s']:.4f} seconds") + + summary: Dict[str, object] = dict(metrics) + summary.setdefault("eval_dataset", dataset_name) + summary.setdefault("timestamp", timestamp) + summary.setdefault("out_dir", args.out_dir) + summary.setdefault("init_from", args.init_from) + summary["snap_to_grid_size"] = snap_size + + write_eval_summary( + args.out_dir, + summary, + extra_dirs=[out_dir], + ) + model.set_snap_to_grid_registry(base_registry) return - x = torch.tensor(start_ids, dtype=torch.long, device=args.device)[None, ...] - # Obtain vector from the specified layer and save it to a file if required - if args.save_avg_vector: - x = torch.tensor(start_ids, dtype=torch.long, device=args.device)[None, ...] - # Run the model to trigger vector extraction - with torch.no_grad(): - with ctx: - block_size = args.block_size if args.block_size else model.config.block_size - idx_cond = x if x.size(1) <= block_size else x[:, -block_size:] - logits, _ = model(idx_cond, dataset_idx=dataset_idx) - print(f"Obtained vector saved to {args.save_avg_vector}") - - if args.interactive: - interactive_generation(model, start_ids, args.device, args.max_new_tokens, args.temperature, args.top_k, args.stop_strings, decode, encode) - elif args.multicontext: - if not args.multicontext_datasets: - raise ValueError("Must specify --multicontext_datasets when using --multicontext") - if args.multicontext_start is None: - raise ValueError("Must specify --multicontext_start when using --multicontext") - if len(args.multicontext_datasets) != len(args.multicontext_start): - raise ValueError( - "Number of --multicontext_datasets must match number of --multicontext_start strings." - ) + if args.interactive and len(snap_sizes) > 1: + print("Multiple snap-to-grid sizes requested for interactive mode; only the first size will be used.") + snap_sizes = snap_sizes[:1] - dataset_names = list(args.multicontext_datasets) - start_strings = list(args.multicontext_start) - - dataset_meta: Dict[str, Dict[str, object]] = {} - decode_lookup: Dict[str, Callable[[Sequence[int]], str]] = {} - initial_tokens: Dict[str, torch.Tensor] = {} - - for dataset_name, start_str in zip(dataset_names, start_strings): - meta_path = os.path.join("data", dataset_name, "meta.pkl") - if not os.path.exists(meta_path): - raise FileNotFoundError(f"meta.pkl not found at {meta_path}") - with open(meta_path, "rb") as f: - dataset_meta[dataset_name] = pickle.load(f) - - encode_i, decode_i = get_tokenizer_functions(dataset_meta[dataset_name]) - token_ids = encode_i(start_str) - if len(token_ids) == 0: - if dataset_meta[dataset_name].get('tokenizer') == 'sinewave': - print( - f"Start string for dataset '{dataset_name}' produced no tokens; defaulting to '0'." - ) - token_ids = [0] - else: - raise ValueError( - f"Start string for dataset '{dataset_name}' produced no tokens. " - "Provide a valid prompt or comma-separated values for numerical tokenizers." - ) + size_values = snap_sizes or [None] - token_tensor = torch.tensor(token_ids, dtype=torch.long, device=args.device)[None, ...] - initial_tokens[dataset_name] = token_tensor - decode_lookup[dataset_name] = decode_i + def execute_generation_for_current_registry(): + local_start = torch.tensor(start_ids, dtype=torch.long, device=args.device)[None, ...] + if args.save_avg_vector: + with torch.no_grad(): + with ctx: + block_size = args.block_size if args.block_size else model.config.block_size + idx_cond = local_start if local_start.size(1) <= block_size else local_start[:, -block_size:] + model(idx_cond, dataset_idx=dataset_idx) + print(f"Obtained vector saved to {args.save_avg_vector}") + + if args.interactive: + interactive_generation(model, start_ids, args.device, args.max_new_tokens, args.temperature, args.top_k, args.stop_strings, decode, encode) + return + + if args.multicontext: + if not args.multicontext_datasets: + raise ValueError("Must specify --multicontext_datasets when using --multicontext") + if args.multicontext_start is None: + raise ValueError("Must specify --multicontext_start when using --multicontext") + if len(args.multicontext_datasets) != len(args.multicontext_start): + raise ValueError( + "Number of --multicontext_datasets must match number of --multicontext_start strings." + ) - block_size = args.block_size if args.block_size else model.config.block_size - with torch.no_grad(), ctx: - for sample_idx in range(args.num_samples): - if args.use_lsv and hasattr(args, 'lsv_size'): - model.set_lsv_index(sample_idx % args.lsv_size) - if args.lsv_scaling_factor is not None: - model.set_lsv_scaling_factor(args.lsv_scaling_factor) - if args.lsv_mixture is not None: - model.set_lsv_mode(2) - model.set_lsv_mixture(args.lsv_mixture) + dataset_names = list(args.multicontext_datasets) + start_strings = list(args.multicontext_start) + + dataset_meta: Dict[str, Dict[str, object]] = {} + decode_lookup: Dict[str, Callable[[Sequence[int]], str]] = {} + initial_tokens: Dict[str, torch.Tensor] = {} + + for dataset_name, start_str in zip(dataset_names, start_strings): + meta_path = os.path.join("data", dataset_name, "meta.pkl") + if not os.path.exists(meta_path): + raise FileNotFoundError(f"meta.pkl not found at {meta_path}") + with open(meta_path, "rb") as f: + dataset_meta[dataset_name] = pickle.load(f) + + encode_i, decode_i = get_tokenizer_functions(dataset_meta[dataset_name]) + token_ids = encode_i(start_str) + if len(token_ids) == 0: + if dataset_meta[dataset_name].get('tokenizer') == 'sinewave': + print( + f"Start string for dataset '{dataset_name}' produced no tokens; defaulting to '0'." + ) + token_ids = [0] else: - model.set_lsv_mode(1) - - token_state = {name: tensor.clone() for name, tensor in initial_tokens.items()} + raise ValueError( + f"Start string for dataset '{dataset_name}' produced no tokens. " + "Provide a valid prompt or comma-separated values for numerical tokenizers." + ) - for _ in range(args.max_new_tokens): - idx_cond_dict = {} - for name in dataset_names: - tokens = token_state[name] - idx_cond_dict[name] = tokens if tokens.size(1) <= block_size else tokens[:, -block_size:] - - logits_list, _ = model(None, token_dict=idx_cond_dict, target_dict=None) - - for i, name in enumerate(dataset_names): - if model.config.numerical_multicontext: - preds = logits_list[i][:, -1] - preds = preds.squeeze(-1) - if preds.ndim == 0: - preds = preds.unsqueeze(0) - rounded = preds.round() - min_val = 0.0 - max_val = None - meta_info = dataset_meta.get(name, {}) - tokenizer_name = meta_info.get('tokenizer') if isinstance(meta_info, dict) else None - if tokenizer_name == 'sinewave': - max_val = 255.0 - elif isinstance(meta_info, dict) and 'vocab_size' in meta_info: - max_val = float(meta_info['vocab_size'] - 1) - - if max_val is not None: - rounded = torch.clamp(rounded, min=min_val, max=max_val) + token_tensor = torch.tensor(token_ids, dtype=torch.long, device=args.device)[None, ...] + initial_tokens[dataset_name] = token_tensor + decode_lookup[dataset_name] = decode_i + + block_size = args.block_size if args.block_size else model.config.block_size + with torch.no_grad(), ctx: + for sample_idx in range(args.num_samples): + if args.use_lsv and hasattr(args, 'lsv_size'): + model.set_lsv_index(sample_idx % args.lsv_size) + if args.lsv_scaling_factor is not None: + model.set_lsv_scaling_factor(args.lsv_scaling_factor) + if args.lsv_mixture is not None: + model.set_lsv_mode(2) + model.set_lsv_mixture(args.lsv_mixture) + else: + model.set_lsv_mode(1) + + token_state = {name: tensor.clone() for name, tensor in initial_tokens.items()} + + for _ in range(args.max_new_tokens): + idx_cond_dict = {} + for name in dataset_names: + tokens = token_state[name] + idx_cond_dict[name] = tokens if tokens.size(1) <= block_size else tokens[:, -block_size:] + + logits_list, _ = model(None, token_dict=idx_cond_dict, target_dict=None) + + for i, name in enumerate(dataset_names): + if model.config.numerical_multicontext: + preds = logits_list[i][:, -1] + preds = preds.squeeze(-1) + if preds.ndim == 0: + preds = preds.unsqueeze(0) + rounded = preds.round() + min_val = 0.0 + max_val = None + meta_info = dataset_meta.get(name, {}) + tokenizer_name = meta_info.get('tokenizer') if isinstance(meta_info, dict) else None + if tokenizer_name == 'sinewave': + max_val = 255.0 + elif isinstance(meta_info, dict) and 'vocab_size' in meta_info: + max_val = float(meta_info['vocab_size'] - 1) + + if max_val is not None: + rounded = torch.clamp(rounded, min=min_val, max=max_val) + else: + rounded = torch.clamp(rounded, min=min_val) + + idx_next = rounded.to(torch.long).unsqueeze(-1) else: - rounded = torch.clamp(rounded, min=min_val) + cur_logits = logits_list[i][:, -1, :] / args.temperature + if args.top_k is not None: + top_k_val = ( + args.top_k[0] + if isinstance(args.top_k, (list, tuple)) + else args.top_k + ) + k = min(top_k_val, cur_logits.size(-1)) + v, _ = torch.topk(cur_logits, k) + cur_logits[cur_logits < v[:, [-1]]] = -float("inf") + + probs = F.softmax(cur_logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1) + + token_state[name] = torch.cat((token_state[name], idx_next), dim=1) + + output_dict: Dict[str, str] = {} + for name in dataset_names: + decode_fn = decode_lookup[name] + output_dict[name] = decode_fn(token_state[name][0].tolist()) + + for name, text in output_dict.items(): + key_color = "bold light_slate_blue" + text_color = "bold cyan" + print(f"\n[{key_color}]{name}:[/{key_color}]\n[{text_color}]{text}[/{text_color}]") + print("---------------") + + if args.sample_file: + with open(args.sample_file, "w") as file: + for name, text in output_dict.items(): + file.write(f"\n{name}: \n{text}\n") + return - idx_next = rounded.to(torch.long).unsqueeze(-1) - else: - cur_logits = logits_list[i][:, -1, :] / args.temperature - if args.top_k is not None: - top_k_val = ( - args.top_k[0] - if isinstance(args.top_k, (list, tuple)) - else args.top_k - ) - k = min(top_k_val, cur_logits.size(-1)) - v, _ = torch.topk(cur_logits, k) - cur_logits[cur_logits < v[:, [-1]]] = -float("inf") - - probs = F.softmax(cur_logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) - - token_state[name] = torch.cat((token_state[name], idx_next), dim=1) - - output_dict: Dict[str, str] = {} - for name in dataset_names: - decode_fn = decode_lookup[name] - output_dict[name] = decode_fn(token_state[name][0].tolist()) - - for name, text in output_dict.items(): - key_color = "bold light_slate_blue" - text_color = "bold cyan" - print(f"\n[{key_color}]{name}:[/{key_color}]\n[{text_color}]{text}[/{text_color}]") - print("---------------") - - if args.sample_file: - with open(args.sample_file, "w") as file: - for name, text in output_dict.items(): - file.write(f"\n{name}: \n{text}\n") - else: sample_with_existing_model( - model, - torch.tensor(start_ids, dtype=torch.long, device=args.device)[None, ...], - decode, - device=args.device, - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - num_samples=args.num_samples, - colorize_output=args.colorize_output, - colorize_mode=args.colorize_mode, - token_boundary=args.token_boundary, - show_heatmaps=args.show_heatmaps, - chart_type=args.chart_type, - last_k_tokens=args.last_k_tokens, - out_dir=out_dir, - sample_file=args.sample_file, - args=args, - dataset_idx=0, - ) + model, + local_start, + decode, + device=args.device, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + num_samples=args.num_samples, + colorize_output=args.colorize_output, + colorize_mode=args.colorize_mode, + token_boundary=args.token_boundary, + show_heatmaps=args.show_heatmaps, + chart_type=args.chart_type, + last_k_tokens=args.last_k_tokens, + out_dir=out_dir, + sample_file=args.sample_file, + args=args, + dataset_idx=0, + ) + + for snap_size in size_values: + if snap_size is None: + model.set_snap_to_grid_registry(base_registry) + print("Sampling without snap-to-grid") + else: + registry = generate_snap_to_grid_registry(model, snap_layers, snap_component, snap_size) + model.set_snap_to_grid_registry(registry) + snap_dir = os.path.join(args.out_dir, 'snap_to_grid') + filename = f"sample_{timestamp}_size{snap_size}.pt" + save_registry(os.path.join(snap_dir, filename), registry) + print(f"Sampling with snap-to-grid size {snap_size}") + + execute_generation_for_current_registry() + + model.set_snap_to_grid_registry(base_registry) if __name__ == "__main__": main() diff --git a/train.py b/train.py index a9dc592977..0ba25a3931 100644 --- a/train.py +++ b/train.py @@ -51,6 +51,8 @@ get_tokenizer_functions, ) +from utils.snap_to_grid import generate_snap_to_grid_registry, save_registry + from rich.progress import ( Progress, TextColumn, @@ -236,6 +238,22 @@ def setup(self): self.model_args['vocab_size'] = None self.model_args['eval_interval'] = self.args.eval_interval + # Normalise snap-to-grid configuration + snap_sizes = getattr(self.args, 'snap_to_grid_sizes', None) + snap_layers = getattr(self.args, 'snap_to_grid_layers', None) + if snap_sizes: + self.model_args['enable_snap_to_grid'] = True + self.model_args['snap_to_grid_sizes'] = snap_sizes + else: + self.model_args['snap_to_grid_sizes'] = snap_sizes or [] + if getattr(self.args, 'enable_snap_to_grid', False): + self.model_args['enable_snap_to_grid'] = True + if snap_layers is not None: + self.model_args['snap_to_grid_layers'] = snap_layers + else: + self.model_args['snap_to_grid_layers'] = [] + self.model_args['snap_to_grid_components'] = getattr(self.args, 'snap_to_grid_components', 'both') + # Training settings self.training_args = {action.dest: getattr(self.args, action.dest) for action in self.training_group._group_actions} if self.args.dataset_list is not None: @@ -517,50 +535,69 @@ def sample_and_print(self): # Do one iteration per lsv, default to one with no lsv sample_iterations = 1 + snap_sizes = getattr(self.args, 'snap_to_grid_sizes', None) or [] + if not snap_sizes: + snap_sizes = [None] + + base_registry = getattr(self.raw_model, 'snap_to_grid_registry', None) + layers = self._get_snap_to_grid_layers() + component = getattr(self.args, 'snap_to_grid_components', 'both') + self.model.eval() if self.args.dataset_list is not None: sample_iterations = len(self.args.dataset_list) - for i in range(sample_iterations): - if self.args.use_lsv: - self.model.set_lsv_index(i) - print(f"lsv index {i}") - - if hasattr(self, 'encode_dict'): - encode_fn = self.encode_dict[self.args.dataset_list[i]] - decode_fn = self.decode_dict[self.args.dataset_list[i]] + for snap_size in snap_sizes: + if snap_size is None: + self.raw_model.set_snap_to_grid_registry(base_registry) + self.console.print("[bold cyan]Sampling without snap-to-grid[/bold cyan]") else: - encode_fn = self.encode - decode_fn = self.decode + registry = generate_snap_to_grid_registry(self.raw_model, layers, component, snap_size) + self.raw_model.set_snap_to_grid_registry(registry) + self.console.print(f"[bold cyan]Sampling with snap-to-grid size {snap_size}[/bold cyan]") + + for i in range(sample_iterations): + if self.args.use_lsv: + self.model.set_lsv_index(i) + print(f"lsv index {i}") + + if hasattr(self, 'encode_dict'): + encode_fn = self.encode_dict[self.args.dataset_list[i]] + decode_fn = self.decode_dict[self.args.dataset_list[i]] + else: + encode_fn = self.encode + decode_fn = self.decode + + start_ids = torch.tensor(encode_fn(self.args.sample_start_tokens), dtype=torch.long, device=self.device)[None, ...] + + with torch.no_grad(): + sample_with_existing_model( + model=self.model, + start_ids=start_ids, + start_tokens=self.args.sample_start_tokens, + decode=decode_fn, + device=self.device, + out_dir=self.args.out_dir, + max_new_tokens=self.args.max_sample_tokens, + temperature=self.args.temperature, + top_k=self.args.top_k, + colorize_output=self.args.colorize_output, + colorize_mode=self.args.colorize_mode, + token_boundary=(self.args.token_boundary or None), + show_heatmaps=self.args.show_heatmaps, + sample_file=self.args.sample_file, + num_samples=self.args.num_samples, + iter_num=self.iter_num, + best_val_loss=self.best_val_loss, + run_name=self.args.tensorboard_run_name, + args=self.args, + writer=self.writer if self.args.tensorboard_log else None, + dataset_idx=i if hasattr(self, 'encode_dict') else None, + console=self.console, + ) - start_ids = torch.tensor(encode_fn(self.args.sample_start_tokens), dtype=torch.long, device=self.device)[None, ...] - - with torch.no_grad(): - sample_with_existing_model( - model=self.model, - start_ids=start_ids, - start_tokens=self.args.sample_start_tokens, - decode=decode_fn, - device=self.device, - out_dir=self.args.out_dir, - max_new_tokens=self.args.max_sample_tokens, - temperature=self.args.temperature, - top_k=self.args.top_k, - colorize_output=self.args.colorize_output, - colorize_mode=self.args.colorize_mode, - token_boundary=(self.args.token_boundary or None), - show_heatmaps=self.args.show_heatmaps, - sample_file=self.args.sample_file, - num_samples=self.args.num_samples, - iter_num=self.iter_num, - best_val_loss=self.best_val_loss, - run_name=self.args.tensorboard_run_name, - args=self.args, - writer=self.writer if self.args.tensorboard_log else None, - dataset_idx=i if hasattr(self, 'encode_dict') else None, - console=self.console, - ) + self.raw_model.set_snap_to_grid_registry(base_registry) # After sampling from the model, optionally run simple dataset benchmarks if self.args.dataset_benchmarks and self.args.max_sample_tokens: @@ -570,6 +607,41 @@ def sample_and_print(self): self.console.rule("[bold green]End Samples[/bold green]") self.console.print("\n"*8) + def _get_snap_to_grid_layers(self): + if self.args.snap_to_grid_layers is not None: + return self.args.snap_to_grid_layers + return list(range(1, getattr(self.raw_model.config, 'n_layer', 0))) + + def run_snap_to_grid_evaluations(self): + sizes = getattr(self.args, 'snap_to_grid_sizes', None) or [] + if not sizes and not getattr(self.args, 'enable_snap_to_grid', False): + return + if not sizes: + return + + layers = self._get_snap_to_grid_layers() + component = getattr(self.args, 'snap_to_grid_components', 'both') + + base_registry = getattr(self.raw_model, 'snap_to_grid_registry', None) + + for size in sizes: + registry = generate_snap_to_grid_registry(self.raw_model, layers, component, size) + self.raw_model.set_snap_to_grid_registry(registry) + losses = self.estimate_loss() + val_loss = losses.get('val') + if self.writer is not None: + self.writer.add_scalar(f'snap_to_grid/val_loss_size_{size}', val_loss, self.iter_num) + + snap_dir = os.path.join(self.args.out_dir, 'snap_to_grid') + filename = f"iter_{self.iter_num:06d}_size{size}.pt" + save_registry(os.path.join(snap_dir, filename), registry) + + self.console.print( + f"[bold cyan]Snap-to-grid[/bold cyan] size {size}: val loss {val_loss:.4f}" + ) + + self.raw_model.set_snap_to_grid_registry(base_registry) + def get_vocab_size_from_meta(self): # Data loader meta_path = os.path.join('data', self.args.dataset, 'meta.pkl') @@ -1648,6 +1720,7 @@ def train(self): print(f"saving checkpoint to {self.args.out_dir}") # Save checkpoint self.save_checkpoint('ckpt.pt') + self.run_snap_to_grid_evaluations() # Sample if self.args.max_sample_tokens: diff --git a/train_args.py b/train_args.py index d12d5f24da..9d76a21eef 100644 --- a/train_args.py +++ b/train_args.py @@ -1219,6 +1219,16 @@ def parse_args(): model_group.add_argument('--softmax_io_logging', default=False, action=argparse.BooleanOptionalAction, help="logs inputs and outputs of supported softmaxes") model_group.add_argument('--softmax_io_log_interval', default=1, type=int) model_group.add_argument('--consmax_beta_gamma_logging', default=False, action=argparse.BooleanOptionalAction, help="logs beta and gamma") + + # Snap-to-grid + model_group.add_argument('--enable_snap_to_grid', default=False, action=argparse.BooleanOptionalAction, + help='Enable snap-to-grid projections before pre-norm in attention/MLP layers (experimental).') + model_group.add_argument('--snap_to_grid_layers', type=int, nargs='+', default=None, + help='Zero-indexed layer IDs (>=1) that should use snap-to-grid. Defaults to all layers except the first.') + model_group.add_argument('--snap_to_grid_components', type=str, default='both', choices=['attn', 'mlp', 'both'], + help='Apply snap-to-grid to attention, MLP, or both components.') + model_group.add_argument('--snap_to_grid_sizes', type=int, nargs='+', default=None, + help='Number of random combinations to generate per location. Multiple values trigger evaluation for each size.') logging_group.add_argument('--create_statistics', default=False, action=argparse.BooleanOptionalAction) logging_group.add_argument('--plot_statistics', default=False, action=argparse.BooleanOptionalAction) diff --git a/utils/snap_to_grid.py b/utils/snap_to_grid.py new file mode 100644 index 0000000000..8d0a469992 --- /dev/null +++ b/utils/snap_to_grid.py @@ -0,0 +1,202 @@ +"""Utilities for constructing and applying snap-to-grid projections.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + + +COMP_ATTENTION = "attn" +COMP_MLP = "mlp" + + +def _normalize_rows(weight: torch.Tensor, target_dim: Optional[int] = None) -> torch.Tensor: + """Return a row-wise L2 normalised view of ``weight``. + + When ``target_dim`` is provided the function ensures that the returned + matrix has exactly that many columns by transposing the weight matrix when + necessary. This prevents concatenation errors when mixing vectors whose + natural orientation differs (e.g. attention vs. MLP projection weights). + """ + + tensor = weight.detach().float().cpu() + + if target_dim is None: + matrix = tensor.reshape(tensor.shape[0], -1) + else: + # Prefer views whose trailing dimension already matches ``target_dim``. + if tensor.ndim >= 2 and tensor.shape[-1] == target_dim: + matrix = tensor.reshape(-1, target_dim) + elif tensor.ndim >= 2 and tensor.shape[0] == target_dim: + matrix = tensor.transpose(0, 1).reshape(-1, target_dim) + else: + flat = tensor.reshape(tensor.shape[0], -1) + if flat.shape[-1] != target_dim: + raise ValueError( + f"Unable to reshape tensor with shape {tuple(tensor.shape)} " + f"to have {target_dim} features." + ) + matrix = flat + + return F.normalize(matrix, p=2, dim=-1) + + +def _gather_base_vectors( + model: torch.nn.Module, + upto_layer: int, + component: str, +) -> Optional[torch.Tensor]: + """Collect all source vectors required for ``upto_layer``. + + Parameters + ---------- + model: + The GPT model containing the parameters. + upto_layer: + Index of the block (0-indexed) whose pre-norm input will snap to the grid. + component: + Either ``"attn"`` or ``"mlp"`` depending on the target sub-module. + + Returns + ------- + Optional[torch.Tensor] + A tensor of shape ``(N, d)`` where each row has unit norm. ``None`` when + there are no available vectors (should not normally happen). + """ + + vectors: List[torch.Tensor] = [] + upto_layer = max(int(upto_layer), 0) + target_dim = getattr(getattr(model, "config", None), "n_embd", None) + + for name, param in model.named_parameters(): + if param.ndim != 2 or not name.endswith("weight"): + continue + + if "transformer.wte" in name: + vectors.append(_normalize_rows(param, target_dim)) + continue + + if "transformer.h." not in name: + continue + + try: + layer_str = name.split("transformer.h.")[1].split(".")[0] + layer_idx = int(layer_str) + except (IndexError, ValueError): + continue + + if COMP_ATTENTION in name and name.endswith("attn.c_proj.weight"): + if layer_idx <= upto_layer: + vectors.append(_normalize_rows(param, target_dim)) + elif COMP_MLP in name and name.endswith("mlp.c_proj.weight"): + limit = upto_layer if component == COMP_MLP else upto_layer - 1 + if layer_idx <= limit: + vectors.append(_normalize_rows(param, target_dim)) + + if not vectors: + return None + + return torch.cat(vectors, dim=0) + + +def _ensure_component_list(component: str | Sequence[str]) -> List[str]: + if isinstance(component, str): + if component == "both": + return [COMP_ATTENTION, COMP_MLP] + return [component] + return list(component) + + +@dataclass +class SnapToGridRegistry: + """Mapping of layer/component pairs to snap-to-grid vectors.""" + + grids: Dict[Tuple[int, str], torch.Tensor] = field(default_factory=dict) + metadata: Dict[str, object] = field(default_factory=dict) + + def set_grid(self, layer_idx: int, component: str, grid: torch.Tensor) -> None: + self.grids[(layer_idx, component)] = grid.cpu() + + def get_grid(self, layer_idx: int, component: str) -> Optional[torch.Tensor]: + return self.grids.get((layer_idx, component)) + + def clear(self) -> None: + self.grids.clear() + + def state_dict(self) -> Dict[str, object]: + return {"grids": self.grids, "metadata": self.metadata} + + def load_state_dict(self, state: Dict[str, object]) -> None: + self.grids = {k: v.cpu() for k, v in state.get("grids", {}).items()} + self.metadata = state.get("metadata", {}) + + +def generate_snap_to_grid_registry( + model: torch.nn.Module, + layers: Optional[Iterable[int]], + component_selection: str | Sequence[str], + size: int, + generator: Optional[torch.Generator] = None, +) -> SnapToGridRegistry: + """Create a registry for the provided ``size`` across all ``layers``.""" + + if size is None or size <= 0: + registry = SnapToGridRegistry() + registry.metadata["size"] = 0 + return registry + + if layers is None: + n_layers = getattr(getattr(model, "config", None), "n_layer", None) + layers = range(1, n_layers or 0) + + component_list = _ensure_component_list(component_selection) + + registry = SnapToGridRegistry(metadata={"size": int(size)}) + + base_rng = generator or torch.Generator(device="cpu") + for layer_idx in layers: + if layer_idx <= 0: + continue + for component in component_list: + source = _gather_base_vectors(model, layer_idx, component) + if source is None or source.numel() == 0: + continue + coeffs = torch.randn((size, source.size(0)), generator=base_rng) + combos = coeffs @ source + combos = F.normalize(combos, p=2, dim=-1) + registry.set_grid(layer_idx, component, combos) + + return registry + + +def apply_snap_to_grid_tensor(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: + """Project activations ``x`` onto the closest vector from ``grid``.""" + + if grid is None: + return x + + target = grid.to(device=x.device, dtype=x.dtype, non_blocking=True) + x_norm = F.normalize(x, p=2, dim=-1) + target_norm = F.normalize(target, p=2, dim=-1) + sims = torch.matmul(x_norm, target_norm.t()) + best = sims.argmax(dim=-1) + snapped = target_norm.index_select(0, best.reshape(-1)) + snapped = snapped.reshape(*x.shape) + return snapped + + +def save_registry(path: str, registry: SnapToGridRegistry) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.save(registry.state_dict(), path) + + +def load_registry(path: str) -> SnapToGridRegistry: + data = torch.load(path, map_location="cpu") + registry = SnapToGridRegistry() + registry.load_state_dict(data) + return registry + diff --git a/variations/block_variations.py b/variations/block_variations.py index fd485b92f7..7e67a6b72b 100644 --- a/variations/block_variations.py +++ b/variations/block_variations.py @@ -11,6 +11,7 @@ from variations.norm_variations import norm_dictionary from variations.learned_confidence_variations import learned_confidence_dictionary from quantization.quantize import fake_quantize_act +from utils.snap_to_grid import apply_snap_to_grid_tensor # type alias for the forward function BlockForward = Callable[['Block', torch.Tensor, int], torch.Tensor] @@ -110,6 +111,9 @@ def attn_then_mlp_forward(block, x: torch.Tensor, iter_num: int) -> torch.Tensor # Make sure not to override skip connection x_attn_in = x + # Snap-to-grid before the attention pre-norm + x_attn_in = block._maybe_snap_to_grid("attn", x_attn_in) + # Attn Pre-LN if block.use_pre_ln_attn: x_attn_in = block.pre_ln_attn(x_attn_in) @@ -135,6 +139,9 @@ def attn_then_mlp_forward(block, x: torch.Tensor, iter_num: int) -> torch.Tensor # Make sure not to override skip connection x_mlp_in = x + # Snap-to-grid before the MLP pre-norm + x_mlp_in = block._maybe_snap_to_grid("mlp", x_mlp_in) + # MLP Pre-LN if block.use_pre_ln_mlp: x_mlp_in = block.pre_ln_mlp(x_mlp_in) @@ -364,7 +371,7 @@ def _setup_resid_scalers_sequential(self, config) -> None: class Block(nn.Module): """Transformer block supporting multiple normalization strategies.""" - def __init__(self, config, mlp=None, attn=None): + def __init__(self, config, layer_idx=None, mlp=None, attn=None): super().__init__() # Choose norm class for attention/MLP blocks @@ -384,6 +391,9 @@ def __init__(self, config, mlp=None, attn=None): self.use_flash_norm = getattr(config, "use_flash_norm", False) + self.layer_idx = layer_idx + self.snap_to_grid_registry = getattr(config, "snap_to_grid_registry", None) + if self.use_parallel_mlp: variant = "parallel_mlp" elif self.use_edgellm_asic: @@ -468,3 +478,11 @@ def _combine_resid(self, kind: str, x: torch.Tensor, out: torch.Tensor) -> torch alpha = self.alpha_fns[kind](out) return self.resid_fns[kind](x, out, alpha, self.residual_slerp_eps) + def _maybe_snap_to_grid(self, component: str, tensor: torch.Tensor) -> torch.Tensor: + if self.snap_to_grid_registry is None or self.layer_idx is None: + return tensor + grid = self.snap_to_grid_registry.get_grid(self.layer_idx, component) + if grid is None: + return tensor + return apply_snap_to_grid_tensor(tensor, grid) +