diff --git a/explorations/norm_wte_abs_embd_scale.yaml b/explorations/norm_wte_abs_embd_scale.yaml new file mode 100644 index 0000000000..5538137b4e --- /dev/null +++ b/explorations/norm_wte_abs_embd_scale.yaml @@ -0,0 +1,107 @@ +# explorations/norm_wte_abs_embd_scale.yaml +--- +# parameter_groups: define sets of overrides to apply on top of base params +parameter_groups: + # baselines + - tensorboard_run_name: ["baseline_abs_pos"] + use_abs_pos_embeddings: [true] + use_rotary_embeddings: [false] + - tensorboard_run_name: ["baseline_rotary"] + use_abs_pos_embeddings: [false] + use_rotary_embeddings: [true] + + # with scale + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_l2_default_w_scale"] + - norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["abs_l2_default_w_scale"] + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [false] + norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_and_abs_default_w_scale"] + + # with scale and gain + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_l2_default_w_scale_and_gain"] + - norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["abs_l2_default_w_scale_and_gain"] + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [true] + norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_and_abs_default_w_scale_and_gain"] + + # With Radius Tests + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["1.0", "50.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [true, false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_l2_radius_tests"] + - norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["1.0", "50.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true, false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["abs_l2_radius_tests"] + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["1.0", "20.0", "50.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [true, false] + norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["1.0", "20.0", "50.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true, false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_and_abs_radius_tests"] + +# Collect Intermediate Activation Stats +compute_model_stats: [true] +print_model_stats: ["print_stats/${RUN_NAME}"] + +use_qk_norm: [true] +use_qk_norm_scale: [true] +use_peri_ln: [true, false] + +# base hyperparameters +max_iters: [5000] +eval_interval: [500] +n_layer: [5] +n_head: [5] +n_embd: [384] +block_size: [256] +device: ["cuda"] +dtype: ["bfloat16"] +dataset: ["minipile"] + +# VRAM and Memory +compile: [true] +never_save_checkpoint: [true] + diff --git a/explorations/norm_wte_abs_sweep.yaml b/explorations/norm_wte_abs_sweep.yaml new file mode 100644 index 0000000000..f408f586c1 --- /dev/null +++ b/explorations/norm_wte_abs_sweep.yaml @@ -0,0 +1,107 @@ +# explorations/norm_wte_abs_sweep.yaml +--- +# parameter_groups: define sets of overrides to apply on top of base params +parameter_groups: + # baselines + - tensorboard_run_name: ["baseline_abs_pos"] + use_abs_pos_embeddings: [true] + use_rotary_embeddings: [false] + - tensorboard_run_name: ["baseline_rotary"] + use_abs_pos_embeddings: [false] + use_rotary_embeddings: [true] + + # with scale + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_l2_default_w_scale"] + - norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["abs_l2_default_w_scale"] + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [false] + norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_and_abs_default_w_scale"] + + # with scale and gain + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_l2_default_w_scale_and_gain"] + - norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["abs_l2_default_w_scale_and_gain"] + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["20.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [true] + norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["20.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_and_abs_default_w_scale_and_gain"] + + # With Radius Tests + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["1.0", "50.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [true, false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_l2_radius_tests"] + - norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["1.0", "50.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true, false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["abs_l2_radius_tests"] + - norm_variant_wte: ["hyperspherenorm"] + norm_wte_radius: ["1.0", "20.0", "50.0"] + norm_wte_scale: ["1.0", "2.0", "5.0"] + norm_wte_gain: [true, false] + norm_variant_abs: ["hyperspherenorm"] + norm_abs_radius: ["1.0", "20.0", "50.0"] + norm_abs_scale: ["1.0", "2.0", "5.0"] + norm_abs_gain: [true, false] + use_abs_pos_embeddings: [true] + tensorboard_run_name: ["wte_and_abs_radius_tests"] + +# Collect Intermediate Activation Stats +compute_model_stats: [true] +print_model_stats: ["print_stats/${RUN_NAME}"] + +use_qk_norm: [true] +use_qk_norm_scale: [true] +use_peri_ln: [true, false] + +# base hyperparameters +max_iters: [5000] +eval_interval: [500] +n_layer: [6] +n_head: [6] +n_embd: [384] +block_size: [256] +device: ["cuda"] +dtype: ["bfloat16"] +dataset: ["minipile"] + +# VRAM and Memory +compile: [true] +never_save_checkpoint: [true] + diff --git a/gpt_conf.py b/gpt_conf.py index c1a7b5ecae..431384ce79 100644 --- a/gpt_conf.py +++ b/gpt_conf.py @@ -366,8 +366,17 @@ class GPTConfig: # Layernorm Alternatives and Options norm_variant_attn: str = "rmsnorm" norm_variant_output: str = "rmsnorm" + norm_variant_wte: str | None = None + norm_wte_radius: float | None = None + norm_wte_scale: float | None = None + norm_wte_gain: bool | None = None + norm_variant_abs: str | None = None + norm_abs_radius: float | None = None + norm_abs_scale: float | None = None + norm_abs_gain: bool | None = None + bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster prmsnorm_pct: float = 0.0625 krmsnorm_num: float = 10 @@ -375,8 +384,10 @@ class GPTConfig: krmsnorm_enable_gain: bool = True krmsnorm_selection_type: str = 'last' krmsnorm_recompute_percentage: float = 0.05 + hsnorm_gain: bool = False - hsnorm_radius: float = 1.0 + hsnorm_radius: float | None = None + hsnorm_scale: float = 1.0 hsnorm_radius_learning: bool = False dact_alpha_init: float = 1.0 @@ -385,6 +396,7 @@ class GPTConfig: dact_use_beta: bool = True dact_use_alpha: bool = True use_embedding_scale: bool = False + embedding_scale_init: float | None = None # Activation Alternatives diff --git a/model.py b/model.py index 86a3005104..6d395abfc0 100644 --- a/model.py +++ b/model.py @@ -144,16 +144,10 @@ def __init__(self, config): # Embedding scale if config.use_embedding_scale: - self.embedding_scale = nn.Parameter(torch.sqrt(torch.tensor(config.n_embd))) - - # Optional post-embedding normalizations - self.post_embedding_norm = None - if config.norm_variant_wte is not None: - self.post_embedding_norm = norm_dictionary[config.norm_variant_wte](config) - - self.post_abs_pos_embedding_norm = None - if config.norm_variant_abs is not None: - self.post_abs_pos_embedding_norm = norm_dictionary[config.norm_variant_abs](config) + if config.embedding_scale_init is not None: + self.embedding_scale = nn.Parameter(torch.tensor([config.embedding_scale_init])) + else: + self.embedding_scale = nn.Parameter(torch.sqrt(torch.tensor([config.n_embd]))) # Learned Steering Vectors self.use_lsv = config.use_lsv @@ -204,6 +198,12 @@ def __init__(self, config): self.transformer['h'] = nn.ModuleList([Block(config, mlp=shared_mlp_array[i], attn=shared_attn_array[i]) for i in range(config.n_layer)]) self.transformer['ln_f'] = norm_dictionary[config.norm_variant_output](config) + # Optional post-embedding normalizations + if self.config.norm_variant_wte is not None: + self.transformer['post_embedding_norm'] = self.build_norm_from_variant(config, "norm_variant_wte", "norm_wte") + if self.config.norm_variant_abs is not None: + self.transformer['post_abs_norm'] = self.build_norm_from_variant(config, "norm_variant_abs", "norm_abs") + if self.config.use_abs_pos_embeddings: if config.quantize_wpe: pos_embd = QuantizedEmbedding(config.block_size, config.n_embd, config.quantize_wpe_method, config.quantize_wpe_bits) @@ -297,6 +297,15 @@ def update_block_size(self, new_block_size): if hasattr(block.attn, 'bias'): block.attn.bias = torch.tril(torch.ones(new_block_size, new_block_size)).view(1, 1, new_block_size, new_block_size) + def build_norm_from_variant(self, config, variant_key: str, prefix: str): + """Helper to deep-copy config and override hsnorm parameters if present.""" + norm_config = copy.deepcopy(config) + for attr in ("radius", "scale", "gain"): + src = f"{prefix}_{attr}" + if getattr(norm_config, src, None) is not None: + setattr(norm_config, f"hsnorm_{attr}", getattr(norm_config, src)) + return norm_dictionary[getattr(config, variant_key)](norm_config) + def _init_weights(self, module): """ Custom weight initialization logic for GPT model. @@ -564,21 +573,21 @@ def forward(self, idx, targets=None, iter_num=None, token_dict=None, target_dict tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) x = None - if self.config.use_embedding_scale: - tok_emb = tok_emb * self.embedding_scale - if self.n_embd_wte: tok_emb = self.transformer.scale_up(tok_emb) - if self.post_embedding_norm is not None: - tok_emb = self.post_embedding_norm(tok_emb) + if self.config.use_embedding_scale: + tok_emb = tok_emb * self.embedding_scale + + if self.config.norm_variant_wte is not None: + tok_emb = self.transformer.post_embedding_norm(tok_emb) if self.config.use_abs_pos_embeddings: pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) x = tok_emb + pos_emb - if self.post_abs_pos_embedding_norm is not None: - x = self.post_abs_pos_embedding_norm(x) + if self.config.norm_variant_abs is not None: + x = self.transformer.post_abs_norm(x) x = self.transformer.drop(x) else: x = self.transformer.drop(tok_emb) @@ -697,21 +706,22 @@ def embed_tokens(self, idx, dataset_idx=None): else: tok_emb = self.transformer.wte(idx) - if self.config.use_embedding_scale: - tok_emb = tok_emb * self.embedding_scale - if self.n_embd_wte: tok_emb = self.transformer.scale_up(tok_emb) - if self.post_embedding_norm is not None: - tok_emb = self.post_embedding_norm(tok_emb) + if self.config.use_embedding_scale: + tok_emb = tok_emb * self.embedding_scale + + if self.config.norm_variant_wte is not None: + tok_emb = self.transformer.post_embedding_norm(tok_emb) if self.config.use_abs_pos_embeddings: t = idx.size(1) pos = torch.arange(0, t, dtype=torch.long, device=device) tok_emb = tok_emb + self.transformer.wpe(pos) - if self.post_abs_pos_embedding_norm is not None: - tok_emb = self.post_abs_pos_embedding_norm(tok_emb) + if self.transformer.post_abs_norm is not None: + tok_emb = self.transformer.post_abs_norm(tok_emb) + return self.transformer.drop(tok_emb) diff --git a/train.py b/train.py index 5c960b0507..ca2b7d0827 100644 --- a/train.py +++ b/train.py @@ -1891,9 +1891,8 @@ def train(self): if self.iter_num > self.args.max_iters: print(self.best_val_loss, self.best_iter, self.best_tokens) if self.args.only_save_checkpoint_at_end: - if not self.args.never_save_checkpoint: - self.save_checkpoint('ckpt.pt') - print(f"Saved checkpoint to {self.args.out_dir}") + self.save_checkpoint('ckpt.pt') + print(f"Saved checkpoint to {self.args.out_dir}") # Sample if set if self.args.max_sample_tokens: diff --git a/train_args.py b/train_args.py index 488813ca11..9cd4e2900b 100644 --- a/train_args.py +++ b/train_args.py @@ -144,7 +144,7 @@ def parse_args(): # Checkpoint args training_group.add_argument('--save_major_ckpt_interval', default=None, type=int, help="Interval for saving major checkpoints.") - training_group.add_argument('--only_save_checkpoint_at_end', default=False, action=argparse.BooleanOptionalAction) + training_group.add_argument('--only_save_checkpoint_at_end', default=False, action=argparse.BooleanOptionalAction, help="Note: overrides never_save_checkpoint. Useful for coupling with epoch or token training limits.") training_group.add_argument('--always_save_checkpoint', default=False, action=argparse.BooleanOptionalAction) training_group.add_argument('--never_save_checkpoint', default=False, action=argparse.BooleanOptionalAction, help="If set, disables saving of all checkpoints.") training_group.add_argument('--patience', default=None, type=int, help="if set, will stop training if the number of evaluations since val loss was seen to decrease exceeds 'patience' setting.") @@ -637,6 +637,18 @@ def parse_args(): model_group.add_argument("--norm_variant_attn", type=str, default="rmsnorm", choices=norm_variations) model_group.add_argument("--norm_variant_output", type=str, default="rmsnorm", choices=norm_variations) + ### WTE and Abs Pos Embedding Post Norms (optional, and default None) + model_group.add_argument("--norm_variant_wte", type=str, default=None, choices=norm_variations) + model_group.add_argument("--norm_variant_abs", type=str, default=None, choices=norm_variations) + + model_group.add_argument("--norm_wte_radius", type=float, default=None) + model_group.add_argument("--norm_wte_scale", type=float, default=None) + model_group.add_argument("--norm_wte_gain", type=bool, default=None, action=argparse.BooleanOptionalAction) + + model_group.add_argument("--norm_abs_radius", type=float, default=None) + model_group.add_argument("--norm_abs_scale", type=float, default=None) + model_group.add_argument("--norm_abs_gain", type=bool, default=None, action=argparse.BooleanOptionalAction) + ## Layernorm model_group.add_argument('--bias', default=False, action=argparse.BooleanOptionalAction, help="only used for layernorm variation option") @@ -652,6 +664,7 @@ def parse_args(): ## HyperSphereNorm model_group.add_argument("--hsnorm_gain", default=False, action=argparse.BooleanOptionalAction) + model_group.add_argument("--hsnorm_scale", type=float, default=1.0) model_group.add_argument("--hsnorm_radius", type=float, default=None) model_group.add_argument("--hsnorm_radius_learning", default=False, action=argparse.BooleanOptionalAction) @@ -691,6 +704,7 @@ def parse_args(): model_group.add_argument("--dact_use_alpha", type=bool, default=True, action=argparse.BooleanOptionalAction) model_group.add_argument("--use_embedding_scale", type=bool, default=False, action=argparse.BooleanOptionalAction) + model_group.add_argument("--embedding_scale_init", type=float, default=None) # ACTIVATION VARIATIONS model_group.add_argument( "--activation_variant", type=str, default="gelu", choices=activation_variations) diff --git a/variations/norm_variations.py b/variations/norm_variations.py index 9a483dd728..3c4c854295 100644 --- a/variations/norm_variations.py +++ b/variations/norm_variations.py @@ -68,6 +68,7 @@ class HyperSphereNorm(nn.Module): def __init__(self, config): super().__init__() + ndim = config.n_embd if config.hsnorm_gain: self.gain = nn.Parameter(torch.ones(ndim)) @@ -81,15 +82,22 @@ def __init__(self, config): else: radius_init = math.sqrt(ndim) + # constant for loss scaling (default set to 1.0) + self.const_radius_factor = config.hsnorm_scale + # Set as constant or learned param + self.hsnorm_radius_learning = config.hsnorm_radius_learning if config.hsnorm_radius_learning: - self.radius = nn.Parameter(torch.tensor([radius_init])) + # div by const_radius_factor (no effect if is 1.0) + radius_init = radius_init / self.const_radius_factor + self.radius_init_factor = nn.Parameter(torch.tensor([radius_init])) else: - self.radius = radius_init + self.radius_init_factor = radius_init def forward(self, x): + radius = self.const_radius_factor * self.radius_init_factor hypersphere_norm = x.norm(2, dim=-1, keepdim=True) - return x / hypersphere_norm * self.radius + return x / hypersphere_norm * radius * self.gain class pRMSNorm(nn.Module): """Partial RMS Normalization"""