diff --git a/scripts/performance/argument_parser.py b/scripts/performance/argument_parser.py index aaf399cce9..ea5b8bbb3e 100644 --- a/scripts/performance/argument_parser.py +++ b/scripts/performance/argument_parser.py @@ -772,5 +772,17 @@ def parse_cli_args(): testing_args.add_argument( "--memory_threshold", type=float, default=0.05, help="Memory validation threshold (default: 0.05 = 5%%)" ) + testing_args.add_argument( + "--eval_time_start_step", + type=int, + default=None, + help="Start step (0-indexed, inclusive) for timing average window. Overrides skip_first_percent_time when set.", + ) + testing_args.add_argument( + "--eval_time_end_step", + type=int, + default=None, + help="End step (0-indexed, exclusive) for timing average window. If None, averages to end.", + ) return parser diff --git a/scripts/performance/setup_experiment.py b/scripts/performance/setup_experiment.py index 6afd3bcd50..b9fd5f5832 100755 --- a/scripts/performance/setup_experiment.py +++ b/scripts/performance/setup_experiment.py @@ -132,6 +132,8 @@ def build_performance_config(args) -> Optional[Dict[str, Any]]: performance_params = { "timing_threshold": args.timing_threshold, "skip_first_percent_time": args.skip_first_percent_time, + "eval_time_start_step": args.eval_time_start_step, + "eval_time_end_step": args.eval_time_end_step, } for key, value in performance_params.items(): @@ -639,6 +641,8 @@ def main( performance_params={ "timing_threshold": args.timing_threshold, "skip_first_percent_time": args.skip_first_percent_time, + "eval_time_start_step": args.eval_time_start_step, + "eval_time_end_step": args.eval_time_end_step, }, memory_params={ "memory_threshold": args.memory_threshold, diff --git a/scripts/performance/utils/evaluate.py b/scripts/performance/utils/evaluate.py index 89f83c4453..c86d35d959 100644 --- a/scripts/performance/utils/evaluate.py +++ b/scripts/performance/utils/evaluate.py @@ -325,9 +325,12 @@ def validate_performance( config = default_config # Discard first N% of iterations for stable comparison - skip = max(1, int(len(steps) * config["skip_first_percent_time"])) - current_stable = current_gpu_util_values[skip:] - golden_stable = golden_gpu_util_values[skip:] + start = config.get("eval_time_start_step") + if start is None: + start = max(1, int(len(steps) * config["skip_first_percent_time"])) + end = config.get("eval_time_end_step") + current_stable = current_gpu_util_values[start:end] + golden_stable = golden_gpu_util_values[start:end] current_avg = float(np.nanmean(current_stable)) golden_avg = float(np.nanmean(golden_stable)) @@ -339,7 +342,8 @@ def validate_performance( is_improvement = signed_diff > config["timing_threshold"] logger.info( - f"GPU utilization comparison (excluding first {config['skip_first_percent_time'] * 100:.1f}% of iterations):" + f"GPU utilization comparison (steps [{start}:{end if end is not None else len(steps)}] " + f"out of {len(steps)} total):" ) logger.info(f" Current average GPU util: {current_avg:.4f}%") logger.info(f" Golden average GPU util: {golden_avg:.4f}%") @@ -677,9 +681,12 @@ def calc_convergence_and_performance( wandb_run=wandb_run, ) # Add iter-time averages for debugging (not used for pass/fail) - skip = max(1, int(len(steps) * performance_config.get("skip_first_percent_time", 0.1))) - performance_result["metrics"]["current_avg_iter_time_ms"] = float(np.nanmean(current_iter_time_values[skip:])) - performance_result["metrics"]["golden_avg_iter_time_ms"] = float(np.nanmean(golden_iter_time_values[skip:])) + start = performance_config.get("eval_time_start_step") + if start is None: + start = max(1, int(len(steps) * performance_config.get("skip_first_percent_time", 0.1))) + end = performance_config.get("eval_time_end_step") + performance_result["metrics"]["current_avg_iter_time_ms"] = float(np.nanmean(current_iter_time_values[start:end])) + performance_result["metrics"]["golden_avg_iter_time_ms"] = float(np.nanmean(golden_iter_time_values[start:end])) if not performance_result["passed"]: direction = performance_result["metrics"]["direction"] signed_diff = performance_result["metrics"]["signed_diff"]