diff --git a/explorations/energy_efficiency_zeus.yaml b/explorations/energy_efficiency_zeus.yaml new file mode 100644 index 0000000000..74040e361d --- /dev/null +++ b/explorations/energy_efficiency_zeus.yaml @@ -0,0 +1,106 @@ +# energy_efficiency_zeus.yaml +--- + +# Atomic Static Groupings +named_static_groups: + - named_group: "qk_norm_scaled" + use_qk_norm: [true] + use_qk_norm_scale: [true] + + - named_group: "peri_ln_on" + use_peri_ln: [true] + - named_group: "peri_ln_off" + use_peri_ln: [false] + + - named_group: "softmax_flash_compiled" + softmax_variant_attn: ["softmax"] + softmax_variant_output: ["softmax"] + disable_flash_attention: [false] + compile: [true] + - named_group: "softmax_no_flash" + softmax_variant_attn: ["softmax"] + softmax_variant_output: ["softmax"] + disable_flash_attention: [true] + compile: [true] + - named_group: "relu2max_no_flash" + softmax_variant_attn: ["relu2max"] + softmax_variant_output: ["relu2max"] + disable_flash_attention: [true] + - named_group: "relu2max_no_flash_compiled" + softmax_variant_attn: ["relu2max"] + softmax_variant_output: ["relu2max"] + disable_flash_attention: [true] + compile: [true] + - named_group: "strongermax" + softmax_variant_attn: ["strongermax"] + softmax_variant_output: ["strongermax"] + disable_flash_attention: [true] + compile: [true] + + - named_group: "rotary" + named_group_settings: + use_rotary_embeddings: [true] + use_abs_pos_embeddings: [false] + - named_group: "abs" + named_group_settings: + use_rotary_embeddings: [false] + use_abs_pos_embeddings: [true] + - named_group: "nope" + named_group_settings: + use_rotary_embeddings: [false] + use_abs_pos_embeddings: [false] + +# Higher Level Groupings +named_variation_groups: + - named_group: "softmax_modes" + named_group_alternates: + - "softmax_flash_compiled" + - "softmax_no_flash" + - "relu2max_no_flash" + - "relu2max_no_flash_compiled" + - "strongermax" + - named_group: "peri_ln_modes" + named_group_alternates: + - "peri_ln_on" + - "peri_ln_off" + - named_group: "position" + named_group_alternates: + - "abs" + - "rotary" + - "nope" + - named_group: "activation_modes" + parameter_groups: + - activation_variant: ["gelu"] + - activation_variant: ["squared_relu"] + - named_group: "precision" + parameter_groups: + - dtype: ["float16"] + - dtype: ["bfloat16"] + +# Common_group: parameters applied to every run, but omitted from run names +common_group: + dataset: ["minipile"] + n_layer: [6] + n_embd: [384] + n_head: [6] + block_size: [256] + batch_size: [64] + max_iters: [10000] + eval_interval: [1000] + eval_iters: [100] + device: ["cuda"] + never_save_checkpoint: [true] + sample_each_eval: [true] + max_sample_tokens: [256] + zeus_profile: [true] + zeus_profile_gpu: [true] + +# Parameter_groups: define sets of overrides to apply on top of base params +parameter_groups: + - named_group_static: ["qk_norm_scaled"] + named_group_variations: + - "softmax_modes" + - "peri_ln_modes" + - "activation_modes" + - "position" + - "precision" diff --git a/optimization_and_search/run_experiments.py b/optimization_and_search/run_experiments.py index 70e6dafe55..789bfc05c4 100644 --- a/optimization_and_search/run_experiments.py +++ b/optimization_and_search/run_experiments.py @@ -25,6 +25,7 @@ "btc_per_param", "peak_gpu_mb", "iter_latency_avg", + "avg_joules_inf", "avg_top1_prob", "avg_top1_correct", "avg_target_rank", @@ -548,7 +549,7 @@ def read_metrics(out_dir: str) -> dict: line = path.read_text().strip() parts = [p.strip() for p in line.split(',')] - casts = [float, int, int, int, float, float, float, float, float, float, float, float, float, float, float, float, float] + casts = [float, int, int, int, float, float, float, float, float, float, float, float, float, float, float, float, float, float] return {k: typ(v) for k, typ, v in zip(METRIC_KEYS, casts, parts)} diff --git a/run_exploration_monitor.py b/run_exploration_monitor.py index 7c598c4746..161816b207 100644 --- a/run_exploration_monitor.py +++ b/run_exploration_monitor.py @@ -173,6 +173,7 @@ def on_mount(self) -> None: "num_params", "peak_gpu_mb", "iter_latency_avg", + "avg_joules_inf", "avg_top1_prob", "avg_top1_correct", "avg_target_rank", @@ -230,6 +231,7 @@ def get_cell(self, entry: Dict, col_name: str): "num_params", "peak_gpu_mb", "iter_latency_avg", + "avg_joules_inf", "avg_top1_prob", "avg_top1_correct", "avg_target_rank", diff --git a/sample.py b/sample.py index 4204cbd796..16f43da2c7 100644 --- a/sample.py +++ b/sample.py @@ -27,6 +27,7 @@ from torch.nn import functional as F from model import GPT, GPTConfig +from utils.energy_profiling import ZeusProfiler from utils.model_info import print_summary, print_module_structure, print_model_blocks from variations.model_variations import model_variation_dictionary @@ -36,6 +37,11 @@ def parse_args(): parser = argparse.ArgumentParser(description="Inference from trained models") parser.add_argument("--device", type=str, default="cuda", help="Device to run inference (e.g., 'cpu', 'cuda', 'cuda:0', 'cuda:1')") parser.add_argument("--out_dir", type=str, default="out", help="Directory to load checkpoint from") + parser.add_argument("--zeus_profile", default=False, action=argparse.BooleanOptionalAction, help="Enable Zeus energy profiling") + parser.add_argument("--zeus_profile_gpu", default=True, action=argparse.BooleanOptionalAction, help="Enable GPU energy profiling with Zeus") + parser.add_argument("--zeus_profile_cpu", default=False, action=argparse.BooleanOptionalAction, help="Enable CPU energy profiling with Zeus") + parser.add_argument("--zeus_gpu_indices", type=int, nargs="+", default=None, help="GPU indices to profile with Zeus") + parser.add_argument("--zeus_cpu_indices", type=int, nargs="+", default=None, help="CPU indices to profile with Zeus") parser.add_argument("--quantization_data_file", type=str, default=None, help="File name to export the quantized weights/activations, scale factor, and zero point") parser.add_argument("--init_from", type=str, default="resume", help="Either 'resume' (from an out_dir) or a GPT-2 variant (e.g., 'gpt2-xl')") parser.add_argument("--start", type=str, default="\n", help="Start text for generation. Can specify a file using 'FILE:prompt.txt'") @@ -474,6 +480,8 @@ def sample_with_existing_model( writer: Optional[object] = None, dataset_idx: Optional[int] = None, console: Console | None = None, + zeus_profiler: ZeusProfiler | None = None, + zeus_window_prefix: str = "inference", ): """ Generate text from an already-loaded GPT model. @@ -511,6 +519,8 @@ def sample_with_existing_model( modes_to_apply = valid_modes if colorize_mode == "all" else [colorize_mode] + energy_samples: List[float] = [] + for current_k in k_values: # Set a tag for logging/filenames based on the active sampling mode if args.softmax_threshold is not None: @@ -553,99 +563,97 @@ def sample_with_existing_model( scalar_rows: List[torch.Tensor] = [] ranks_list: List[int] = [] # NEW - with torch.no_grad(): - for _step in range(max_new_tokens): - idx_cond = ( - x - if x.size(1) <= model.config.block_size - else x[:, -model.config.block_size :] - ) - - model_logits, _ = model(idx_cond, dataset_idx=dataset_idx) - raw_logits_row = model_logits[:, -1, :] # Raw logits from model + window_name = f"{zeus_window_prefix}_{k_tag}_sample_{sample_idx}" + with zeus_profiler.window(window_name) if zeus_profiler else nullcontext() as zeus_window: + with torch.no_grad(): + for _step in range(max_new_tokens): + idx_cond = ( + x + if x.size(1) <= model.config.block_size + else x[:, -model.config.block_size :] + ) - # --- Apply Cosine Similarity Penalty (if enabled) --- - if args.cosine_penalty is not None: - N = 5 if len(args.cosine_penalty) < 1 else int(args.cosine_penalty[0]) - alpha = 1.0 if len(args.cosine_penalty) < 2 else args.cosine_penalty[1] + model_logits, _ = model(idx_cond, dataset_idx=dataset_idx) + raw_logits_row = model_logits[:, -1, :] # Raw logits from model - # Calculate original probabilities for comparison - probs_before = F.softmax(raw_logits_row / temperature, dim=-1) + # --- Apply Cosine Similarity Penalty (if enabled) --- + if args.cosine_penalty is not None: + N = 5 if len(args.cosine_penalty) < 1 else int(args.cosine_penalty[0]) + alpha = 1.0 if len(args.cosine_penalty) < 2 else args.cosine_penalty[1] + # Calculate original probabilities for comparison + probs_before = F.softmax(raw_logits_row / temperature, dim=-1) - # Apply penalty as long as there are tokens in the context and N > 0 - if x.size(1) > 0 and N > 0: - # Python's negative slicing gracefully handles cases where x.size(1) < N - last_n_tokens = x[0, -N:] + # Apply penalty as long as there are tokens in the context and N > 0 + if x.size(1) > 0 and N > 0: + # Python's negative slicing gracefully handles cases where x.size(1) < N + last_n_tokens = x[0, -N:] - embedding_matrix = model.transformer.wte.weight + embedding_matrix = model.transformer.wte.weight - # Normalize embeddings - last_n_embeds = F.normalize(embedding_matrix[last_n_tokens], p=2, dim=1) - all_embeds = F.normalize(embedding_matrix, p=2, dim=1) + # Normalize embeddings + last_n_embeds = F.normalize(embedding_matrix[last_n_tokens], p=2, dim=1) + all_embeds = F.normalize(embedding_matrix, p=2, dim=1) - # Calculate max cosine similarity for each candidate against the last N tokens - sim_matrix = torch.matmul(all_embeds, last_n_embeds.T) - max_sim_per_candidate, _ = torch.max(sim_matrix, dim=1) - penalty = alpha * max_sim_per_candidate - raw_logits_row = raw_logits_row - penalty + # Calculate max cosine similarity for each candidate against the last N tokens + sim_matrix = torch.matmul(all_embeds, last_n_embeds.T) + max_sim_per_candidate, _ = torch.max(sim_matrix, dim=1) + penalty = alpha * max_sim_per_candidate + raw_logits_row = raw_logits_row - penalty - # Calculate KL divergence to measure the change - probs_after = F.softmax(raw_logits_row / temperature, dim=-1) - # Add a small epsilon to avoid log(0) - kl_div = F.kl_div(torch.log(probs_after + 1e-9), probs_before, reduction='sum') - kl_divergences.append(kl_div.item()) + # Calculate KL divergence to measure the change + probs_after = F.softmax(raw_logits_row / temperature, dim=-1) + # Add a small epsilon to avoid log(0) + kl_div = F.kl_div(torch.log(probs_after + 1e-9), probs_before, reduction='sum') + kl_divergences.append(kl_div.item()) + logits = raw_logits_row / temperature # Scaled logits for sampling + full_row = logits[0].clone() # pre-mask - logits = raw_logits_row / temperature # Scaled logits for sampling - full_row = logits[0].clone() # pre-mask + # Apply the selected truncation logic + if args.softmax_threshold is not None: + # Calculate probabilities and find the threshold + probs = F.softmax(logits, dim=-1) + max_prob = torch.max(probs) + prob_threshold = max_prob * args.softmax_threshold + # Set probabilities of tokens below the threshold to 0 + probs[probs < prob_threshold] = 0 + topk_row = logits[0].clone() # post-mask - # Apply the selected truncation logic - if args.softmax_threshold is not None: - # Calculate probabilities and find the threshold - probs = F.softmax(logits, dim=-1) - max_prob = torch.max(probs) - prob_threshold = max_prob * args.softmax_threshold - # Set probabilities of tokens below the threshold to 0 - probs[probs < prob_threshold] = 0 + if args.softmax_threshold is not None: + # Calculate probabilities and find the threshold + probs = F.softmax(logits, dim=-1) + max_prob = torch.max(probs) + prob_threshold = max_prob * args.softmax_threshold + # Set probabilities of tokens below the threshold to 0 + probs[probs < prob_threshold] = 0 + # Sample from the modified, unnormalized distribution of probabilities + idx_next = torch.multinomial(probs, num_samples=1) + # For colorization, we can still use the unmasked logits + topk_row = logits[0].clone() + elif current_k is not None: + v, _ = torch.topk(logits, min(current_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("inf") + topk_row = logits[0].clone() # post-mask + probs = F.softmax(logits, dim=-1) # Re-softmax after masking + idx_next = torch.multinomial(probs, num_samples=1) + else: # No truncation / default case + topk_row = logits[0].clone() + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1) + x = torch.cat((x, idx_next), dim=1) - topk_row = logits[0].clone() # post-mask + if colorize_output: + chosen = idx_next.item() + # rank: 1 = best + rank = (full_row > full_row[chosen]).sum().item() + 1 - if args.softmax_threshold is not None: - # Calculate probabilities and find the threshold - probs = F.softmax(logits, dim=-1) - max_prob = torch.max(probs) - prob_threshold = max_prob * args.softmax_threshold - # Set probabilities of tokens below the threshold to 0 - probs[probs < prob_threshold] = 0 - # Sample from the modified, unnormalized distribution of probabilities - idx_next = torch.multinomial(probs, num_samples=1) - # For colorization, we can still use the unmasked logits - topk_row = logits[0].clone() - elif current_k is not None: - v, _ = torch.topk(logits, min(current_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float("inf") - topk_row = logits[0].clone() # post-mask - probs = F.softmax(logits, dim=-1) # Re-softmax after masking - idx_next = torch.multinomial(probs, num_samples=1) - else: # No truncation / default case - topk_row = logits[0].clone() - probs = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) - - x = torch.cat((x, idx_next), dim=1) - - if colorize_output: - chosen = idx_next.item() - # rank: 1 = best - rank = (full_row > full_row[chosen]).sum().item() + 1 - - tokens_for_color.append(chosen) - full_rows.append(full_row) - topk_rows.append(topk_row) - scalar_rows.append(full_row[chosen]) + tokens_for_color.append(chosen) + full_rows.append(full_row) + topk_rows.append(topk_row) + scalar_rows.append(full_row[chosen]) if args.show_minmax_chart: pre_temp_scalar_rows.append(raw_logits_row[0, chosen]) ranks_list.append(rank) @@ -666,6 +674,12 @@ def sample_with_existing_model( ) + if zeus_profiler: + total_energy = getattr(zeus_window, "total_energy_joules", None) + if total_energy is not None: + energy_samples.append(total_energy) + console.print(f"[bold cyan]Zeus energy[/bold cyan] {total_energy:.4f} J") + # ---------- Print summary statistics for this sample ------------------ if kl_divergences: avg_kl = np.mean(kl_divergences) @@ -677,7 +691,6 @@ def sample_with_existing_model( pre_temp_scalar_rows, out_dir, k_tag, sample_idx ) - # ---------- decode plain text ----------------------------------- plain_text = decode(x[0].tolist()) if token_boundary is not None: @@ -769,6 +782,16 @@ def sample_with_existing_model( f"{run_name}_{k_tag}" if run_name else k_tag, ) + if energy_samples: + avg_energy = sum(energy_samples) / len(energy_samples) + console.print(f"[bold cyan]Average Zeus energy[/bold cyan] {avg_energy:.4f} J") + return { + "avg_joules": avg_energy, + "per_sample_joules": energy_samples, + } + + return None + def interactive_generation(model, start_ids, device, max_new_tokens, temperature, top_k, stop_string, decode, encode): x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] @@ -1146,6 +1169,7 @@ def main(): device_type = 'cuda' if 'cuda' in args.device else 'cpu' ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[args.dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) + zeus_profiler = ZeusProfiler.from_args(args, device=args.device) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") out_dir = os.path.join(args.out_dir, timestamp) @@ -1383,6 +1407,7 @@ def main(): decode_lookup[dataset_name] = decode_i block_size = args.block_size if args.block_size else model.config.block_size + energy_samples: List[float] = [] with torch.no_grad(), ctx: for sample_idx in range(args.num_samples): if args.use_lsv and hasattr(args, 'lsv_size'): @@ -1397,52 +1422,60 @@ def main(): token_state = {name: tensor.clone() for name, tensor in initial_tokens.items()} - for _ in range(args.max_new_tokens): - idx_cond_dict = {} - for name in dataset_names: - tokens = token_state[name] - idx_cond_dict[name] = tokens if tokens.size(1) <= block_size else tokens[:, -block_size:] - - logits_list, _ = model(None, token_dict=idx_cond_dict, target_dict=None) - - for i, name in enumerate(dataset_names): - if model.config.numerical_multicontext: - preds = logits_list[i][:, -1] - preds = preds.squeeze(-1) - if preds.ndim == 0: - preds = preds.unsqueeze(0) - rounded = preds.round() - min_val = 0.0 - max_val = None - meta_info = dataset_meta.get(name, {}) - tokenizer_name = meta_info.get('tokenizer') if isinstance(meta_info, dict) else None - if tokenizer_name == 'sinewave': - max_val = 255.0 - elif isinstance(meta_info, dict) and 'vocab_size' in meta_info: - max_val = float(meta_info['vocab_size'] - 1) - - if max_val is not None: - rounded = torch.clamp(rounded, min=min_val, max=max_val) + window_name = f"multicontext_sample_{sample_idx}" + with zeus_profiler.window(window_name) if zeus_profiler else nullcontext() as zeus_window: + for _ in range(args.max_new_tokens): + idx_cond_dict = {} + for name in dataset_names: + tokens = token_state[name] + idx_cond_dict[name] = tokens if tokens.size(1) <= block_size else tokens[:, -block_size:] + + logits_list, _ = model(None, token_dict=idx_cond_dict, target_dict=None) + + for i, name in enumerate(dataset_names): + if model.config.numerical_multicontext: + preds = logits_list[i][:, -1] + preds = preds.squeeze(-1) + if preds.ndim == 0: + preds = preds.unsqueeze(0) + rounded = preds.round() + min_val = 0.0 + max_val = None + meta_info = dataset_meta.get(name, {}) + tokenizer_name = meta_info.get('tokenizer') if isinstance(meta_info, dict) else None + if tokenizer_name == 'sinewave': + max_val = 255.0 + elif isinstance(meta_info, dict) and 'vocab_size' in meta_info: + max_val = float(meta_info['vocab_size'] - 1) + + if max_val is not None: + rounded = torch.clamp(rounded, min=min_val, max=max_val) + else: + rounded = torch.clamp(rounded, min=min_val) + + idx_next = rounded.to(torch.long).unsqueeze(-1) else: - rounded = torch.clamp(rounded, min=min_val) - - idx_next = rounded.to(torch.long).unsqueeze(-1) - else: - cur_logits = logits_list[i][:, -1, :] / args.temperature - if args.top_k is not None: - top_k_val = ( - args.top_k[0] - if isinstance(args.top_k, (list, tuple)) - else args.top_k - ) - k = min(top_k_val, cur_logits.size(-1)) - v, _ = torch.topk(cur_logits, k) - cur_logits[cur_logits < v[:, [-1]]] = -float("inf") - - probs = F.softmax(cur_logits, dim=-1) - idx_next = torch.multinomial(probs, num_samples=1) - - token_state[name] = torch.cat((token_state[name], idx_next), dim=1) + cur_logits = logits_list[i][:, -1, :] / args.temperature + if args.top_k is not None: + top_k_val = ( + args.top_k[0] + if isinstance(args.top_k, (list, tuple)) + else args.top_k + ) + k = min(top_k_val, cur_logits.size(-1)) + v, _ = torch.topk(cur_logits, k) + cur_logits[cur_logits < v[:, [-1]]] = -float("inf") + + probs = F.softmax(cur_logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1) + + token_state[name] = torch.cat((token_state[name], idx_next), dim=1) + + if zeus_profiler: + total_energy = getattr(zeus_window, "total_energy_joules", None) + if total_energy is not None: + energy_samples.append(total_energy) + print(f"Zeus energy: {total_energy:.4f} J") output_dict: Dict[str, str] = {} for name in dataset_names: @@ -1459,6 +1492,10 @@ def main(): with open(args.sample_file, "w") as file: for name, text in output_dict.items(): file.write(f"\n{name}: \n{text}\n") + + if energy_samples: + avg_energy = sum(energy_samples) / len(energy_samples) + print(f"Average Zeus energy: {avg_energy:.4f} J") else: sample_with_existing_model( model, @@ -1479,8 +1516,9 @@ def main(): sample_file=args.sample_file, args=args, dataset_idx=0, + zeus_profiler=zeus_profiler, + zeus_window_prefix="sample", ) if __name__ == "__main__": main() - diff --git a/train.py b/train.py index 0b8e98595a..6534d678b7 100644 --- a/train.py +++ b/train.py @@ -46,6 +46,7 @@ compute_activation_stats, print_model_stats_table, ) +from utils.energy_profiling import ZeusProfiler from sample import ( sample_with_existing_model, @@ -118,6 +119,7 @@ def __init__(self, args, model_group, training_group, logging_group): self.latest_left_prob_95 = float('nan') self.latest_ln_f_cosine = float('nan') self.latest_ln_f_cosine_95 = float('nan') + self.latest_avg_joules_inf = float('nan') # store overall statistics for weights and activations self.latest_overall_weight_stats = { @@ -243,6 +245,7 @@ def setup(self): self.ptdtype = {"bfloat16" : torch.bfloat16, "float16" : torch.float16, "float32" : torch.float32}[self.args.dtype] self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=self.ptdtype) + self.zeus_profiler = ZeusProfiler.from_args(self.args, device=self.device) # Model settings # TODO only add if they are defined from the argparse @@ -581,6 +584,7 @@ def sample_and_print(self): sample_iterations = 1 self.model.eval() + energy_samples: list[float] = [] if self.args.dataset_list is not None: sample_iterations = len(self.args.dataset_list) @@ -600,7 +604,7 @@ def sample_and_print(self): start_ids = torch.tensor(encode_fn(self.args.sample_start_tokens), dtype=torch.long, device=self.device)[None, ...] with torch.no_grad(): - sample_with_existing_model( + energy_result = sample_with_existing_model( model=self.model, start_ids=start_ids, start_tokens=self.args.sample_start_tokens, @@ -623,12 +627,26 @@ def sample_and_print(self): writer=self.writer if self.args.tensorboard_log else None, dataset_idx=i if hasattr(self, 'encode_dict') else None, console=self.console, + zeus_profiler=self.zeus_profiler, + zeus_window_prefix="training_sample", ) + if energy_result and energy_result.get("per_sample_joules"): + energy_samples.extend(energy_result["per_sample_joules"]) # After sampling from the model, optionally run simple dataset benchmarks if self.args.dataset_benchmarks and self.args.max_sample_tokens: self.run_dataset_benchmarks() + if energy_samples: + self.latest_avg_joules_inf = sum(energy_samples) / len(energy_samples) + self.console.print( + f"[bold cyan]Average inference energy[/bold cyan] {self.latest_avg_joules_inf:.4f} J" + ) + if self.args.tensorboard_log and self.writer is not None: + self.writer.add_scalar("avg_joules_inf", self.latest_avg_joules_inf, self.iter_num) + else: + self.latest_avg_joules_inf = float('nan') + self.model.train() self.console.rule("[bold green]End Samples[/bold green]") self.console.print("\n"*8) @@ -1642,11 +1660,24 @@ def run_validation_step(self, running_mfu, current_epoch, current_dataset, num_s print(f"Saved major checkpoint to {self.args.out_dir}/{major_ckpt_name}") if losses['val'] < self.best_val_loss or self.args.always_save_checkpoint: - if losses['val'] < self.best_val_loss: + improved = losses['val'] < self.best_val_loss + if improved: self.best_val_loss = losses['val'] self.best_iter = self.iter_num self.best_tokens = self.tokens_trained peak_mb = self.peak_gpu_usage / (1024 ** 2) + num_steps_with_worse_loss = 0 + if self.iter_num > 0 and not self.args.never_save_checkpoint: + print(f"saving checkpoint to {self.args.out_dir}") + self.save_checkpoint('ckpt.pt') + + if self.args.max_sample_tokens: + if live: + live.stop() + self.sample_and_print() + if live: + live.start() + if improved: with open(os.path.join(self.args.out_dir, 'best_val_loss_and_iter.txt'), "w") as best_loss_file: chance_ratio = self.model_args['vocab_size']/math.exp(self.best_val_loss.item()) metrics = [ @@ -1658,6 +1689,7 @@ def run_validation_step(self, running_mfu, current_epoch, current_dataset, num_s f"{chance_ratio/self.model.num_param:.3e}", f"{peak_mb:.1f}", f"{self.iter_latency_avg:.1f}", + f"{self.latest_avg_joules_inf:.6f}", f"{self.latest_top1_prob:.6f}", f"{self.latest_top1_correct:.6f}", f"{self.latest_target_rank:.2f}", @@ -1679,17 +1711,6 @@ def run_validation_step(self, running_mfu, current_epoch, current_dataset, num_s f"{self.latest_overall_activation_stats['abs_max']:.6f}", ] best_loss_file.write(", ".join(metrics) + "\n") - num_steps_with_worse_loss = 0 - if self.iter_num > 0 and not self.args.never_save_checkpoint: - print(f"saving checkpoint to {self.args.out_dir}") - self.save_checkpoint('ckpt.pt') - - if self.args.max_sample_tokens: - if live: - live.stop() - self.sample_and_print() - if live: - live.start() if self.args.export_wte_npy: self.raw_model.export_wte(self.args.export_wte_npy) if self.args.export_scale_matrices_npz: @@ -2086,4 +2107,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/train_args.py b/train_args.py index 94767f9cc3..653fe61b5e 100644 --- a/train_args.py +++ b/train_args.py @@ -1337,6 +1337,11 @@ def parse_args(): logging_group.add_argument('--tensorboard_log_dir', type=str, default='logs') logging_group.add_argument('--tensorboard_run_name', type=str, default=None) logging_group.add_argument('--tensorboard_graph', default=True, action=argparse.BooleanOptionalAction) + logging_group.add_argument('--zeus_profile', default=False, action=argparse.BooleanOptionalAction, help="Enable Zeus energy profiling") + logging_group.add_argument('--zeus_profile_gpu', default=True, action=argparse.BooleanOptionalAction, help="Enable GPU energy profiling with Zeus") + logging_group.add_argument('--zeus_profile_cpu', default=False, action=argparse.BooleanOptionalAction, help="Enable CPU energy profiling with Zeus") + logging_group.add_argument('--zeus_gpu_indices', type=int, nargs="+", default=None, help="GPU indices to profile with Zeus") + logging_group.add_argument('--zeus_cpu_indices', type=int, nargs="+", default=None, help="CPU indices to profile with Zeus") # Metric logging toggles logging_group.add_argument('--log_btc_train', default=False, action=argparse.BooleanOptionalAction, help='Log better-than-chance training metrics') @@ -1449,4 +1454,3 @@ class LayerListAction(argparse.Action): """ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, list(values)) - diff --git a/utils/energy_profiling/__init__.py b/utils/energy_profiling/__init__.py new file mode 100644 index 0000000000..cb20b2c413 --- /dev/null +++ b/utils/energy_profiling/__init__.py @@ -0,0 +1,5 @@ +"""Utilities for optional Zeus energy profiling.""" + +from utils.energy_profiling.zeus_profiler import ZeusProfiler + +__all__ = ["ZeusProfiler"] diff --git a/utils/energy_profiling/zeus_profiler.py b/utils/energy_profiling/zeus_profiler.py new file mode 100644 index 0000000000..86534fb96e --- /dev/null +++ b/utils/energy_profiling/zeus_profiler.py @@ -0,0 +1,111 @@ +"""Reusable helpers for optional Zeus energy profiling.""" + +from __future__ import annotations + +from importlib.util import find_spec +from typing import Iterable, Optional + + +_ZEUS_AVAILABLE = find_spec("zeus.monitor") is not None +if _ZEUS_AVAILABLE: + from zeus.monitor import ZeusMonitor +else: + ZeusMonitor = None + +class ZeusWindow: + def __init__(self, monitor: ZeusMonitor | None, name: str, enabled: bool) -> None: + self._monitor = monitor + self._name = name + self._enabled = enabled + self._measurement = None + self._total_energy_joules = None + + def __enter__(self) -> "ZeusWindow": + if self._enabled and self._monitor is not None: + self._monitor.begin_window(self._name) + return self + + def __exit__(self, exc_type, exc, exc_tb) -> None: + if not self._enabled or self._monitor is None: + return + measurement = self._monitor.end_window(self._name) + total_energy = _extract_total_energy_joules(measurement) + self._measurement = measurement + self._total_energy_joules = total_energy + + @property + def total_energy_joules(self) -> float | None: + return self._total_energy_joules + + @property + def measurement(self) -> object | None: + return self._measurement + + +class ZeusProfiler: + def __init__( + self, + enabled: bool, + gpu_indices: Optional[Iterable[int]] = None, + cpu_indices: Optional[Iterable[int]] = None, + ) -> None: + self._enabled = enabled and _ZEUS_AVAILABLE + self._gpu_indices = list(gpu_indices) if gpu_indices is not None else None + self._cpu_indices = list(cpu_indices) if cpu_indices is not None else None + self._monitor = None + if self._enabled: + self._monitor = ZeusMonitor( + gpu_indices=self._gpu_indices, + cpu_indices=self._cpu_indices, + ) + elif enabled and not _ZEUS_AVAILABLE: + print("Zeus profiling requested, but zeus.monitor is unavailable.") + + @classmethod + def from_args(cls, args, device: str) -> "ZeusProfiler": + if not getattr(args, "zeus_profile", False): + return cls(enabled=False) + + use_gpu = bool(getattr(args, "zeus_profile_gpu", True)) and "cuda" in device + use_cpu = bool(getattr(args, "zeus_profile_cpu", False)) + + gpu_indices = None + cpu_indices = None + + if use_gpu: + gpu_indices = getattr(args, "zeus_gpu_indices", None) + if gpu_indices is None: + import torch + + gpu_indices = [torch.cuda.current_device()] + + if use_cpu: + cpu_indices = getattr(args, "zeus_cpu_indices", None) + + return cls(enabled=use_gpu or use_cpu, gpu_indices=gpu_indices, cpu_indices=cpu_indices) + + def window(self, name: str) -> ZeusWindow: + return ZeusWindow(self._monitor, name, self._enabled) + + @property + def enabled(self) -> bool: + return self._enabled + + +def _extract_total_energy_joules(measurement: object) -> float | None: + if measurement is None: + return None + + if hasattr(measurement, "total_energy"): + total_energy = getattr(measurement, "total_energy") + if isinstance(total_energy, dict): + return float(sum(total_energy.values())) + return float(total_energy) + + if hasattr(measurement, "energy"): + energy = getattr(measurement, "energy") + if isinstance(energy, dict): + return float(sum(energy.values())) + return float(energy) + + return None