Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions explorations/norm_wte_abs_embd_scale.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# explorations/norm_wte_abs_embd_scale.yaml
---
# parameter_groups: define sets of overrides to apply on top of base params
parameter_groups:
# baselines
- tensorboard_run_name: ["baseline_abs_pos"]
use_abs_pos_embeddings: [true]
use_rotary_embeddings: [false]
- tensorboard_run_name: ["baseline_rotary"]
use_abs_pos_embeddings: [false]
use_rotary_embeddings: [true]

# with scale
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_l2_default_w_scale"]
- norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["abs_l2_default_w_scale"]
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [false]
norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_and_abs_default_w_scale"]

# with scale and gain
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_l2_default_w_scale_and_gain"]
- norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["abs_l2_default_w_scale_and_gain"]
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [true]
norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_and_abs_default_w_scale_and_gain"]

# With Radius Tests
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["1.0", "50.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [true, false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_l2_radius_tests"]
- norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["1.0", "50.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true, false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["abs_l2_radius_tests"]
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["1.0", "20.0", "50.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [true, false]
norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["1.0", "20.0", "50.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true, false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_and_abs_radius_tests"]

# Collect Intermediate Activation Stats
compute_model_stats: [true]
print_model_stats: ["print_stats/${RUN_NAME}"]

use_qk_norm: [true]
use_qk_norm_scale: [true]
use_peri_ln: [true, false]

# base hyperparameters
max_iters: [5000]
eval_interval: [500]
n_layer: [5]
n_head: [5]
n_embd: [384]
block_size: [256]
device: ["cuda"]
dtype: ["bfloat16"]
dataset: ["minipile"]

# VRAM and Memory
compile: [true]
never_save_checkpoint: [true]

107 changes: 107 additions & 0 deletions explorations/norm_wte_abs_sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# explorations/norm_wte_abs_sweep.yaml
---
# parameter_groups: define sets of overrides to apply on top of base params
parameter_groups:
# baselines
- tensorboard_run_name: ["baseline_abs_pos"]
use_abs_pos_embeddings: [true]
use_rotary_embeddings: [false]
- tensorboard_run_name: ["baseline_rotary"]
use_abs_pos_embeddings: [false]
use_rotary_embeddings: [true]

# with scale
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_l2_default_w_scale"]
- norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["abs_l2_default_w_scale"]
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [false]
norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_and_abs_default_w_scale"]

# with scale and gain
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_l2_default_w_scale_and_gain"]
- norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["abs_l2_default_w_scale_and_gain"]
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["20.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [true]
norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["20.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_and_abs_default_w_scale_and_gain"]

# With Radius Tests
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["1.0", "50.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [true, false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_l2_radius_tests"]
- norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["1.0", "50.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true, false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["abs_l2_radius_tests"]
- norm_variant_wte: ["hyperspherenorm"]
norm_wte_radius: ["1.0", "20.0", "50.0"]
norm_wte_scale: ["1.0", "2.0", "5.0"]
norm_wte_gain: [true, false]
norm_variant_abs: ["hyperspherenorm"]
norm_abs_radius: ["1.0", "20.0", "50.0"]
norm_abs_scale: ["1.0", "2.0", "5.0"]
norm_abs_gain: [true, false]
use_abs_pos_embeddings: [true]
tensorboard_run_name: ["wte_and_abs_radius_tests"]

# Collect Intermediate Activation Stats
compute_model_stats: [true]
print_model_stats: ["print_stats/${RUN_NAME}"]

use_qk_norm: [true]
use_qk_norm_scale: [true]
use_peri_ln: [true, false]

# base hyperparameters
max_iters: [5000]
eval_interval: [500]
n_layer: [6]
n_head: [6]
n_embd: [384]
block_size: [256]
device: ["cuda"]
dtype: ["bfloat16"]
dataset: ["minipile"]

# VRAM and Memory
compile: [true]
never_save_checkpoint: [true]

14 changes: 13 additions & 1 deletion gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,17 +366,28 @@ class GPTConfig:
# Layernorm Alternatives and Options
norm_variant_attn: str = "rmsnorm"
norm_variant_output: str = "rmsnorm"

norm_variant_wte: str | None = None
norm_wte_radius: float | None = None
norm_wte_scale: float | None = None
norm_wte_gain: bool | None = None

norm_variant_abs: str | None = None
norm_abs_radius: float | None = None
norm_abs_scale: float | None = None
norm_abs_gain: bool | None = None

bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
prmsnorm_pct: float = 0.0625
krmsnorm_num: float = 10
krmsnorm_quantize_type: str = 'int8'
krmsnorm_enable_gain: bool = True
krmsnorm_selection_type: str = 'last'
krmsnorm_recompute_percentage: float = 0.05

hsnorm_gain: bool = False
hsnorm_radius: float = 1.0
hsnorm_radius: float | None = None
hsnorm_scale: float = 1.0
hsnorm_radius_learning: bool = False

dact_alpha_init: float = 1.0
Expand All @@ -385,6 +396,7 @@ class GPTConfig:
dact_use_beta: bool = True
dact_use_alpha: bool = True
use_embedding_scale: bool = False
embedding_scale_init: float | None = None

# Activation Alternatives

Expand Down
58 changes: 34 additions & 24 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,10 @@ def __init__(self, config):

# Embedding scale
if config.use_embedding_scale:
self.embedding_scale = nn.Parameter(torch.sqrt(torch.tensor(config.n_embd)))

# Optional post-embedding normalizations
self.post_embedding_norm = None
if config.norm_variant_wte is not None:
self.post_embedding_norm = norm_dictionary[config.norm_variant_wte](config)

self.post_abs_pos_embedding_norm = None
if config.norm_variant_abs is not None:
self.post_abs_pos_embedding_norm = norm_dictionary[config.norm_variant_abs](config)
if config.embedding_scale_init is not None:
self.embedding_scale = nn.Parameter(torch.tensor([config.embedding_scale_init]))
else:
self.embedding_scale = nn.Parameter(torch.sqrt(torch.tensor([config.n_embd])))

# Learned Steering Vectors
self.use_lsv = config.use_lsv
Expand Down Expand Up @@ -204,6 +198,12 @@ def __init__(self, config):
self.transformer['h'] = nn.ModuleList([Block(config, mlp=shared_mlp_array[i], attn=shared_attn_array[i]) for i in range(config.n_layer)])
self.transformer['ln_f'] = norm_dictionary[config.norm_variant_output](config)

# Optional post-embedding normalizations
if self.config.norm_variant_wte is not None:
self.transformer['post_embedding_norm'] = self.build_norm_from_variant(config, "norm_variant_wte", "norm_wte")
if self.config.norm_variant_abs is not None:
self.transformer['post_abs_norm'] = self.build_norm_from_variant(config, "norm_variant_abs", "norm_abs")

if self.config.use_abs_pos_embeddings:
if config.quantize_wpe:
pos_embd = QuantizedEmbedding(config.block_size, config.n_embd, config.quantize_wpe_method, config.quantize_wpe_bits)
Expand Down Expand Up @@ -297,6 +297,15 @@ def update_block_size(self, new_block_size):
if hasattr(block.attn, 'bias'):
block.attn.bias = torch.tril(torch.ones(new_block_size, new_block_size)).view(1, 1, new_block_size, new_block_size)

def build_norm_from_variant(self, config, variant_key: str, prefix: str):
"""Helper to deep-copy config and override hsnorm parameters if present."""
norm_config = copy.deepcopy(config)
for attr in ("radius", "scale", "gain"):
src = f"{prefix}_{attr}"
if getattr(norm_config, src, None) is not None:
setattr(norm_config, f"hsnorm_{attr}", getattr(norm_config, src))
return norm_dictionary[getattr(config, variant_key)](norm_config)

def _init_weights(self, module):
"""
Custom weight initialization logic for GPT model.
Expand Down Expand Up @@ -564,21 +573,21 @@ def forward(self, idx, targets=None, iter_num=None, token_dict=None, target_dict
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
x = None

if self.config.use_embedding_scale:
tok_emb = tok_emb * self.embedding_scale

if self.n_embd_wte:
tok_emb = self.transformer.scale_up(tok_emb)

if self.post_embedding_norm is not None:
tok_emb = self.post_embedding_norm(tok_emb)
if self.config.use_embedding_scale:
tok_emb = tok_emb * self.embedding_scale

if self.config.norm_variant_wte is not None:
tok_emb = self.transformer.post_embedding_norm(tok_emb)

if self.config.use_abs_pos_embeddings:
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = tok_emb + pos_emb
if self.post_abs_pos_embedding_norm is not None:
x = self.post_abs_pos_embedding_norm(x)
if self.config.norm_variant_abs is not None:
x = self.transformer.post_abs_norm(x)
x = self.transformer.drop(x)
else:
x = self.transformer.drop(tok_emb)
Expand Down Expand Up @@ -697,21 +706,22 @@ def embed_tokens(self, idx, dataset_idx=None):
else:
tok_emb = self.transformer.wte(idx)

if self.config.use_embedding_scale:
tok_emb = tok_emb * self.embedding_scale

if self.n_embd_wte:
tok_emb = self.transformer.scale_up(tok_emb)

if self.post_embedding_norm is not None:
tok_emb = self.post_embedding_norm(tok_emb)
if self.config.use_embedding_scale:
tok_emb = tok_emb * self.embedding_scale

if self.config.norm_variant_wte is not None:
tok_emb = self.transformer.post_embedding_norm(tok_emb)

if self.config.use_abs_pos_embeddings:
t = idx.size(1)
pos = torch.arange(0, t, dtype=torch.long, device=device)
tok_emb = tok_emb + self.transformer.wpe(pos)
if self.post_abs_pos_embedding_norm is not None:
tok_emb = self.post_abs_pos_embedding_norm(tok_emb)
if self.transformer.post_abs_norm is not None:
tok_emb = self.transformer.post_abs_norm(tok_emb)


return self.transformer.drop(tok_emb)

Expand Down
5 changes: 2 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,9 +1891,8 @@ def train(self):
if self.iter_num > self.args.max_iters:
print(self.best_val_loss, self.best_iter, self.best_tokens)
if self.args.only_save_checkpoint_at_end:
if not self.args.never_save_checkpoint:
self.save_checkpoint('ckpt.pt')
print(f"Saved checkpoint to {self.args.out_dir}")
self.save_checkpoint('ckpt.pt')
print(f"Saved checkpoint to {self.args.out_dir}")

# Sample if set
if self.args.max_sample_tokens:
Expand Down
Loading