diff --git a/experiments/activation/llm.py b/experiments/activation/llm.py new file mode 100644 index 00000000..a571926d --- /dev/null +++ b/experiments/activation/llm.py @@ -0,0 +1,703 @@ +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torch.cuda.amp import autocast, GradScaler +import math +import random +import numpy as np +from datasets import load_dataset +from tqdm import tqdm +import time +from transformers import AutoTokenizer +from dataclasses import dataclass +from typing import List, Optional, Tuple +import warnings +import os +import pickle +from torchtune.modules import RotaryPositionalEmbeddings +warnings.filterwarnings('ignore') +import matplotlib.pyplot as plt + +def set_seed(seed: int = 42): + """Set all random seeds for reproducibility""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print(f"๐ŸŒฑ Set all seeds to {seed}") + +@dataclass +class MoEModelConfig: + # Model architecture + d_model: int = 384 + n_heads: int = 8 + n_layers: int = 6 + d_ff: int = 1536 + batch_size: int = 24 + max_steps: int = 1000 + + # Training parameters + gradient_accumulation_steps: int = 4 + muon_lr: float = 0.01 + + # Data parameters + max_seq_len: int = 512 + num_documents: int = 2000 + max_tokens: int = 500000 + + # Evaluation + eval_every: int = 500 + eval_steps: int = 100 + + # Regularization + weight_decay: float = 0.1 + dropout: float = 0.1 + grad_clip: float = 1.0 + + # Technical + use_amp: bool = True + vocab_size: Optional[int] = None + log_milestones: Tuple[int, ...] = (2000, 5000, 10000) + + # MoE specific parameters + num_experts: int = 8 + expert_top_k: int = 2 + load_balancing_weight: float = 0.01 + + def __post_init__(self): + self.d_k = self.d_model // self.n_heads + assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads" + +@torch.compile +def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor: + """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.""" + assert G.ndim >= 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + + if G.size(-2) > G.size(-1): + X = X.mT + + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + + return X + +class Muon(torch.optim.Optimizer): + """Muon - MomentUm Orthogonalized by Newton-schulz""" + def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5): + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + g = p.grad + state = self.state[p] + + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + + buf = state["momentum_buffer"] + buf.lerp_(g, 1 - group["momentum"]) + g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + p.add_(g.view_as(p), alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5) + +def load_and_cache_data(config: MoEModelConfig, cache_dir: str = "data_cache"): + """Load and cache tokenized data to avoid reprocessing""" + os.makedirs(cache_dir, exist_ok=True) + cache_file = f"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl" + + # Check if cached data exists + if os.path.exists(cache_file): + print(f"๐Ÿ“ฆ Loading cached data from {cache_file}") + with open(cache_file, 'rb') as f: + cached_data = pickle.load(f) + + texts = cached_data['texts'] + tokenizer = cached_data['tokenizer'] + tokens = cached_data['tokens'] + config.vocab_size = tokenizer.vocab_size + + print(f"โœ… Loaded {len(texts)} documents, {len(tokens):,} tokens from cache") + return texts, tokenizer, tokens + + print(f"๐Ÿ”„ Processing new data (will cache for future use)") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M", token=False) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load dataset + dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True, token=False) + + texts = [] + for i, item in enumerate(dataset): + if i >= config.num_documents: + break + texts.append(item["text"][:3000]) + + print(f"Loaded {len(texts)} documents") + + # Tokenize + print("Tokenizing texts...") + all_tokens = [] + for text in tqdm(texts, desc="Tokenizing"): + tokens = tokenizer.encode(text, add_special_tokens=False) + all_tokens.extend(tokens) + + tokens = all_tokens[:config.max_tokens] + print(f"Using {len(tokens):,} tokens") + config.vocab_size = tokenizer.vocab_size + + # Cache the processed data + cached_data = {'texts': texts, 'tokenizer': tokenizer, 'tokens': tokens} + with open(cache_file, 'wb') as f: + pickle.dump(cached_data, f) + + print(f"๐Ÿ’พ Cached data to {cache_file}") + return texts, tokenizer, tokens + +class TextTokenDataset(Dataset): + def __init__(self, tokens: List[int], seq_len: int = 512): + self.tokens = tokens + self.seq_len = seq_len + + def __len__(self): + return max(0, len(self.tokens) - self.seq_len) + + def __getitem__(self, idx): + x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long) + y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long) + return x, y + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + self.rope = RotaryPositionalEmbeddings(dim=dim, max_seq_len=max_seq_len, base=10000) + + def forward(self, x_BTHD: torch.Tensor): + # x_BTHD shape: [B, T, H, D] - need to convert to [B, T, H, D] for torchtune + # torchtune expects [batch, seq_len, num_heads, head_dim] + # Our input is already [B, T, H, D] which matches torchtune's expectation + return self.rope(x_BTHD) + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, max_seq_len: int, dropout: float = 0.1): + super().__init__() + self.d_model = d_model + self.n_heads = n_heads + self.d_k = d_model // n_heads + + self.qkv = nn.Linear(d_model, d_model * 3, bias=False) + self.w_o = nn.Linear(d_model, d_model, bias=False) + self.rotary = Rotary(self.d_k, max_seq_len) + self.dropout = dropout + + def forward(self, x): + batch_size, seq_len = x.size(0), x.size(1) + # B, T = x.size(0), x.size(1) + # qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.d_k).permute(2, 0, 3, 1, 4) + # Q, K, V = qkv[0], qkv[1], qkv[2] # [B, H, T, D] + + qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.d_k) + qkv = qkv.permute(2, 0, 3, 1, 4) + Q, K, V = qkv[0], qkv[1], qkv[2] # [B, H, T, D] + + # Q = self.rotary(Q) + # K = self.rotary(K) + # Apply RoPE on [B, T, H, D] + Q = self.rotary(Q.transpose(1, 2)).transpose(1, 2) + K = self.rotary(K.transpose(1, 2)).transpose(1, 2) + + attn_output = F.scaled_dot_product_attention( + Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0 + ) + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model) + # attn_output = attn_output.transpose(1, 2).reshape(B, T, self.d_model) + return self.w_o(attn_output) + + + +class Expert(nn.Module): + """Single expert network (essentially a FeedForward layer)""" + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): + super().__init__() + self.linear1 = nn.Linear(d_model, d_ff, bias=False) + self.linear2 = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.linear2(self.dropout(F.silu(self.linear1(x)))) + +class TopKRouter(nn.Module): + """Router that selects top-k experts for each token""" + def __init__(self, d_model: int, num_experts: int, top_k: int = 2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(d_model, num_experts, bias=False) + self.noise_std = 0.1 # Standard deviation for noise during training + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: Input tensor [batch_size, seq_len, d_model] + + Returns: + - router_weights: Softmax weights for selected experts [batch_size, seq_len, top_k] + - expert_indices: Indices of selected experts [batch_size, seq_len, top_k] + - router_probs: Full probability distribution over experts (for load balancing loss) + """ + batch_size, seq_len, d_model = x.shape + + # Compute router logits + router_logits = self.gate(x) # [batch_size, seq_len, num_experts] + + # Add noise during training for exploration + if self.training and self.noise_std > 0: + noise = torch.randn_like(router_logits) * self.noise_std + router_logits = router_logits + noise + + # Get full probability distribution (for load balancing loss) + router_probs = F.softmax(router_logits, dim=-1) + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) + top_k_weights = F.softmax(top_k_logits, dim=-1) + + return top_k_weights, top_k_indices, router_probs + +class MixtureOfExperts(nn.Module): + """Mixture of Experts layer with top-k routing""" + def __init__( + self, + d_model: int, + d_ff: int, + num_experts: int = 8, + top_k: int = 2, + dropout: float = 0.1, + load_balancing_weight: float = 0.01, + activation: str = "silu" # Add this parameter + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.load_balancing_weight = load_balancing_weight + + # Create experts with proper parameter passing + self.experts = nn.ModuleList([ + Expert(d_model, d_ff, activation, dropout) for _ in range(num_experts) + ]) + + # Create router + self.router = TopKRouter(d_model, num_experts, top_k) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: Input tensor [batch_size, seq_len, d_model] + + Returns: + - output: MoE output [batch_size, seq_len, d_model] + - aux_loss: Load balancing auxiliary loss (only during training) + """ + batch_size, seq_len, d_model = x.shape + + # Get routing decisions + router_weights, expert_indices, router_probs = self.router(x) + + # Initialize output tensor + output = torch.zeros_like(x) + + # Process each expert + for expert_idx in range(self.num_experts): + # Find tokens routed to this expert + expert_mask = (expert_indices == expert_idx).any(dim=-1) # [batch_size, seq_len] + + if expert_mask.any(): + # Get tokens for this expert + expert_input = x[expert_mask] # [num_tokens, d_model] + + # Apply expert + expert_output = self.experts[expert_idx](expert_input) + + # Get weights for this expert - CORRECTED APPROACH + # First get the mask for this expert's positions + mask_for_expert = (expert_indices == expert_idx) # [batch, seq, top_k] + # Find which position (0 or 1) this expert appears in for relevant tokens + positions = mask_for_expert[expert_mask].float().argmax(dim=-1) + # Gather weights only for relevant tokens + expert_weights = router_weights[expert_mask].gather( + -1, positions.unsqueeze(-1) + ).squeeze(-1) + + # Add weighted expert output to result + output[expert_mask] += expert_weights.unsqueeze(-1) * expert_output + + # Compute load balancing loss during training + aux_loss = None + if self.training: + aux_loss = self._compute_load_balancing_loss(router_probs, expert_indices) + + return output, aux_loss + + def _compute_load_balancing_loss( + self, + router_probs: torch.Tensor, + expert_indices: torch.Tensor + ) -> torch.Tensor: + """ + Compute auxiliary loss to ensure balanced expert usage. + This encourages the router to distribute tokens evenly across experts. + """ + # Compute the fraction of tokens routed to each expert + expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).float() + tokens_per_expert = expert_mask.sum(dim=[0, 1, 2]) / expert_mask.sum() + + # Compute the average probability of routing to each expert + router_prob_mean = router_probs.mean(dim=[0, 1]) + + # Load balancing loss encourages uniform distribution + aux_loss = torch.sum(tokens_per_expert * router_prob_mean) * self.num_experts + + return aux_loss * self.load_balancing_weight + +class MoETransformerBlock(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + d_ff: int, + max_seq_len: int, + num_experts: int = 8, + top_k: int = 2, + dropout: float = 0.1, + activation: str = "silu" # Add this parameter + ): + super().__init__() + + # Attention layer + self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout) + + # MoE layer with activation parameter + self.feed_forward = MixtureOfExperts( + d_model, d_ff, num_experts, top_k, dropout, activation=activation + ) + + # Normalization layers + self.norm1 = nn.RMSNorm(d_model) + self.norm2 = nn.RMSNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # Self-attention + attn_out = self.attention(self.norm1(x)) + x = x + self.dropout(attn_out) + + # MoE feed-forward + ff_out, aux_loss = self.feed_forward(self.norm2(x)) + x = x + self.dropout(ff_out) + return x, aux_loss + + +class MoEMinimalLLM(nn.Module): + def __init__(self, config: MoEModelConfig): + super().__init__() + self.config = config + + # Token embeddings + self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) + self.position_dropout = nn.Dropout(config.dropout) + + # Get activation from config, default to "silu" + activation = getattr(config, 'activation', 'silu') + + # Transformer blocks with MoE + self.transformer_blocks = nn.ModuleList([ + MoETransformerBlock( + config.d_model, + config.n_heads, + config.d_ff, + config.max_seq_len, + config.num_experts, + config.expert_top_k, + config.dropout, + activation # Pass the activation + ) + for i in range(config.n_layers) + ]) + + # Rest of the constructor remains the same... + + # Output layers + self.norm = nn.RMSNorm(config.d_model) + self.output_dropout = nn.Dropout(config.dropout) + + # Language modeling head (tied with embeddings) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.lm_head.weight = self.token_embedding.weight + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, x, return_aux_loss=True): + # Token embeddings + x = self.token_embedding(x) * math.sqrt(self.config.d_model) + x = self.position_dropout(x) + + # Collect auxiliary losses from MoE layers + aux_losses = [] + + # Pass through transformer blocks + for block in self.transformer_blocks: + x, aux_loss = block(x) + if aux_loss is not None and return_aux_loss: + aux_losses.append(aux_loss) + + # Output projection + x = self.norm(x) + x = self.output_dropout(x) + logits = self.lm_head(x) + + # Combine auxiliary losses + total_aux_loss = sum(aux_losses) if aux_losses else None + + if return_aux_loss: + return logits, total_aux_loss + return logits + +def evaluate_model(model: nn.Module, val_loader: DataLoader, config: MoEModelConfig): + """Evaluate model performance""" + model.eval() + total_loss = 0 + total_tokens = 0 + total_correct = 0 + + device = next(model.parameters()).device + + with torch.no_grad(): + for i, (x, y) in enumerate(val_loader): + if i >= config.eval_steps: + break + x, y = x.to(device), y.to(device) + + with autocast(enabled=config.use_amp): + # MoE model evaluation + logits = model(x, return_aux_loss=False) # Don't return aux loss during eval + loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1)) + + total_loss += loss.item() * y.numel() + total_tokens += y.numel() + + predictions = logits.argmax(dim=-1) + total_correct += (predictions == y).sum().item() + + avg_loss = total_loss / total_tokens + accuracy = total_correct / total_tokens + perplexity = math.exp(min(avg_loss, 20)) + + model.train() + return {'val_loss': avg_loss, 'val_accuracy': accuracy, 'val_perplexity': perplexity} + +def setup_muon_optimizer(model: nn.Module, config: MoEModelConfig): + """Setup Muon optimizer with hybrid approach""" + muon_params = [] + adamw_params = [] + + for name, param in model.named_parameters(): + if (param.ndim == 2 and + 'token_embedding' not in name and + 'norm' not in name and + param.requires_grad): + muon_params.append(param) + else: + adamw_params.append(param) + + print(f" Muon parameters: {sum(p.numel() for p in muon_params):,}") + print(f" AdamW parameters: {sum(p.numel() for p in adamw_params):,}") + + muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95) + adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay) + + return [muon_optimizer, adamw_optimizer] + + +def train_moe_model(config: MoEModelConfig, train_loader: DataLoader, val_loader: DataLoader): + """Train the MoE model""" + print(f"\n๐Ÿš€ Training MoE model with {config.num_experts} experts (top-{config.expert_top_k})") + + # Initialize model + set_seed(42) + model = MoEMinimalLLM(config) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + active_params = sum(p.numel() for n, p in model.named_parameters() if 'expert' not in n) + expert_params = total_params - active_params + + print(f" ๐Ÿ“Š Total parameters: {total_params:,}") + print(f" ๐Ÿ“Š Active parameters: {active_params:,}") + print(f" ๐Ÿ“Š Expert parameters: {expert_params:,}") + print(f" ๐Ÿ“Š Parameter efficiency: {active_params/total_params:.1%} active per forward pass") + + # Setup optimizers + optimizers = setup_muon_optimizer(model, config) + + # Learning rate schedule + schedulers = [] + for optimizer in optimizers: + warmup_steps = config.max_steps // 20 + + def lr_lambda(step): + if step < warmup_steps: + return step / warmup_steps + else: + progress = (step - warmup_steps) / (config.max_steps - warmup_steps) + return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + schedulers.append(scheduler) + + scaler = GradScaler() if config.use_amp else None + + # ----------------------------- + # Metric storage for plotting + # ----------------------------- + train_losses = [] + aux_losses = [] + val_losses = [] + val_steps = [] + + # Training loop + model.train() + step = 0 + pbar = tqdm(total=config.max_steps, desc="Training MoE") + + while step < config.max_steps: + for batch_idx, (x, y) in enumerate(train_loader): + if step >= config.max_steps: + break + + x, y = x.to(device), y.to(device) + + # Forward pass + if config.use_amp: + with autocast(): + logits, aux_loss = model(x, return_aux_loss=True) + ce_loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1)) + total_loss = ce_loss + (aux_loss if aux_loss is not None else 0.0) + loss = total_loss / config.gradient_accumulation_steps + scaler.scale(loss).backward() + else: + logits, aux_loss = model(x, return_aux_loss=True) + ce_loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1)) + total_loss = ce_loss + (aux_loss if aux_loss is not None else 0.0) + loss = total_loss / config.gradient_accumulation_steps + loss.backward() + + # Record training losses + train_losses.append(ce_loss.item()) + aux_losses.append(aux_loss.item() if aux_loss is not None else 0.0) + + # Optimizer step + if (step + 1) % config.gradient_accumulation_steps == 0: + if config.use_amp: + for optimizer in optimizers: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + + for optimizer in optimizers: + scaler.step(optimizer) + optimizer.zero_grad() + for scheduler in schedulers: + scheduler.step() + scaler.update() + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + for optimizer in optimizers: + optimizer.step() + optimizer.zero_grad() + for scheduler in schedulers: + scheduler.step() + + # Logging + if step % 100 == 0: + with torch.no_grad(): + predictions = logits.argmax(dim=-1) + accuracy = (predictions == y).float().mean().item() + perplexity = math.exp(min(ce_loss.item(), 20)) + + pbar.set_postfix({ + 'loss': f'{ce_loss.item():.4f}', + 'aux': f'{aux_loss.item() if aux_loss is not None else 0:.4f}', + 'acc': f'{accuracy:.3f}', + 'ppl': f'{perplexity:.1f}' + }) + + # Evaluation + if step % config.eval_every == 0 and step > 0: + eval_metrics = evaluate_model(model, val_loader, config) + val_losses.append(eval_metrics["val_loss"]) + val_steps.append(step) + print(f"\nStep {step}: Val Loss: {eval_metrics['val_loss']:.4f}, " + f"Val Acc: {eval_metrics['val_accuracy']:.4f}, " + f"Val PPL: {eval_metrics['val_perplexity']:.2f}") + + # Milestone evaluations + if step in getattr(config, 'log_milestones', ()): + eval_metrics = evaluate_model(model, val_loader, config) + print(f"\n๐Ÿงช Milestone {step}: Val Loss: {eval_metrics['val_loss']:.4f}") + + step += 1 + if step % 20 == 0: + pbar.update(20) + + pbar.close() + + # Final evaluation + final_eval = evaluate_model(model, val_loader, config) + print(f"\n๐Ÿ“Š Final Results:{config.activation}") + print(f" Val Loss: {final_eval['val_loss']:.4f}") + print(f" Val Accuracy: {final_eval['val_accuracy']:.4f}") + print(f" Val Perplexity: {final_eval['val_perplexity']:.2f}") + + # ----------------------------- + # Return model, final metrics, and training curves + # ----------------------------- + history = { + "train_losses": train_losses, + "aux_losses": aux_losses, + "val_losses": val_losses, + "val_steps": val_steps, + } + + return model, final_eval, history \ No newline at end of file diff --git a/experiments/activation/plots.py b/experiments/activation/plots.py new file mode 100644 index 00000000..f626a271 --- /dev/null +++ b/experiments/activation/plots.py @@ -0,0 +1,86 @@ +import matplotlib.pyplot as plt +import numpy as np +import math +# ------------------------------- +# 1. Subplots: train vs val loss per activation +# ------------------------------- +def plot_loss_subplots(results, save_path="loss_subplots.png"): + grouped = {} + for r in results: + grouped.setdefault(r["activation"], []).append(r) + + n_acts = len(grouped) + fig, axes = plt.subplots(n_acts, 1, figsize=(8, 3*n_acts), sharex=True) + + if n_acts == 1: + axes = [axes] + + for ax, (act, act_results) in zip(axes, grouped.items()): + for run in act_results: + steps = list(range(1, len(run["train_losses"]) + 1)) + ax.plot(steps,[math.log(i) for i in run["train_losses"]], label=f"{act} train (frac={run['frac']})") + if run["val_losses"]: + ax.plot(run["val_steps"],[math.log(i) for i in run["val_losses"]], linestyle="--", + label=f"{act} val (frac={run['frac']})") + ax.set_title(f"Activation: {act}") + ax.set_ylabel("Loss") + ax.legend() + + axes[-1].set_xlabel("Training Steps") + plt.tight_layout() + plt.savefig(save_path) + plt.show() + print(f"๐Ÿ“‰ Saved subplots to {save_path}") + + +# ------------------------------- +# 2. Bar plot: compute efficiency +# ------------------------------- +def plot_compute_bar(results, save_path="compute_bar.png"): + acts = [] + times = [] + for r in results: + if "time_per_step_sec" in r: + acts.append(f"{r['activation']} (frac={r['frac']})") + times.append(r["time_per_step_sec"]) + + plt.figure(figsize=(8,5)) + plt.bar(acts, times) + plt.ylabel("Time per Step (s)") + plt.title("Compute Cost per Activation") + plt.xticks(rotation=45, ha="right") + plt.tight_layout() + plt.savefig(save_path) + plt.show() + print(f"๐Ÿ“Š Saved compute bar plot to {save_path}") + + +# ------------------------------- +# 3. Scaling: val loss vs dataset size +# ------------------------------- +def plot_loss_vs_dataset(results, save_path="scaling_dataset.png"): + grouped = {} + for r in results: + grouped.setdefault(r["activation"], []).append(r) + + plt.figure(figsize=(8,6)) + for act, act_results in grouped.items(): + Ds = np.array([r["dataset_size"] for r in act_results], dtype=float) + Ls = np.array([r["val_loss"] for r in act_results], dtype=float) + + order = np.argsort(Ds) + Ds, Ls = Ds[order], Ls[order] + + plt.plot(Ds, Ls, "o-", label=f"{act}") + + plt.xscale("log") + plt.yscale("log") + plt.xlabel("Dataset Size (log)") + plt.ylabel("Validation Loss (log)") + plt.title("Scaling: Validation Loss vs Dataset Size") + plt.legend() + plt.grid(True, which="both") + plt.tight_layout() + plt.savefig(save_path) + plt.show() + print(f"๐Ÿ“‰ Saved scaling plot to {save_path}") diff --git a/experiments/activation/scaling.py b/experiments/activation/scaling.py new file mode 100644 index 00000000..59c7eed1 --- /dev/null +++ b/experiments/activation/scaling.py @@ -0,0 +1,128 @@ +import numpy as np +from torch.utils.data import Subset +from scipy.optimize import curve_fit +from torch.utils.data import DataLoader, random_split +import matplotlib.pyplot as plt +from llm import train_moe_model, MoEModelConfig + +# Scaling law function (only data term + constant since model size N is fixed) +def scaling_law(D, a, b, c): + """ + L(D) = a / (D^b) + c + - a : coefficient for data scaling + - b : scaling exponent (key metric!) + - c : irreducible error (loss floor) + """ + return a / (D ** b) + c + + +# ------------------------------- +# Extended Benchmark +# ------------------------------- +def benchmark_activation_scaling(act_name: str, dataset, tokenizer, dataset_fracs=[0.1, 0.3, 1.0], num_steps=200): + """ + Run scaling-law benchmark for one activation across different dataset sizes. + """ + results = [] + total_len = len(dataset) + + for frac in dataset_fracs: + size = int(total_len * frac) + indices = list(range(size)) + sub_dataset = Subset(dataset, indices) + val_size = size // 10 + train_size = size - val_size + + train_dataset, val_dataset = random_split(sub_dataset, [train_size, val_size]) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=16) + + # Config (small, fixed model size) + config = MoEModelConfig( + d_model=256, + n_heads=4, + n_layers=2, + d_ff=1024, + num_experts=4, + expert_top_k=2, + num_documents=size, + max_tokens=20000, + max_seq_len=128, + batch_size=16, + max_steps=num_steps, + eval_every=20, + ) + config.activation = act_name + config.vocab_size = len(tokenizer) + + # Train + collect results + _, final_eval, history = train_moe_model(config, train_loader, val_loader) + results.append({ + "activation": act_name, + "frac": frac, + "dataset_size": size, + "val_loss": final_eval["val_loss"], + "train_losses": history["train_losses"], + "val_losses": history["val_losses"], + "val_steps": history["val_steps"] + }) + + return results + + +# ------------------------------- +# Fit scaling exponents +# ------------------------------- +def fit_scaling_exponent(results): + """ + Fit power-law curve: L(D) = a / D^b + c + Returns exponent b and fit parameters per activation. + """ + summary = {} + + grouped = {} + for r in results: + grouped.setdefault(r["activation"], []).append(r) + + for act, act_results in grouped.items(): + Ds = np.array([r["dataset_size"] for r in act_results], dtype=float) + Ls = np.array([r["val_loss"] for r in act_results], dtype=float) + + # Fit curve + popt, _ = curve_fit(scaling_law, Ds, Ls, maxfev=10000) + a, b, c = popt + summary[act] = {"a": a, "b": b, "c": c, "raw": (Ds, Ls)} + + return summary + + +# ------------------------------- +# Plot scaling curves +# ------------------------------- +def plot_scaling_curves(summary, save_path="scaling_laws.png"): + plt.figure(figsize=(8,6)) + + for act, params in summary.items(): + Ds, Ls = params["raw"] + sorted_idx = np.argsort(Ds) + Ds_sorted = Ds[sorted_idx] + Ls_sorted = Ls[sorted_idx] + + # Fitted curve + fit_D = np.logspace(np.log10(min(Ds)), np.log10(max(Ds)), 50) + fit_L = scaling_law(fit_D, params["a"], params["b"], params["c"]) + + plt.plot(Ds_sorted, Ls_sorted, "o", label=f"{act} (val)") + plt.plot(fit_D, fit_L, "--", label=f"{act} fit (b={params['b']:.2f})") + + plt.xscale("log") + plt.yscale("log") + plt.xlabel("Dataset Size (log)") + plt.ylabel("Validation Loss (log)") + plt.title("Scaling Laws: Activation Functions") + plt.legend() + plt.grid(True, which="both") + plt.tight_layout() + plt.savefig(save_path) + plt.show() + print(f"๐Ÿ“‰ Saved scaling-law plot to {save_path}") diff --git a/experiments/activation/test.py b/experiments/activation/test.py new file mode 100644 index 00000000..81a3e374 --- /dev/null +++ b/experiments/activation/test.py @@ -0,0 +1,208 @@ +import time +import matplotlib.pyplot as plt +from enum import Enum +import torch.nn.functional as F +import torch.nn as nn +import torch +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from llm import (MoEMinimalLLM, load_and_cache_data, TextTokenDataset, MoEModelConfig) +from scaling import * +# ------------------------------- +# Activations +# ------------------------------- +class NonlinearityType(Enum): + RELU = "relu" + SILU = "silu" + GELU = "gelu" + SWIGLU = "swiglu" + GEGLU = "geglu" + +def get_activation_fn(name: str): + if name == "relu": + return nn.ReLU() + elif name == "silu": + return nn.SiLU() + elif name == "gelu": + return nn.GELU() + elif name == "swiglu": + return lambda x: F.silu(x[..., :x.shape[-1]//2]) * x[..., x.shape[-1]//2:] + elif name == "geglu": + return lambda x: F.gelu(x[..., :x.shape[-1]//2]) * x[..., x.shape[-1]//2:] + else: + raise ValueError(f"Unknown activation {name}") + +# ------------------------------- +# Expert with configurable activation +# ------------------------------- +class Expert(nn.Module): + def __init__(self, d_model: int, d_ff: int, activation: str = "silu", dropout: float = 0.1): + super().__init__() + self.activation_name = activation + self.activation = get_activation_fn(activation) + + # For gated activations, need double dim + if activation in ["swiglu", "geglu"]: + assert d_ff % 2 == 0, "d_ff must be even for gated activations" + inner_dim = d_ff * 2 + else: + inner_dim = d_ff + + self.linear1 = nn.Linear(d_model, inner_dim, bias=False) + self.linear2 = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + h = self.linear1(x) + h = self.activation(h) + h = self.dropout(h) + return self.linear2(h) + +# ------------------------------- +# Training loop for activation benchmark +# ------------------------------- +def train_moe_with_activation(config, train_loader, val_loader): + print(f"\n๐Ÿš€ Training with activation = {config.activation}") + model = MoEMinimalLLM(config).to("cuda" if torch.cuda.is_available() else "cpu") + device = next(model.parameters()).device + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + train_losses, val_losses, val_steps = [], [], [] + + for step in tqdm(range(1, config.max_steps + 1), desc="Training"): + model.train() + total_loss = 0.0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, aux_loss = model(x, return_aux_loss=True) + ce_loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1)) + loss = ce_loss + (aux_loss if aux_loss is not None else 0.0) + + loss.backward() + optimizer.step() + total_loss += loss.item() + + avg_loss = total_loss / len(train_loader) + train_losses.append(avg_loss) + + # Validation + if step % config.eval_every == 0: + model.eval() + val_loss, correct, total = 0.0, 0, 0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + logits, _ = model(x, return_aux_loss=True) + val_loss += F.cross_entropy( + logits.view(-1, config.vocab_size), y.view(-1) + ).item() + preds = logits.argmax(dim=-1) + correct += (preds == y).sum().item() + total += y.numel() + val_loss /= len(val_loader) + val_acc = correct / total + val_losses.append(val_loss) + val_steps.append(step) + + print(f"Step {step}: train_loss={avg_loss:.4f}, val_loss={val_loss:.4f}, acc={val_acc:.4f}") + + return model, { + "train_losses": train_losses, + "val_losses": val_losses, + "val_steps": val_steps, + } + +# ------------------------------- +# Benchmark function +# ------------------------------- +def benchmark_activation(act_name: str, num_steps: int = 100): + print(f"\n๐Ÿ”น Benchmarking activation: {act_name}") + + config = MoEModelConfig( + d_model=256, + n_heads=4, + n_layers=2, + d_ff=1024, + num_experts=4, + expert_top_k=2, + num_documents=200, + max_tokens=20000, + max_seq_len=128, + batch_size=16, + max_steps=num_steps, + eval_every=20, + ) + config.activation = act_name + + texts, tokenizer, tokens = load_and_cache_data(config) + config.vocab_size = len(tokenizer) + dataset = TextTokenDataset(tokens, config.max_seq_len) + val_size = len(dataset) // 10 + train_size = len(dataset) - val_size + train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=config.batch_size) + + start = time.time() + _, metrics = train_moe_with_activation(config, train_loader, val_loader) + elapsed = time.time() - start + + return { + "activation": act_name, + "train_time_sec": elapsed, + "time_per_step_sec": elapsed / num_steps, + **metrics, + } + +# ------------------------------- +# Plot results +# ------------------------------- +def plot_loss_curves(results, save_path="activation_benchmark.png"): + plt.figure(figsize=(10, 6)) + for r in results: + steps = list(range(1, len(r["train_losses"]) + 1)) + plt.plot(steps, r["train_losses"], label=f"{r['activation']} (train)") + if r["val_losses"]: + plt.plot(r["val_steps"], r["val_losses"], linestyle="--", label=f"{r['activation']} (val)") + plt.xlabel("Training Steps") + plt.ylabel("Loss") + plt.title("Activation Function Benchmark (MoE)") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig(save_path) + plt.show() + print(f"๐Ÿ“‰ Saved plot to {save_path}") + + +def main(): + # Load full dataset once + texts, tokenizer, tokens = load_and_cache_data(MoEModelConfig()) + dataset = TextTokenDataset(tokens, 128) + + all_results = [] + for act in ["relu", "silu", "gelu", "swiglu", "geglu"]: + res = benchmark_activation_scaling( + act, + dataset, + tokenizer, + dataset_fracs=[0.1, 0.3, 1.0] + ) + all_results.extend(res) + + + summary = fit_scaling_exponent(all_results) + print("\n๐Ÿ“Š Scaling Exponents") + print(f"{'Activation':>10} | {'Exponent b':>10} | {'Irreducible c':>12}") + print("-"*36) + for act, p in summary.items(): + print(f"{act:>10} | {p['b']:10.3f} | {p['c']:12.3f}") + + + return all_results + +if __name__=='__main__': + main() \ No newline at end of file