From fe0ecc277738198095550e9398e19de26f979250 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Sat, 3 Jan 2026 18:20:27 +0800 Subject: [PATCH 01/14] fix(mlx_metal): apply evolved attention hook inside subprocess benchmarks --- examples/mlx_metal_kernel_opt/evaluator.py | 34 +++++- .../mlx_lm_generate_with_hook.py | 105 ++++++++++++++++++ .../qwen3_benchmark_suite.py | 50 +++++++-- .../mlx_metal_kernel_opt/run_benchmarks.py | 39 ++----- 4 files changed, 184 insertions(+), 44 deletions(-) create mode 100644 examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py diff --git a/examples/mlx_metal_kernel_opt/evaluator.py b/examples/mlx_metal_kernel_opt/evaluator.py index 62fbe8e71..76ffbf4a3 100644 --- a/examples/mlx_metal_kernel_opt/evaluator.py +++ b/examples/mlx_metal_kernel_opt/evaluator.py @@ -136,6 +136,7 @@ def evaluate(self, program_text: str) -> Dict[str, Any]: ) custom_attention_class = extraction_result["class"] + program_source = extraction_result["program_text"] # Step 2: Pre-execution Metal kernel safety validation print("\n🔍 STEP 2: Pre-execution Metal Kernel Safety Validation") @@ -168,7 +169,9 @@ def evaluate(self, program_text: str) -> Dict[str, Any]: # Step 5: Command-buffer-protected benchmarking print("\n🚀 STEP 5: Command-Buffer-Protected Performance Benchmarking") - benchmark_result = self._command_buffer_protected_benchmark(custom_attention_class) + benchmark_result = self._command_buffer_protected_benchmark( + program_source, custom_attention_class + ) if not benchmark_result["success"]: return self._create_comprehensive_failure_result( f"Command-buffer-protected benchmarking failed: {benchmark_result['error']}" @@ -191,6 +194,7 @@ def evaluate(self, program_text: str) -> Dict[str, Any]: result = { "success": True, "final_score": final_score, + "combined_score": final_score, "performance_metrics": performance_analysis["aggregate_metrics"], "correctness_score": correctness_score, "benchmark_results": [self._result_to_dict(r) for r in custom_results], @@ -295,7 +299,12 @@ def _bulletproof_extract_custom_attention(self, program_text: str) -> Dict[str, print(f" ✅ Successfully extracted and validated CustomGQAAttention class") print(f" 🛡️ Metal safety pre-checks: {metal_validation['safe']}") - return {"success": True, "class": custom_class, "metal_validation": metal_validation} + return { + "success": True, + "class": custom_class, + "metal_validation": metal_validation, + "program_text": actual_program_text, + } except Exception as e: self.total_metal_errors += 1 @@ -721,7 +730,9 @@ def _test_single_sequence_memory_safe( else: raise ValueError(f"Sequence test error: {error_msg}") - def _command_buffer_protected_benchmark(self, custom_attention_class: Any) -> Dict[str, Any]: + def _command_buffer_protected_benchmark( + self, program_text: str, custom_attention_class: Any + ) -> Dict[str, Any]: """Command-buffer-protected benchmarking with maximum safety""" print(" 🚀 Running command-buffer-protected benchmarking...") @@ -748,8 +759,17 @@ def _command_buffer_protected_benchmark(self, custom_attention_class: Any) -> Di } original_attention = hook_result["original"] + temp_program_path = None try: + # Ensure the evolved program is available to the subprocess that runs mlx_lm.generate. + # Monkey-patching in this evaluator process does NOT propagate across subprocess boundaries. + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f: + f.write(program_text) + temp_program_path = f.name + + self.benchmark_suite.hook_program_path = temp_program_path + # Run benchmarks with command buffer protection custom_configs = self._get_safe_benchmark_configs() custom_results = [] @@ -820,6 +840,13 @@ def _command_buffer_protected_benchmark(self, custom_attention_class: Any) -> Di return {"success": False, "error": error_msg} finally: + # Always clear subprocess hook settings and clean up temp program + self.benchmark_suite.hook_program_path = None + if temp_program_path: + try: + os.unlink(temp_program_path) + except OSError: + pass # Always restore original attention self._gpu_protected_remove_hook(original_attention) @@ -1333,6 +1360,7 @@ def _create_comprehensive_failure_result(self, error_message: str) -> Dict[str, return { "success": False, "final_score": -1000.0, + "combined_score": -1000.0, "error": error_message, "performance_metrics": {}, "correctness_score": 0.0, diff --git a/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py b/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py new file mode 100644 index 000000000..552da824d --- /dev/null +++ b/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +""" +Run MLX-LM generation with a custom attention monkey-patch applied *inside* the +current process. + +Why this exists +--------------- +Many benchmarking utilities run `mlx_lm.generate` via `subprocess.run(...)`. +Any monkey-patch done in the parent process (e.g. replacing +`mlx_lm.models.qwen3.Attention`) does NOT propagate into the child process. + +This wrapper makes the patch effective by: +1) loading an evolved program file (e.g. best_program.py) +2) calling its `create_metal_qwen3_optimization_hook()` to apply the patch +3) running `mlx_lm.generate` in the same process (via `runpy`) +""" + +from __future__ import annotations + +import argparse +import importlib.util +import runpy +import sys +from types import ModuleType +from typing import List, Optional, Tuple, Any + + +def _load_module_from_path(module_path: str) -> ModuleType: + spec = importlib.util.spec_from_file_location("openevolve_mlx_metal_hook_program", module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Failed to load hook program from: {module_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _apply_hook_from_program(module_path: str) -> Tuple[Any, Any]: + program = _load_module_from_path(module_path) + + hook_factory = getattr(program, "create_metal_qwen3_optimization_hook", None) + if hook_factory is None: + raise RuntimeError( + "Hook factory `create_metal_qwen3_optimization_hook()` not found in hook program." + ) + + apply_hook, remove_hook = hook_factory() + original_attention = apply_hook() + if original_attention is None: + raise RuntimeError("Failed to apply optimization hook (original_attention is None).") + + return original_attention, remove_hook + + +def main(argv: Optional[List[str]] = None) -> int: + parser = argparse.ArgumentParser( + description="Run `mlx_lm.generate` with a custom attention hook applied in-process." + ) + parser.add_argument("--hook-program", required=True, help="Path to evolved program (e.g. best_program.py)") + parser.add_argument("--model", required=True, help="Model id/path to pass to mlx_lm.generate") + parser.add_argument("--prompt", required=True, help="Prompt string") + parser.add_argument("--max-tokens", required=True, type=int, help="Max tokens to generate") + + args = parser.parse_args(argv) + + original_attention = None + remove_hook = None + try: + original_attention, remove_hook = _apply_hook_from_program(args.hook_program) + + # Mimic `python -m mlx_lm.generate ...` + sys.argv = [ + "mlx_lm.generate", + "--model", + args.model, + "--prompt", + args.prompt, + "--max-tokens", + str(args.max_tokens), + ] + + try: + runpy.run_module("mlx_lm.generate", run_name="__main__") + return 0 + except SystemExit as e: + # Preserve the exit code from mlx_lm.generate + code = e.code + if code is None: + return 0 + if isinstance(code, int): + return code + return 1 + + finally: + if remove_hook is not None and original_attention is not None: + try: + remove_hook(original_attention) + except Exception: + # Non-fatal in a one-shot subprocess wrapper + pass + + +if __name__ == "__main__": + raise SystemExit(main()) + + diff --git a/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py b/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py index f35bb7c2c..f641fce5e 100644 --- a/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py +++ b/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py @@ -17,6 +17,7 @@ import subprocess import tempfile import os +import sys from dataclasses import dataclass from typing import Dict, List, Tuple, Optional import mlx.core as mx @@ -53,8 +54,15 @@ class BenchmarkConfig: class Qwen3BenchmarkSuite: """Comprehensive benchmark suite for Qwen3-0.6B Metal kernel optimization""" - def __init__(self, model_path: str = "mlx-community/Qwen3-0.6B-bf16"): + def __init__( + self, + model_path: str = "mlx-community/Qwen3-0.6B-bf16", + hook_program_path: Optional[str] = None, + ): self.model_path = model_path + # When set, benchmarks will run via `mlx_lm_generate_with_hook.py` so that + # the attention monkey-patch is applied inside the subprocess. + self.hook_program_path = hook_program_path self.results: List[BenchmarkResult] = [] def create_benchmark_configs(self) -> List[BenchmarkConfig]: @@ -566,17 +574,35 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult: try: # Build command - cmd = [ - "python", - "-m", - "mlx_lm.generate", - "--model", - self.model_path, - "--prompt", - config.prompt, - "--max-tokens", - str(config.max_tokens), - ] + if self.hook_program_path: + wrapper_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "mlx_lm_generate_with_hook.py", + ) + cmd = [ + sys.executable, + wrapper_path, + "--hook-program", + self.hook_program_path, + "--model", + self.model_path, + "--prompt", + config.prompt, + "--max-tokens", + str(config.max_tokens), + ] + else: + cmd = [ + sys.executable, + "-m", + "mlx_lm.generate", + "--model", + self.model_path, + "--prompt", + config.prompt, + "--max-tokens", + str(config.max_tokens), + ] # Clear MLX cache before starting print(f"🧹 Clearing MLX cache...") diff --git a/examples/mlx_metal_kernel_opt/run_benchmarks.py b/examples/mlx_metal_kernel_opt/run_benchmarks.py index bc7c5fc2b..ed689405d 100644 --- a/examples/mlx_metal_kernel_opt/run_benchmarks.py +++ b/examples/mlx_metal_kernel_opt/run_benchmarks.py @@ -147,38 +147,19 @@ def run_optimized_benchmark(args, original_dir): ) return None - # Apply the custom attention hook - apply_hook, remove_hook = best_program.create_metal_qwen3_optimization_hook() - print("🔧 Applying custom Metal kernel optimized attention hook...") + # IMPORTANT: the benchmark suite runs `mlx_lm.generate` in a subprocess. + # Monkey-patching Attention in this parent process does not propagate to the subprocess. + # Instead, we pass the evolved program path so the subprocess can apply the hook in-process. - original_attention = apply_hook() + print("📊 Running full benchmark suite with custom Metal kernel optimization...") + print("⏳ This will take another 15-30 minutes...") + print("💡 The optimization uses a custom Metal kernel implementation for Apple Silicon GPU") - if original_attention is None: - print("❌ Failed to apply custom Metal kernel optimization hook") - print("This may indicate MLX-LM import issues or incompatible environment") - return None - - print("✅ Custom Metal kernel optimization hook applied successfully") + optimized_suite = Qwen3BenchmarkSuite(args.model, hook_program_path=best_program_path) + optimized_results = optimized_suite.run_full_benchmark_suite() - try: - # Run benchmarks with optimized attention - print("📊 Running full benchmark suite with custom Metal kernel optimization...") - print("⏳ This will take another 15-30 minutes...") - print( - "💡 The optimization uses custom Metal kernel implementation for Apple Silicon GPU" - ) - - optimized_suite = Qwen3BenchmarkSuite(args.model) - optimized_results = optimized_suite.run_full_benchmark_suite() - - print("✅ Custom Metal kernel benchmark suite completed successfully") - return optimized_results - - finally: - # Always remove the hook to restore original behavior - print("🔄 Restoring standard attention...") - remove_hook(original_attention) - print("✅ Standard attention restored") + print("✅ Custom Metal kernel benchmark suite completed successfully") + return optimized_results except Exception as e: print(f"❌ Error running Metal kernel optimized benchmark: {e}") From a4038229aa99dbb1c1611699e6933a4eb369d1c0 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Sun, 4 Jan 2026 09:12:37 +0800 Subject: [PATCH 02/14] fix(mlx_metal): align eval dtype and head config with Qwen3-0.6B-bf16 --- examples/mlx_metal_kernel_opt/README.md | 8 ++-- examples/mlx_metal_kernel_opt/best_program.py | 44 +++++++++--------- examples/mlx_metal_kernel_opt/config.yaml | 12 ++--- examples/mlx_metal_kernel_opt/evaluator.py | 46 +++++++++++++------ .../mlx_metal_kernel_opt/initial_program.py | 44 +++++++++--------- .../qwen3_benchmark_suite.py | 2 +- 6 files changed, 88 insertions(+), 68 deletions(-) diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index 8a0c35136..0bc6d8650 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -14,7 +14,7 @@ Modern transformer models rely heavily on optimized attention kernels for effici ### 1.2 Target System -- **Model**: Qwen3-0.6B with Grouped Query Attention (40 query heads : 8 key-value heads) +- **Model**: Qwen3-0.6B with Grouped Query Attention (16 query heads : 8 key-value heads) - **Hardware**: Apple M-series GPUs with unified memory architecture - **Framework**: MLX with custom Metal kernel integration - **Baseline**: `mx.fast.scaled_dot_product_attention` @@ -111,7 +111,7 @@ for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + head_idx * (SEQ_LEN * HEAD_DIM) + query_pos * HEAD_DIM; -const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 5:1 mapping +const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 2:1 mapping ``` **Innovation**: Leverages unified memory bandwidth through coalesced access patterns and direct GQA head mapping. @@ -186,7 +186,7 @@ The evolved kernel shows workload-dependent performance characteristics: **Algorithm Innovation**: The two-pass online softmax represents a novel contribution, demonstrating that evolutionary approaches can discover algorithmic improvements beyond simple micro-optimizations. -**GQA Specialization**: Direct exploitation of the 5:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. +**GQA Specialization**: Direct exploitation of the 2:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. ### 5.3 Evolutionary Process Analysis @@ -211,7 +211,7 @@ Our approach differs by applying evolutionary optimization directly to GPU shade ### 7.1 Current Limitations - **Workload Specificity**: Performance improvements are highly dependent on sequence patterns -- **Model Scope**: Results specific to Qwen3-0.6B's 40:8 GQA configuration +- **Model Scope**: Results specific to Qwen3-0.6B's 16:8 GQA configuration - **Hardware Scope**: Optimizations specific to Apple Silicon architecture ### 7.2 Future Directions diff --git a/examples/mlx_metal_kernel_opt/best_program.py b/examples/mlx_metal_kernel_opt/best_program.py index a94d94c92..09b8f12c1 100644 --- a/examples/mlx_metal_kernel_opt/best_program.py +++ b/examples/mlx_metal_kernel_opt/best_program.py @@ -1,11 +1,11 @@ """ Qwen3 Custom Metal Kernel for Grouped Query Attention (GQA) Optimization -This module implements a custom Metal kernel for Qwen3's 40:8 GQA pattern using +This module implements a custom Metal kernel for Qwen3's 16:8 GQA pattern using MLX's metal_kernel API. The kernel is designed to outperform mx.fast.scaled_dot_product_attention -by leveraging Apple Silicon specific optimizations and the 5:1 query-to-KV head ratio. +by leveraging Apple Silicon specific optimizations and the 2:1 query-to-KV head ratio. -Target: Qwen3-0.6B with 40 query heads : 8 KV heads +Target: Qwen3-0.6B with 16 query heads : 8 KV heads Hardware: Apple M-series GPUs with unified memory Baseline: Standard MLX-LM using mx.fast.scaled_dot_product_attention Goal: 5-15% performance improvement through custom Metal kernel optimization @@ -26,19 +26,19 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): Custom Metal kernel implementation for Qwen3 GQA attention. Args: - queries: [B, num_heads=40, L, head_dim=128] + queries: [B, num_heads=16, L, head_dim=128] keys: [B, num_kv_heads=8, L, head_dim=128] values: [B, num_kv_heads=8, L, head_dim=128] scale: Attention scaling factor (1/sqrt(head_dim)) mask: Attention mask (None, "causal", or boolean tensor) Returns: - Attention output [B, num_heads=40, L, head_dim=128] + Attention output [B, num_heads=16, L, head_dim=128] """ B, num_heads, L, head_dim = queries.shape _, num_kv_heads, _, _ = keys.shape - heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3 + heads_per_kv = num_heads // num_kv_heads # 2 for Qwen3-0.6B (16:8) # Handle mask conversion if mask == "causal" or mask is None: @@ -67,9 +67,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): # EVOLVE-BLOCK-START # Custom Metal kernel source for Qwen3 GQA optimization - # This kernel leverages the 40:8 head ratio and Apple Silicon architecture + # This kernel leverages the 16:8 head ratio and Apple Silicon architecture kernel_source = """ - // Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern + // Qwen3 GQA Metal Kernel - Optimized for 16:8 head pattern // Thread mapping: each thread processes one query position uint thread_id = thread_position_in_grid.x; uint head_idx = thread_position_in_grid.y; @@ -86,7 +86,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): bool use_mask_val = use_mask[0] > 0; // GQA mapping: determine which KV head corresponds to this query head - uint kv_head_idx = head_idx / HEADS_PER_KV; // 5 query heads per KV head + uint kv_head_idx = head_idx / HEADS_PER_KV; // 2 query heads per KV head // Pre-calculate base indices for memory access optimization const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + @@ -248,8 +248,8 @@ def __init__(self, args): super().__init__() # Standard Qwen3 parameters - dim = args.hidden_size # 5120 - self.n_heads = n_heads = args.num_attention_heads # 40 + dim = args.hidden_size # 2048 + self.n_heads = n_heads = args.num_attention_heads # 16 assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads # 8 head_dim = args.head_dim # 128 @@ -362,8 +362,8 @@ def benchmark_metal_gqa_optimization(): # Qwen3-0.6B configuration class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 + hidden_size = 2048 + num_attention_heads = 16 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -375,10 +375,10 @@ class MockArgs: # Test configurations for Metal kernel validation test_configs = [ - ("short_sequence", 1, 128, 5120), - ("medium_sequence", 1, 512, 5120), - ("long_sequence", 1, 1024, 5120), - ("max_sequence", 1, 2048, 5120), + ("short_sequence", 1, 128, 2048), + ("medium_sequence", 1, 512, 2048), + ("long_sequence", 1, 1024, 2048), + ("max_sequence", 1, 2048, 2048), ] print("Benchmarking Custom Metal GQA Kernel vs MLX Baseline") @@ -425,11 +425,11 @@ def test_metal_gqa_correctness(): print("=" * 50) # Test configuration - B, L, D = 1, 64, 5120 + B, L, D = 1, 64, 2048 class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 + hidden_size = 2048 + num_attention_heads = 16 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -463,7 +463,7 @@ class MockArgs: # Test direct kernel function print("\n=== Testing Direct Kernel Function ===") - B, H, L, D = 1, 40, 128, 128 + B, H, L, D = 1, 16, 128, 128 q = mx.random.normal((B, H, L, D)) k = mx.random.normal((B, 8, L, D)) # 8 KV heads v = mx.random.normal((B, 8, L, D)) @@ -496,7 +496,7 @@ class MockArgs: print("Evolution focus:") print("1. 🔧 Metal kernel source code optimization") print("2. 💾 Memory access pattern improvements for Apple Silicon") - print("3. 🎯 GQA-specific optimizations for 40:8 head ratio") + print("3. 🎯 GQA-specific optimizations for 16:8 head ratio") print("4. ⚡ Vectorization and SIMD optimization") print("5. 🚀 Thread group and grid configuration tuning") print("Target: 5-15% performance improvement through Metal kernel innovation") diff --git a/examples/mlx_metal_kernel_opt/config.yaml b/examples/mlx_metal_kernel_opt/config.yaml index 19b9342ab..29dd90ad2 100644 --- a/examples/mlx_metal_kernel_opt/config.yaml +++ b/examples/mlx_metal_kernel_opt/config.yaml @@ -22,7 +22,7 @@ prompt: # TARGET: Optimize Metal Kernel for Qwen3 Grouped Query Attention (GQA) # HARDWARE: Apple M-series GPUs with unified memory architecture # BASELINE: Standard MLX scaled_dot_product_attention - # ARCHITECTURE: 40 query heads : 8 KV heads (5:1 ratio), 128 head dimension + # ARCHITECTURE: 16 query heads : 8 KV heads (2:1 ratio), 128 head dimension # GOAL: 5-15% performance improvement through Metal kernel optimization # CURRENT METAL KERNEL STRUCTURE: @@ -33,7 +33,7 @@ prompt: uint head_idx = thread_position_in_grid.y; uint batch_idx = thread_position_in_grid.z; - // GQA mapping: 5 query heads per KV head + // GQA mapping: 2 query heads per KV head uint kv_head_idx = head_idx / HEADS_PER_KV; // Current algorithm: @@ -76,10 +76,10 @@ prompt: **3. GQA-Specific Optimizations:** ```metal // CURRENT: Basic kv_head_idx = head_idx / HEADS_PER_KV - // OPTIMIZE: Leverage the specific 5:1 ratio pattern + // OPTIMIZE: Leverage the specific 2:1 ratio pattern // Example: Process 5 query heads together for each KV head - // Example: Optimize memory layout for the 40:8 pattern + // Example: Optimize memory layout for the 16:8 pattern // Example: Reduce broadcast overhead through clever indexing ``` @@ -180,7 +180,7 @@ prompt: **Strategy 4: GQA Pattern Exploitation** ```metal - // Optimize for the specific 5:1 query:KV ratio + // Optimize for the specific 2:1 query:KV ratio // Process query heads in groups of 5 // Reduce KV head indexing overhead ``` @@ -203,7 +203,7 @@ prompt: - Focus ONLY on optimizing the Metal kernel source code in the EVOLVE-BLOCK - The kernel will be compiled using mx.fast.metal_kernel() automatically - Maintain the exact same attention computation semantics - - Test with Qwen3's specific 40:8 head configuration + - Test with Qwen3-0.6B's specific 16:8 head configuration - Leverage Apple Silicon's unified memory and SIMD capabilities Your goal is to discover Metal kernel optimizations that outperform MLX's diff --git a/examples/mlx_metal_kernel_opt/evaluator.py b/examples/mlx_metal_kernel_opt/evaluator.py index 76ffbf4a3..486c8d732 100644 --- a/examples/mlx_metal_kernel_opt/evaluator.py +++ b/examples/mlx_metal_kernel_opt/evaluator.py @@ -377,8 +377,10 @@ def _validate_metal_kernel_safety(self, custom_attention_class: Any) -> Dict[str # Mock arguments for safety testing class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 + # NOTE: This should reflect the default model used by this evaluator: + # `mlx-community/Qwen3-0.6B-bf16` (16 Q heads : 8 KV heads, head_dim=128). + hidden_size = 2048 + num_attention_heads = 16 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -396,11 +398,14 @@ class MockArgs: print(" ✅ Custom attention instantiation successful") - # Basic parameter validation - if hasattr(instance, "n_heads") and instance.n_heads != 40: - return {"success": False, "error": f"Invalid head count: {instance.n_heads}"} + # Basic parameter validation (should match the args we instantiated with) + if hasattr(instance, "n_heads") and instance.n_heads != args.num_attention_heads: + return { + "success": False, + "error": f"Invalid head count: {instance.n_heads} (expected {args.num_attention_heads})", + } - if hasattr(instance, "n_kv_heads") and instance.n_kv_heads != 8: + if hasattr(instance, "n_kv_heads") and instance.n_kv_heads != args.num_key_value_heads: return { "success": False, "error": f"Invalid KV head count: {instance.n_kv_heads}", @@ -543,8 +548,9 @@ def _memory_safe_correctness_test(self, custom_attention_class: Any) -> Dict[str try: # Safe test configuration class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 + # Must match the default model `mlx-community/Qwen3-0.6B-bf16` + hidden_size = 2048 + num_attention_heads = 16 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -556,10 +562,10 @@ class MockArgs: # Conservative test cases (smaller sequences for safety) test_cases = [ - (1, 8, 5120), # Micro sequence - (1, 16, 5120), # Very short - (1, 32, 5120), # Short sequence - (1, 64, 5120), # Medium sequence + (1, 8, 2048), # Micro sequence + (1, 16, 2048), # Very short + (1, 32, 2048), # Short sequence + (1, 64, 2048), # Medium sequence ] correctness_scores = [] @@ -576,7 +582,10 @@ class MockArgs: self._ensure_clean_gpu_state() # Create conservative test inputs - x = mx.random.normal((B, L, D)) * 0.1 # Smaller values for safety + # IMPORTANT: Match the real inference dtype used by the default model + # (`mlx-community/Qwen3-0.6B-bf16`), otherwise Metal kernels may compile + # for float32 in correctness tests but fail under bfloat16 in practice. + x = (mx.random.normal((B, L, D)) * 0.1).astype(mx.bfloat16) mask = "causal" # Test with maximum GPU protection @@ -652,6 +661,11 @@ def _test_single_sequence_memory_safe( ) -> float: """Test single sequence with enhanced memory safety""" try: + # Force bfloat16 to exercise the same kernel template/compilation path as production + # inference with `mlx-community/Qwen3-0.6B-bf16`. + if x.dtype != mx.bfloat16: + x = x.astype(mx.bfloat16) + # Pre-execution safety checks if x.shape[1] > self.max_sequence_length_safe: raise MetalKernelSafetyError( @@ -668,6 +682,12 @@ def _test_single_sequence_memory_safe( if custom_attn is None: raise ValueError("Failed to instantiate custom attention") + # Ensure module parameters follow the intended compute dtype as well. + # Otherwise, float32 weights can upcast intermediate Q/K/V tensors and + # accidentally avoid bfloat16 kernel compilation. + if hasattr(custom_attn, "set_dtype"): + custom_attn.set_dtype(mx.bfloat16) + # Conservative forward pass with timeout simulation start_time = time.time() output = custom_attn(x, mask=mask) diff --git a/examples/mlx_metal_kernel_opt/initial_program.py b/examples/mlx_metal_kernel_opt/initial_program.py index 24c6896cf..6c705b9b7 100644 --- a/examples/mlx_metal_kernel_opt/initial_program.py +++ b/examples/mlx_metal_kernel_opt/initial_program.py @@ -1,11 +1,11 @@ """ Qwen3 Custom Metal Kernel for Grouped Query Attention (GQA) Optimization -This module implements a custom Metal kernel for Qwen3's 40:8 GQA pattern using +This module implements a custom Metal kernel for Qwen3's 16:8 GQA pattern using MLX's metal_kernel API. The kernel is designed to outperform mx.fast.scaled_dot_product_attention -by leveraging Apple Silicon specific optimizations and the 5:1 query-to-KV head ratio. +by leveraging Apple Silicon specific optimizations and the 2:1 query-to-KV head ratio. -Target: Qwen3-0.6B with 40 query heads : 8 KV heads +Target: Qwen3-0.6B with 16 query heads : 8 KV heads Hardware: Apple M-series GPUs with unified memory Baseline: Standard MLX-LM using mx.fast.scaled_dot_product_attention Goal: 5-15% performance improvement through custom Metal kernel optimization @@ -26,19 +26,19 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): Custom Metal kernel implementation for Qwen3 GQA attention. Args: - queries: [B, num_heads=40, L, head_dim=128] + queries: [B, num_heads=16, L, head_dim=128] keys: [B, num_kv_heads=8, L, head_dim=128] values: [B, num_kv_heads=8, L, head_dim=128] scale: Attention scaling factor (1/sqrt(head_dim)) mask: Attention mask (None, "causal", or boolean tensor) Returns: - Attention output [B, num_heads=40, L, head_dim=128] + Attention output [B, num_heads=16, L, head_dim=128] """ B, num_heads, L, head_dim = queries.shape _, num_kv_heads, _, _ = keys.shape - heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3 + heads_per_kv = num_heads // num_kv_heads # 2 for Qwen3-0.6B (16:8) # Handle mask conversion if mask == "causal" or mask is None: @@ -67,9 +67,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): # EVOLVE-BLOCK-START # Custom Metal kernel source for Qwen3 GQA optimization - # This kernel leverages the 40:8 head ratio and Apple Silicon architecture + # This kernel leverages the 16:8 head ratio and Apple Silicon architecture kernel_source = """ - // Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern + // Qwen3 GQA Metal Kernel - Optimized for 16:8 head pattern // Thread mapping: each thread processes one query position uint thread_id = thread_position_in_grid.x; uint head_idx = thread_position_in_grid.y; @@ -86,7 +86,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): bool use_mask_val = use_mask[0] > 0; // GQA mapping: determine which KV head corresponds to this query head - uint kv_head_idx = head_idx / HEADS_PER_KV; // 5 query heads per KV head + uint kv_head_idx = head_idx / HEADS_PER_KV; // 2 query heads per KV head // Pre-calculate base indices for memory access optimization const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + @@ -250,8 +250,8 @@ def __init__(self, args): super().__init__() # Standard Qwen3 parameters - dim = args.hidden_size # 5120 - self.n_heads = n_heads = args.num_attention_heads # 40 + dim = args.hidden_size # 2048 + self.n_heads = n_heads = args.num_attention_heads # 16 assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads # 8 head_dim = args.head_dim # 128 @@ -364,8 +364,8 @@ def benchmark_metal_gqa_optimization(): # Qwen3-0.6B configuration class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 + hidden_size = 2048 + num_attention_heads = 16 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -377,10 +377,10 @@ class MockArgs: # Test configurations for Metal kernel validation test_configs = [ - ("short_sequence", 1, 128, 5120), - ("medium_sequence", 1, 512, 5120), - ("long_sequence", 1, 1024, 5120), - ("max_sequence", 1, 2048, 5120), + ("short_sequence", 1, 128, 2048), + ("medium_sequence", 1, 512, 2048), + ("long_sequence", 1, 1024, 2048), + ("max_sequence", 1, 2048, 2048), ] print("Benchmarking Custom Metal GQA Kernel vs MLX Baseline") @@ -427,11 +427,11 @@ def test_metal_gqa_correctness(): print("=" * 50) # Test configuration - B, L, D = 1, 64, 5120 + B, L, D = 1, 64, 2048 class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 + hidden_size = 2048 + num_attention_heads = 16 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -465,7 +465,7 @@ class MockArgs: # Test direct kernel function print("\n=== Testing Direct Kernel Function ===") - B, H, L, D = 1, 40, 128, 128 + B, H, L, D = 1, 16, 128, 128 q = mx.random.normal((B, H, L, D)) k = mx.random.normal((B, 8, L, D)) # 8 KV heads v = mx.random.normal((B, 8, L, D)) @@ -498,7 +498,7 @@ class MockArgs: print("Evolution focus:") print("1. 🔧 Metal kernel source code optimization") print("2. 💾 Memory access pattern improvements for Apple Silicon") - print("3. 🎯 GQA-specific optimizations for 40:8 head ratio") + print("3. 🎯 GQA-specific optimizations for 16:8 head ratio") print("4. ⚡ Vectorization and SIMD optimization") print("5. 🚀 Thread group and grid configuration tuning") print("Target: 5-15% performance improvement through Metal kernel innovation") diff --git a/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py b/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py index f641fce5e..1ecd97909 100644 --- a/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py +++ b/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py @@ -8,7 +8,7 @@ Target Model: mlx-community/Qwen3-0.6B-bf16 Target Hardware: Apple M4 24GB -Optimization: Custom Metal kernel for GQA attention (40 query heads : 8 KV heads) +Optimization: Custom Metal kernel for GQA attention (16 query heads : 8 KV heads) Baseline: mx.fast.scaled_dot_product_attention """ From 40b428dafde875ba6f4c42bf497cf357904e453b Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Sun, 4 Jan 2026 10:20:43 +0800 Subject: [PATCH 03/14] docs(mlx_metal): update example README for benchmark validity --- examples/mlx_metal_kernel_opt/README.md | 72 +++++++++---------------- 1 file changed, 24 insertions(+), 48 deletions(-) diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index 0bc6d8650..d0af01693 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -4,7 +4,9 @@ ## Abstract -This work demonstrates the application of evolutionary code optimization to the automatic discovery of custom Metal GPU kernels for transformer attention mechanisms. Using OpenEvolve, we evolved a specialized Metal kernel for Grouped Query Attention (GQA) in Qwen3-0.6B that leverages Apple Silicon's unified memory architecture and vector processing capabilities. Our approach achieved measurable performance improvements over MLX's highly optimized `scaled_dot_product_attention` baseline across diverse inference workloads, with decode speed improvements averaging 12.5% and reaching up to 106% on specific benchmark tasks. +This example demonstrates evolutionary code optimization for discovering custom Apple Silicon Metal GPU kernels for transformer attention. It targets Grouped Query Attention (GQA) in Qwen3-0.6B using MLX’s `metal_kernel` API, with performance evaluated via `mlx_lm.generate`. + +> **Important**: Earlier versions of this example had evaluation validity issues (subprocess benchmarks not using the evolved kernel, correctness tests using float32 while the default model is bfloat16, and docs/tests assuming the wrong head configuration). These issues can lead to misleading “best program” results and invalid performance claims. The example has been updated to address these problems. ## 1. Introduction @@ -36,8 +38,8 @@ We employ OpenEvolve to automatically optimize the Metal kernel source code resp Each evolved kernel undergoes comprehensive evaluation: -1. **Correctness Validation**: Numerical accuracy verification against MLX baseline -2. **Performance Benchmarking**: 20 diverse inference scenarios covering: +1. **Correctness Validation**: Functional/safety checks and dtype coverage consistent with the target model (bfloat16 by default). +2. **Performance Benchmarking**: Diverse inference scenarios covering: - Short context (16-64 tokens) - Long context (512-2048 tokens) - Code generation @@ -81,7 +83,7 @@ for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { } ``` -**Innovation**: Using 8-element vectors perfectly matches Apple Silicon's SIMD capabilities for 128-dimensional attention heads. +**Note**: Vectorized kernels must be validated under the target dtype (bfloat16 by default). Some vectorized patterns (e.g. `dot(vec)`) may not compile on Metal and should be caught by correctness gating. #### 3.1.2 Online Softmax Algorithm ```metal @@ -128,57 +130,35 @@ The evolved kernel exploits specific Apple Silicon features: ### 4.1 Performance Benchmarking -We evaluated the evolved kernel against MLX baseline across 20 comprehensive benchmark scenarios representing real-world inference patterns. - -**Aggregate Performance Improvements**: -- **Decode Speed**: +12.5% average improvement (σ = 38.3%) -- **Prefill Speed**: +14.4% average improvement (σ = 17.6%) -- **Total Throughput**: +10.4% average improvement (σ = 30.7%) -- **Memory Usage**: +0.99% average reduction (σ = 1.7%) +We evaluate kernels against the MLX baseline across a benchmark suite representing real-world inference patterns. -### 4.2 Benchmark Category Analysis +**Note**: If you are comparing results across commits, ensure the benchmarks are actually exercising the custom kernel (subprocess hook) and that correctness covers the target dtype (bfloat16). Otherwise, reported speedups can be noise. -| **Category** | **Benchmarks** | **Decode Improvement** | **Notable Results** | -|--------------|----------------|------------------------|-------------------| -| **Short Context** | 2 | -4.6% ± 3.8% | Mixed results on very short sequences | -| **Long Context** | 6 | +8.1% ± 42.1% | High variance, strong improvements in some cases | -| **Code Generation** | 1 | -16.5% | Performance regression | -| **General Tasks** | 9 | +24.8% ± 35.4% | Strongest category with 106% peak improvement | -| **Stress Tests** | 2 | +22.9% ± 31.5% | Robust performance under memory pressure | +To reproduce results on your machine: -### 4.3 Statistical Analysis - -**Distribution of Improvements**: -- **Significant Gains** (>25%): 7/20 benchmarks -- **Moderate Gains** (5-25%): 3/20 benchmarks -- **Neutral** (±5%): 4/20 benchmarks -- **Regressions** (<-5%): 6/20 benchmarks +```bash +cd openevolve/examples/mlx_metal_kernel_opt +python run_benchmarks.py --mode compare --model mlx-community/Qwen3-0.6B-bf16 --output-dir results +``` -**Peak Performance**: Repetitive pattern generation achieved 106% decode speed improvement, demonstrating the kernel's effectiveness for certain workload characteristics. +This writes CSV/JSON comparison artifacts into `results/` for analysis. ### 4.4 Correctness Validation -All evolved kernels maintained numerical correctness: -- **Accuracy**: 100% correctness score across all test cases -- **Numerical Stability**: No NaN/Inf values detected -- **Statistical Validation**: Output distributions within expected ranges -- **Functional Equivalence**: Attention semantics preserved +Correctness checks should include: +- **Target dtype coverage** (bfloat16 by default for `mlx-community/Qwen3-0.6B-bf16`) +- **Numerical sanity** (no NaN/Inf) +- **Shape checks** +- **Safety checks** (GPU command buffer errors / memory violations) ## 5. Discussion ### 5.1 Performance Characteristics -The evolved kernel shows workload-dependent performance characteristics: +Kernel performance is workload-dependent. In particular: -**Strengths**: -- **Sustained Generation**: +46.6% improvement on dialogue tasks -- **Long Sequences**: +73.9% improvement on extreme-length generation -- **Memory Efficiency**: Consistent memory usage reduction - -**Limitations**: -- **Short Sequences**: Limited improvement due to setup overhead -- **Code Generation**: -16.5% regression suggesting suboptimal patterns for this workload -- **Variance**: High performance variance across different sequence patterns +- **Short sequences**: may see limited gains due to fixed overheads. +- **Long sequences / sustained decode**: are typically where attention kernels matter most, but must be measured. ### 5.2 Technical Insights @@ -190,11 +170,7 @@ The evolved kernel shows workload-dependent performance characteristics: ### 5.3 Evolutionary Process Analysis -**Convergence**: The system converged to the optimal solution within 25 generations, with significant improvements appearing by generation 10. - -**Safety**: Zero Metal kernel compilation errors or GPU command buffer failures across all evolution attempts, demonstrating robust evolutionary constraints. - -**Diversity**: The evolutionary process explored multiple optimization strategies including different vectorization patterns, memory layouts, and algorithmic approaches. +With realistic evaluation enabled (subprocess hook + bfloat16 correctness), it is expected that some evolved kernels will be rejected due to bfloat16 Metal compilation/runtime failures. The evaluator should treat these as ordinary candidate failures (not crashes) and continue evolution. ## 6. Related Work @@ -223,6 +199,6 @@ Our approach differs by applying evolutionary optimization directly to GPU shade ## 8. Conclusion -We demonstrate that evolutionary code optimization can automatically discover hardware-specific GPU kernel optimizations that outperform expert-engineered baselines. The evolved Metal kernel achieved an average 12.5% decode speed improvement through novel vectorization patterns, algorithmic innovations, and Apple Silicon specializations. While performance gains are workload-dependent, the approach successfully identified genuinely novel optimizations that would be challenging to discover through manual optimization. +We demonstrate how evolutionary code optimization can be applied to discover hardware-specific Metal kernels for transformer attention. Performance gains are workload-dependent; for credible results, rerun the benchmark suite on your machine with the evaluation validity fixes enabled. This work establishes evolutionary optimization as a viable approach for automated GPU kernel discovery and suggests significant potential for applying similar techniques to other performance-critical computational kernels. \ No newline at end of file From 2721643450c44a8a29ee2e5e422232524db7b3bc Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Sun, 4 Jan 2026 11:17:19 +0800 Subject: [PATCH 04/14] chore(mlx_metal): preserve baseline artifacts; add validity-fix README --- examples/mlx_metal_kernel_opt/README.md | 80 ++++--- .../README_validity_fix.md | 206 ++++++++++++++++++ examples/mlx_metal_kernel_opt/best_program.py | 44 ++-- 3 files changed, 280 insertions(+), 50 deletions(-) create mode 100644 examples/mlx_metal_kernel_opt/README_validity_fix.md diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index d0af01693..8a0c35136 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -4,9 +4,7 @@ ## Abstract -This example demonstrates evolutionary code optimization for discovering custom Apple Silicon Metal GPU kernels for transformer attention. It targets Grouped Query Attention (GQA) in Qwen3-0.6B using MLX’s `metal_kernel` API, with performance evaluated via `mlx_lm.generate`. - -> **Important**: Earlier versions of this example had evaluation validity issues (subprocess benchmarks not using the evolved kernel, correctness tests using float32 while the default model is bfloat16, and docs/tests assuming the wrong head configuration). These issues can lead to misleading “best program” results and invalid performance claims. The example has been updated to address these problems. +This work demonstrates the application of evolutionary code optimization to the automatic discovery of custom Metal GPU kernels for transformer attention mechanisms. Using OpenEvolve, we evolved a specialized Metal kernel for Grouped Query Attention (GQA) in Qwen3-0.6B that leverages Apple Silicon's unified memory architecture and vector processing capabilities. Our approach achieved measurable performance improvements over MLX's highly optimized `scaled_dot_product_attention` baseline across diverse inference workloads, with decode speed improvements averaging 12.5% and reaching up to 106% on specific benchmark tasks. ## 1. Introduction @@ -16,7 +14,7 @@ Modern transformer models rely heavily on optimized attention kernels for effici ### 1.2 Target System -- **Model**: Qwen3-0.6B with Grouped Query Attention (16 query heads : 8 key-value heads) +- **Model**: Qwen3-0.6B with Grouped Query Attention (40 query heads : 8 key-value heads) - **Hardware**: Apple M-series GPUs with unified memory architecture - **Framework**: MLX with custom Metal kernel integration - **Baseline**: `mx.fast.scaled_dot_product_attention` @@ -38,8 +36,8 @@ We employ OpenEvolve to automatically optimize the Metal kernel source code resp Each evolved kernel undergoes comprehensive evaluation: -1. **Correctness Validation**: Functional/safety checks and dtype coverage consistent with the target model (bfloat16 by default). -2. **Performance Benchmarking**: Diverse inference scenarios covering: +1. **Correctness Validation**: Numerical accuracy verification against MLX baseline +2. **Performance Benchmarking**: 20 diverse inference scenarios covering: - Short context (16-64 tokens) - Long context (512-2048 tokens) - Code generation @@ -83,7 +81,7 @@ for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { } ``` -**Note**: Vectorized kernels must be validated under the target dtype (bfloat16 by default). Some vectorized patterns (e.g. `dot(vec)`) may not compile on Metal and should be caught by correctness gating. +**Innovation**: Using 8-element vectors perfectly matches Apple Silicon's SIMD capabilities for 128-dimensional attention heads. #### 3.1.2 Online Softmax Algorithm ```metal @@ -113,7 +111,7 @@ for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + head_idx * (SEQ_LEN * HEAD_DIM) + query_pos * HEAD_DIM; -const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 2:1 mapping +const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 5:1 mapping ``` **Innovation**: Leverages unified memory bandwidth through coalesced access patterns and direct GQA head mapping. @@ -130,35 +128,57 @@ The evolved kernel exploits specific Apple Silicon features: ### 4.1 Performance Benchmarking -We evaluate kernels against the MLX baseline across a benchmark suite representing real-world inference patterns. +We evaluated the evolved kernel against MLX baseline across 20 comprehensive benchmark scenarios representing real-world inference patterns. -**Note**: If you are comparing results across commits, ensure the benchmarks are actually exercising the custom kernel (subprocess hook) and that correctness covers the target dtype (bfloat16). Otherwise, reported speedups can be noise. +**Aggregate Performance Improvements**: +- **Decode Speed**: +12.5% average improvement (σ = 38.3%) +- **Prefill Speed**: +14.4% average improvement (σ = 17.6%) +- **Total Throughput**: +10.4% average improvement (σ = 30.7%) +- **Memory Usage**: +0.99% average reduction (σ = 1.7%) -To reproduce results on your machine: +### 4.2 Benchmark Category Analysis -```bash -cd openevolve/examples/mlx_metal_kernel_opt -python run_benchmarks.py --mode compare --model mlx-community/Qwen3-0.6B-bf16 --output-dir results -``` +| **Category** | **Benchmarks** | **Decode Improvement** | **Notable Results** | +|--------------|----------------|------------------------|-------------------| +| **Short Context** | 2 | -4.6% ± 3.8% | Mixed results on very short sequences | +| **Long Context** | 6 | +8.1% ± 42.1% | High variance, strong improvements in some cases | +| **Code Generation** | 1 | -16.5% | Performance regression | +| **General Tasks** | 9 | +24.8% ± 35.4% | Strongest category with 106% peak improvement | +| **Stress Tests** | 2 | +22.9% ± 31.5% | Robust performance under memory pressure | + +### 4.3 Statistical Analysis + +**Distribution of Improvements**: +- **Significant Gains** (>25%): 7/20 benchmarks +- **Moderate Gains** (5-25%): 3/20 benchmarks +- **Neutral** (±5%): 4/20 benchmarks +- **Regressions** (<-5%): 6/20 benchmarks -This writes CSV/JSON comparison artifacts into `results/` for analysis. +**Peak Performance**: Repetitive pattern generation achieved 106% decode speed improvement, demonstrating the kernel's effectiveness for certain workload characteristics. ### 4.4 Correctness Validation -Correctness checks should include: -- **Target dtype coverage** (bfloat16 by default for `mlx-community/Qwen3-0.6B-bf16`) -- **Numerical sanity** (no NaN/Inf) -- **Shape checks** -- **Safety checks** (GPU command buffer errors / memory violations) +All evolved kernels maintained numerical correctness: +- **Accuracy**: 100% correctness score across all test cases +- **Numerical Stability**: No NaN/Inf values detected +- **Statistical Validation**: Output distributions within expected ranges +- **Functional Equivalence**: Attention semantics preserved ## 5. Discussion ### 5.1 Performance Characteristics -Kernel performance is workload-dependent. In particular: +The evolved kernel shows workload-dependent performance characteristics: -- **Short sequences**: may see limited gains due to fixed overheads. -- **Long sequences / sustained decode**: are typically where attention kernels matter most, but must be measured. +**Strengths**: +- **Sustained Generation**: +46.6% improvement on dialogue tasks +- **Long Sequences**: +73.9% improvement on extreme-length generation +- **Memory Efficiency**: Consistent memory usage reduction + +**Limitations**: +- **Short Sequences**: Limited improvement due to setup overhead +- **Code Generation**: -16.5% regression suggesting suboptimal patterns for this workload +- **Variance**: High performance variance across different sequence patterns ### 5.2 Technical Insights @@ -166,11 +186,15 @@ Kernel performance is workload-dependent. In particular: **Algorithm Innovation**: The two-pass online softmax represents a novel contribution, demonstrating that evolutionary approaches can discover algorithmic improvements beyond simple micro-optimizations. -**GQA Specialization**: Direct exploitation of the 2:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. +**GQA Specialization**: Direct exploitation of the 5:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. ### 5.3 Evolutionary Process Analysis -With realistic evaluation enabled (subprocess hook + bfloat16 correctness), it is expected that some evolved kernels will be rejected due to bfloat16 Metal compilation/runtime failures. The evaluator should treat these as ordinary candidate failures (not crashes) and continue evolution. +**Convergence**: The system converged to the optimal solution within 25 generations, with significant improvements appearing by generation 10. + +**Safety**: Zero Metal kernel compilation errors or GPU command buffer failures across all evolution attempts, demonstrating robust evolutionary constraints. + +**Diversity**: The evolutionary process explored multiple optimization strategies including different vectorization patterns, memory layouts, and algorithmic approaches. ## 6. Related Work @@ -187,7 +211,7 @@ Our approach differs by applying evolutionary optimization directly to GPU shade ### 7.1 Current Limitations - **Workload Specificity**: Performance improvements are highly dependent on sequence patterns -- **Model Scope**: Results specific to Qwen3-0.6B's 16:8 GQA configuration +- **Model Scope**: Results specific to Qwen3-0.6B's 40:8 GQA configuration - **Hardware Scope**: Optimizations specific to Apple Silicon architecture ### 7.2 Future Directions @@ -199,6 +223,6 @@ Our approach differs by applying evolutionary optimization directly to GPU shade ## 8. Conclusion -We demonstrate how evolutionary code optimization can be applied to discover hardware-specific Metal kernels for transformer attention. Performance gains are workload-dependent; for credible results, rerun the benchmark suite on your machine with the evaluation validity fixes enabled. +We demonstrate that evolutionary code optimization can automatically discover hardware-specific GPU kernel optimizations that outperform expert-engineered baselines. The evolved Metal kernel achieved an average 12.5% decode speed improvement through novel vectorization patterns, algorithmic innovations, and Apple Silicon specializations. While performance gains are workload-dependent, the approach successfully identified genuinely novel optimizations that would be challenging to discover through manual optimization. This work establishes evolutionary optimization as a viable approach for automated GPU kernel discovery and suggests significant potential for applying similar techniques to other performance-critical computational kernels. \ No newline at end of file diff --git a/examples/mlx_metal_kernel_opt/README_validity_fix.md b/examples/mlx_metal_kernel_opt/README_validity_fix.md new file mode 100644 index 000000000..9295bdf4a --- /dev/null +++ b/examples/mlx_metal_kernel_opt/README_validity_fix.md @@ -0,0 +1,206 @@ +# OpenEvolve Metal Kernel Optimization: Automated Discovery of Custom GPU Kernels for Transformer Attention + +**Evolutionary Optimization of Apple Silicon Metal Kernels for Grouped Query Attention in Qwen3-0.6B** + +## Abstract + +This example demonstrates evolutionary code optimization for discovering custom Apple Silicon Metal GPU kernels for transformer attention. It targets Grouped Query Attention (GQA) in Qwen3-0.6B using MLX’s `metal_kernel` API, with performance evaluated via `mlx_lm.generate`. + +> **Important**: Earlier versions of this example had evaluation validity issues (subprocess benchmarks not using the evolved kernel, correctness tests using float32 while the default model is bfloat16, and docs/tests assuming the wrong head configuration). These issues can lead to misleading “best program” results and invalid performance claims. The example has been updated to address these problems. + +## 1. Introduction + +### 1.1 Motivation + +Modern transformer models rely heavily on optimized attention kernels for efficient inference. While frameworks like MLX provide highly optimized implementations, the rapid evolution of hardware architectures creates opportunities for specialized optimizations that general-purpose kernels cannot capture. This work explores whether evolutionary code optimization can automatically discover hardware-specific kernel optimizations that outperform expert-engineered baselines. + +### 1.2 Target System + +- **Model**: Qwen3-0.6B with Grouped Query Attention (16 query heads : 8 key-value heads) +- **Hardware**: Apple M-series GPUs with unified memory architecture +- **Framework**: MLX with custom Metal kernel integration +- **Baseline**: `mx.fast.scaled_dot_product_attention` +- **Evolution Target**: Metal shader source code implementing GQA attention computation + +## 2. Methodology + +### 2.1 Evolution Framework + +We employ OpenEvolve to automatically optimize the Metal kernel source code responsible for computing attention. The evolutionary process operates on a single code block (EVOLVE-BLOCK) containing approximately 150 lines of Metal C++ shader code while preserving the surrounding MLX integration infrastructure. + +**Evolution Configuration**: +- **Population Size**: 25 programs +- **Generations**: 25 iterations +- **Models**: Gemini 2.5 Flash (60%) + Gemini 2.5 Pro (40%) +- **Selection**: Multi-objective optimization balancing performance and correctness + +### 2.2 Evaluation Methodology + +Each evolved kernel undergoes comprehensive evaluation: + +1. **Correctness Validation**: Functional/safety checks and dtype coverage consistent with the target model (bfloat16 by default). +2. **Performance Benchmarking**: Diverse inference scenarios covering: + - Short context (16-64 tokens) + - Long context (512-2048 tokens) + - Code generation + - Sustained dialogue + - Technical documentation + - Memory stress tests + +3. **Safety Validation**: GPU command buffer error detection and Metal memory violation checking + +### 2.3 Optimization Constraints + +**Preserved Elements**: +- Kernel function signature and I/O specifications +- Thread grid mapping and bounds checking +- Overall algorithm correctness (attention semantics) +- MLX integration interface + +**Optimizable Elements**: +- Memory access patterns and vectorization +- Computation order and algorithmic efficiency +- Apple Silicon specific optimizations +- GQA-specific computation strategies + +## 3. Technical Contributions + +### 3.1 Discovered Optimizations + +The evolutionary process discovered several key optimizations: + +#### 3.1.1 Enhanced Vectorization +```metal +// Original: Scalar operations +for (uint d = 0; d < HEAD_DIM; d++) { + score += query_vec[d] * keys[k_base + d]; +} + +// Evolved: Vector operations with optimal width +vec query_vec_v[HEAD_DIM / 8]; // 16 vectors for 128-dim heads +for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { + score += dot(query_vec_v[d_vec], ((device vec*)(keys + k_base))[d_vec]); +} +``` + +**Note**: Vectorized kernels must be validated under the target dtype (bfloat16 by default). Some vectorized patterns (e.g. `dot(vec)`) may not compile on Metal and should be caught by correctness gating. + +#### 3.1.2 Online Softmax Algorithm +```metal +// Pass 1: Find maximum for numerical stability +T max_score = T(-INFINITY); +for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { + T score = compute_attention_score(query_vec, key_vec) * scale_val; + max_score = max(max_score, score); +} + +// Pass 2: Combined softmax computation and value accumulation +T sum_exp = T(0.0); +vec output_acc_v[HEAD_DIM / 8]; +for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { + T exp_score = exp(current_score - max_score); + sum_exp += exp_score; + // Fused accumulation + output_acc_v[d_vec] += exp_score * ((device vec*)(values + v_base))[d_vec]; +} +``` + +**Innovation**: Reduced from three-pass to two-pass algorithm, fusing softmax normalization with value accumulation. + +#### 3.1.3 Memory Access Optimization +```metal +// Pre-computed base indices for coalesced access +const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + + head_idx * (SEQ_LEN * HEAD_DIM) + + query_pos * HEAD_DIM; +const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 2:1 mapping +``` + +**Innovation**: Leverages unified memory bandwidth through coalesced access patterns and direct GQA head mapping. + +### 3.2 Apple Silicon Specialization + +The evolved kernel exploits specific Apple Silicon features: +- **Unified Memory**: Optimized bandwidth utilization patterns +- **SIMD Width**: 8-element vectors matching GPU vector units +- **Thread Group Size**: 32-thread groups optimal for Apple GPUs +- **Register Allocation**: Balanced computation vs. memory bandwidth + +## 4. Experimental Results + +### 4.1 Performance Benchmarking + +We evaluate kernels against the MLX baseline across a benchmark suite representing real-world inference patterns. + +**Note**: If you are comparing results across commits, ensure the benchmarks are actually exercising the custom kernel (subprocess hook) and that correctness covers the target dtype (bfloat16). Otherwise, reported speedups can be noise. + +To reproduce results on your machine: + +```bash +cd openevolve/examples/mlx_metal_kernel_opt +python run_benchmarks.py --mode compare --model mlx-community/Qwen3-0.6B-bf16 --output-dir results +``` + +This writes CSV/JSON comparison artifacts into `results/` for analysis. + +### 4.4 Correctness Validation + +Correctness checks should include: +- **Target dtype coverage** (bfloat16 by default for `mlx-community/Qwen3-0.6B-bf16`) +- **Numerical sanity** (no NaN/Inf) +- **Shape checks** +- **Safety checks** (GPU command buffer errors / memory violations) + +## 5. Discussion + +### 5.1 Performance Characteristics + +Kernel performance is workload-dependent. In particular: + +- **Short sequences**: may see limited gains due to fixed overheads. +- **Long sequences / sustained decode**: are typically where attention kernels matter most, but must be measured. + +### 5.2 Technical Insights + +**Vectorization Impact**: The discovery of `vec` operations as optimal for 128-dimensional heads represents a significant finding, suggesting that hardware-specific vector widths are crucial for performance. + +**Algorithm Innovation**: The two-pass online softmax represents a novel contribution, demonstrating that evolutionary approaches can discover algorithmic improvements beyond simple micro-optimizations. + +**GQA Specialization**: Direct exploitation of the 2:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. + +### 5.3 Evolutionary Process Analysis + +With realistic evaluation enabled (subprocess hook + bfloat16 correctness), it is expected that some evolved kernels will be rejected due to bfloat16 Metal compilation/runtime failures. The evaluator should treat these as ordinary candidate failures (not crashes) and continue evolution. + +## 6. Related Work + +This work extends prior research in automated kernel optimization: + +- **AlphaTensor** [Fawzi et al., 2022]: Matrix multiplication algorithm discovery +- **TensorIR** [Feng et al., 2023]: Tensor compiler optimization +- **Ansor** [Zheng et al., 2020]: Automated tensor program optimization + +Our approach differs by applying evolutionary optimization directly to GPU shader source code rather than higher-level tensor algebra, enabling discovery of hardware-specific optimizations that would be difficult to express in tensor IRs. + +## 7. Limitations and Future Work + +### 7.1 Current Limitations + +- **Workload Specificity**: Performance improvements are highly dependent on sequence patterns +- **Model Scope**: Results specific to Qwen3-0.6B's 16:8 GQA configuration +- **Hardware Scope**: Optimizations specific to Apple Silicon architecture + +### 7.2 Future Directions + +- **Multi-Architecture**: Extend to CUDA, ROCm, and other GPU architectures +- **Model Generalization**: Apply to different attention patterns and model sizes +- **Algorithmic Expansion**: Explore evolution of other transformer components +- **Cross-Compilation**: Develop architecture-agnostic optimization strategies + +## 8. Conclusion + +We demonstrate how evolutionary code optimization can be applied to discover hardware-specific Metal kernels for transformer attention. Performance gains are workload-dependent; for credible results, rerun the benchmark suite on your machine with the evaluation validity fixes enabled. + +This work establishes evolutionary optimization as a viable approach for automated GPU kernel discovery and suggests significant potential for applying similar techniques to other performance-critical computational kernels. + + diff --git a/examples/mlx_metal_kernel_opt/best_program.py b/examples/mlx_metal_kernel_opt/best_program.py index 09b8f12c1..a94d94c92 100644 --- a/examples/mlx_metal_kernel_opt/best_program.py +++ b/examples/mlx_metal_kernel_opt/best_program.py @@ -1,11 +1,11 @@ """ Qwen3 Custom Metal Kernel for Grouped Query Attention (GQA) Optimization -This module implements a custom Metal kernel for Qwen3's 16:8 GQA pattern using +This module implements a custom Metal kernel for Qwen3's 40:8 GQA pattern using MLX's metal_kernel API. The kernel is designed to outperform mx.fast.scaled_dot_product_attention -by leveraging Apple Silicon specific optimizations and the 2:1 query-to-KV head ratio. +by leveraging Apple Silicon specific optimizations and the 5:1 query-to-KV head ratio. -Target: Qwen3-0.6B with 16 query heads : 8 KV heads +Target: Qwen3-0.6B with 40 query heads : 8 KV heads Hardware: Apple M-series GPUs with unified memory Baseline: Standard MLX-LM using mx.fast.scaled_dot_product_attention Goal: 5-15% performance improvement through custom Metal kernel optimization @@ -26,19 +26,19 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): Custom Metal kernel implementation for Qwen3 GQA attention. Args: - queries: [B, num_heads=16, L, head_dim=128] + queries: [B, num_heads=40, L, head_dim=128] keys: [B, num_kv_heads=8, L, head_dim=128] values: [B, num_kv_heads=8, L, head_dim=128] scale: Attention scaling factor (1/sqrt(head_dim)) mask: Attention mask (None, "causal", or boolean tensor) Returns: - Attention output [B, num_heads=16, L, head_dim=128] + Attention output [B, num_heads=40, L, head_dim=128] """ B, num_heads, L, head_dim = queries.shape _, num_kv_heads, _, _ = keys.shape - heads_per_kv = num_heads // num_kv_heads # 2 for Qwen3-0.6B (16:8) + heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3 # Handle mask conversion if mask == "causal" or mask is None: @@ -67,9 +67,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): # EVOLVE-BLOCK-START # Custom Metal kernel source for Qwen3 GQA optimization - # This kernel leverages the 16:8 head ratio and Apple Silicon architecture + # This kernel leverages the 40:8 head ratio and Apple Silicon architecture kernel_source = """ - // Qwen3 GQA Metal Kernel - Optimized for 16:8 head pattern + // Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern // Thread mapping: each thread processes one query position uint thread_id = thread_position_in_grid.x; uint head_idx = thread_position_in_grid.y; @@ -86,7 +86,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): bool use_mask_val = use_mask[0] > 0; // GQA mapping: determine which KV head corresponds to this query head - uint kv_head_idx = head_idx / HEADS_PER_KV; // 2 query heads per KV head + uint kv_head_idx = head_idx / HEADS_PER_KV; // 5 query heads per KV head // Pre-calculate base indices for memory access optimization const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + @@ -248,8 +248,8 @@ def __init__(self, args): super().__init__() # Standard Qwen3 parameters - dim = args.hidden_size # 2048 - self.n_heads = n_heads = args.num_attention_heads # 16 + dim = args.hidden_size # 5120 + self.n_heads = n_heads = args.num_attention_heads # 40 assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads # 8 head_dim = args.head_dim # 128 @@ -362,8 +362,8 @@ def benchmark_metal_gqa_optimization(): # Qwen3-0.6B configuration class MockArgs: - hidden_size = 2048 - num_attention_heads = 16 + hidden_size = 5120 + num_attention_heads = 40 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -375,10 +375,10 @@ class MockArgs: # Test configurations for Metal kernel validation test_configs = [ - ("short_sequence", 1, 128, 2048), - ("medium_sequence", 1, 512, 2048), - ("long_sequence", 1, 1024, 2048), - ("max_sequence", 1, 2048, 2048), + ("short_sequence", 1, 128, 5120), + ("medium_sequence", 1, 512, 5120), + ("long_sequence", 1, 1024, 5120), + ("max_sequence", 1, 2048, 5120), ] print("Benchmarking Custom Metal GQA Kernel vs MLX Baseline") @@ -425,11 +425,11 @@ def test_metal_gqa_correctness(): print("=" * 50) # Test configuration - B, L, D = 1, 64, 2048 + B, L, D = 1, 64, 5120 class MockArgs: - hidden_size = 2048 - num_attention_heads = 16 + hidden_size = 5120 + num_attention_heads = 40 num_key_value_heads = 8 head_dim = 128 rms_norm_eps = 1e-06 @@ -463,7 +463,7 @@ class MockArgs: # Test direct kernel function print("\n=== Testing Direct Kernel Function ===") - B, H, L, D = 1, 16, 128, 128 + B, H, L, D = 1, 40, 128, 128 q = mx.random.normal((B, H, L, D)) k = mx.random.normal((B, 8, L, D)) # 8 KV heads v = mx.random.normal((B, 8, L, D)) @@ -496,7 +496,7 @@ class MockArgs: print("Evolution focus:") print("1. 🔧 Metal kernel source code optimization") print("2. 💾 Memory access pattern improvements for Apple Silicon") - print("3. 🎯 GQA-specific optimizations for 16:8 head ratio") + print("3. 🎯 GQA-specific optimizations for 40:8 head ratio") print("4. ⚡ Vectorization and SIMD optimization") print("5. 🚀 Thread group and grid configuration tuning") print("Target: 5-15% performance improvement through Metal kernel innovation") From 712fde36453209229393ded1341e64d872c05414 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Mon, 5 Jan 2026 16:40:19 +0800 Subject: [PATCH 05/14] fix(mlx_metal_kernel_opt): early exit on Metal compilation errors Previously, compilation errors like 'Unable to build metal library from source' were treated the same as transient GPU errors (memory pressure, command buffer). This caused 16+ unnecessary retry attempts when a kernel had bfloat16-incompatible Metal code. Changes: - Detect 'unable to build metal library' error messages - Return immediately with failure instead of retrying - Mark error as 'compilation_error: True' for caller awareness - Increment metal_compilation_errors counter for tracking This significantly reduces evaluation time for programs with invalid Metal code. --- examples/mlx_metal_kernel_opt/evaluator.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/examples/mlx_metal_kernel_opt/evaluator.py b/examples/mlx_metal_kernel_opt/evaluator.py index 486c8d732..c1cc52faf 100644 --- a/examples/mlx_metal_kernel_opt/evaluator.py +++ b/examples/mlx_metal_kernel_opt/evaluator.py @@ -608,6 +608,20 @@ class MockArgs: elif "memory violation" in error_msg.lower(): local_memory_violations += 1 + # EARLY EXIT: Metal compilation errors are deterministic - no retry + if "unable to build metal library" in error_msg.lower(): + self.metal_compilation_errors += 1 + print(f" ❌ Metal compilation error (no retry): {error_msg[:200]}...") + # Return early - compilation errors won't be fixed by retrying + return { + "success": False, + "score": 0.0, + "error": "Metal kernel compilation failed - bfloat16 incompatible code", + "command_buffer_errors": local_command_buffer_errors, + "memory_violations": local_memory_violations, + "compilation_error": True, + } + if retry_count < self.max_retry_attempts: print( f" 🔄 Retry {retry_count + 1} for length {L}: {error_msg}" @@ -624,6 +638,19 @@ class MockArgs: error_msg = str(e) print(f" ❌ Exception for length {L}: {error_msg}") + # EARLY EXIT: Metal compilation errors are deterministic - no retry + if "unable to build metal library" in error_msg.lower(): + self.metal_compilation_errors += 1 + print(f" ❌ Metal compilation error (no retry): {error_msg[:200]}...") + return { + "success": False, + "score": 0.0, + "error": "Metal kernel compilation failed - bfloat16 incompatible code", + "command_buffer_errors": local_command_buffer_errors, + "memory_violations": local_memory_violations, + "compilation_error": True, + } + if retry_count < self.max_retry_attempts: retry_count += 1 time.sleep(self.retry_base_delay * (2**retry_count)) From 854d205bc203f82b65a16f1df70b7bdd34be0990 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:08:51 +0800 Subject: [PATCH 06/14] fix(mlx_metal_kernel_opt): reorder eval steps + add experiment runner script Evaluation optimizations: - Run correctness test BEFORE baseline benchmark (fail fast) - Skip expensive baseline if kernel doesn't compile (bfloat16 errors) Experiment runner script (run_evolve_experiment.sh): - Explicit log file truncation on resume to prevent content mixing - Fixed checkpoint resume logic to select highest numbered checkpoint - Added PYTHONUNBUFFERED and stdbuf for reliable log ordering - Support for --resume, --foreground, --iterations, --dry-run flags Documentation: - Added changelog section to README_validity_fix.md --- .../README_validity_fix.md | 40 +++ examples/mlx_metal_kernel_opt/evaluator.py | 23 +- .../run_evolve_experiment.sh | 231 ++++++++++++++++++ 3 files changed, 283 insertions(+), 11 deletions(-) create mode 100755 examples/mlx_metal_kernel_opt/run_evolve_experiment.sh diff --git a/examples/mlx_metal_kernel_opt/README_validity_fix.md b/examples/mlx_metal_kernel_opt/README_validity_fix.md index 9295bdf4a..bfe469d93 100644 --- a/examples/mlx_metal_kernel_opt/README_validity_fix.md +++ b/examples/mlx_metal_kernel_opt/README_validity_fix.md @@ -203,4 +203,44 @@ We demonstrate how evolutionary code optimization can be applied to discover har This work establishes evolutionary optimization as a viable approach for automated GPU kernel discovery and suggests significant potential for applying similar techniques to other performance-critical computational kernels. +--- + +## Appendix: Changelog (Validity & Performance Fixes) + +This section documents the specific fixes applied to address evaluation validity issues in the original example. + +### Critical Bug Fixes + +| Fix | Description | Impact | +|-----|-------------|--------| +| **Subprocess Benchmark Hook** | Evolved attention was not applied inside `subprocess.run()` — both baseline and custom benchmarks ran the same MLX attention. Fixed by injecting the hook within the subprocess. | **All reported speedups were invalid before this fix.** | +| **Dtype Alignment (bfloat16)** | Correctness tests used `float32` inputs while `Qwen3-0.6B-bf16` runs in `bfloat16`. Kernels could pass correctness but fail at inference time. Fixed by testing with `mx.bfloat16`. | Kernels incompatible with bfloat16 are now correctly rejected. | +| **Head Ratio Correction** | Documentation and tests assumed 16:8 heads, but `Qwen3-0.6B` actually uses 16:8 (2:1 GQA ratio). Verified and aligned. | Prevents confusion in kernel design. | + +### Evaluation Efficiency Optimizations + +| Optimization | Description | Benefit | +|--------------|-------------|---------| +| **Early Exit on Compilation Errors** | Metal compilation errors (e.g., `dot()` on bfloat16 vectors) are deterministic — no point retrying. Now returns immediately with `compilation_error: True`. | Saves ~30s per failed iteration (was retrying 3× per sequence length). | +| **Correctness-First Evaluation** | Reordered: correctness test runs **before** baseline benchmark. If correctness fails, baseline is skipped. | Saves ~1-2 min per invalid kernel. | +| **Log Buffering Fix** | Added `PYTHONUNBUFFERED=1` and optional `stdbuf` to `run_evolve_experiment.sh` to ensure `run.log` outputs in correct order. | Reliable log analysis. | + +### Files Modified + +- `evaluator.py` — Early exit logic, correctness-first ordering, bfloat16 test inputs +- `qwen3_benchmark_suite.py` — Subprocess hook injection +- `run_evolve_experiment.sh` — Unbuffered logging +- `config.yaml` — Documentation alignment + +### How to Verify Fixes Are Active + +```bash +# Check for early exit message in logs +grep "Metal compilation error (no retry)" openevolve_output_*/run.log + +# Check correctness runs before baseline (STEP 3 before STEP 4 in old logs, now STEP 3 = correctness) +grep "STEP 3:" openevolve_output_*/run.log | head -1 +# Should show: "STEP 3: Memory-Safe Custom Attention Correctness Testing" +``` + diff --git a/examples/mlx_metal_kernel_opt/evaluator.py b/examples/mlx_metal_kernel_opt/evaluator.py index c1cc52faf..a16130723 100644 --- a/examples/mlx_metal_kernel_opt/evaluator.py +++ b/examples/mlx_metal_kernel_opt/evaluator.py @@ -145,16 +145,8 @@ def evaluate(self, program_text: str) -> Dict[str, Any]: print(f"⚠️ Metal kernel safety validation failed: {safety_result['error']}") print("🛡️ Proceeding with enhanced protection...") - # Step 3: GPU-protected baseline measurement - print("\n📊 STEP 3: GPU-Protected Baseline Performance Measurement") - baseline_results = self._gpu_protected_measure_baseline() - if not baseline_results: - return self._create_comprehensive_failure_result( - "Failed to measure baseline performance with GPU protection" - ) - - # Step 4: Memory-safe correctness testing - print("\n🔍 STEP 4: Memory-Safe Custom Attention Correctness Testing") + # Step 3: Memory-safe correctness testing FIRST (fail fast, skip baseline if invalid) + print("\n🔍 STEP 3: Memory-Safe Custom Attention Correctness Testing") correctness_result = self._memory_safe_correctness_test(custom_attention_class) if not correctness_result["success"]: return self._create_comprehensive_failure_result( @@ -167,6 +159,14 @@ def evaluate(self, program_text: str) -> Dict[str, Any]: f"Correctness score too low: {correctness_score:.3f} (required: 0.90)" ) + # Step 4: GPU-protected baseline measurement (only if correctness passed) + print("\n📊 STEP 4: GPU-Protected Baseline Performance Measurement") + baseline_results = self._gpu_protected_measure_baseline() + if not baseline_results: + return self._create_comprehensive_failure_result( + "Failed to measure baseline performance with GPU protection" + ) + # Step 5: Command-buffer-protected benchmarking print("\n🚀 STEP 5: Command-Buffer-Protected Performance Benchmarking") benchmark_result = self._command_buffer_protected_benchmark( @@ -1016,7 +1016,8 @@ def _get_safe_benchmark_configs(self) -> List[BenchmarkConfig]: "code_generation", # Medium safety "long_context_detailed", # More challenging but still safe "long_generation", # Longer generation - "maximum_context_stress_test", # Most challenging - saved for last + # Disabled for faster testing + #"maximum_context_stress_test", # Most challenging - saved for last ] config_dict = {c.name: c for c in all_configs} diff --git a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh new file mode 100755 index 000000000..2de3ae11f --- /dev/null +++ b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh @@ -0,0 +1,231 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'USAGE' +run_evolve_experiment.sh + +Run OpenEvolve for the mlx_metal_kernel_opt example with an isolated output dir +and an isolated database path (so multiple runs don't overwrite each other). + +Defaults: + - config: ./config.yaml + - initial: ./initial_program.py + - evaluator: ./evaluator.py + - output: ./openevolve_output_/ + - db_path: /qwen3_metal_kernel_evolution + +Required env: + - OPENAI_API_KEY must be set (Gemini OpenAI-compatible API key in this example). + Alternatively, set GEMINI_API_KEY and the script will map it to OPENAI_API_KEY. + +Usage: + bash run_evolve_experiment.sh [options] + +Options: + --run-name NAME Output dir name (default: openevolve_output_) + --output-base DIR Base directory to create the run directory in (default: example dir) + --config PATH Path to config YAML (default: ./config.yaml) + --python PATH Python interpreter to use (default: python). The script runs it with -u (unbuffered). + --iterations N Override max iterations (passes -i to openevolve CLI) + --target-score S Override target score (passes -t) + --log-level LEVEL Override log level (passes -l) + --checkpoint PATH Resume from checkpoint directory (passes --checkpoint) + --resume Auto-resume from the latest run's latest checkpoint + --api-base URL Override LLM api_base (passes --api-base) + --primary-model NAME Override primary model (passes --primary-model) + --secondary-model NAME Override secondary model (passes --secondary-model) + --foreground Run in foreground (default: background + write run.log) + --dry-run Print what would run, but do not execute + -h, --help Show this help + +Examples: + export OPENAI_API_KEY="..." + bash run_evolve_experiment.sh # Start new run + bash run_evolve_experiment.sh --resume # Resume latest run + bash run_evolve_experiment.sh --iterations 5 --run-name trial_5iter +USAGE +} + +# Force unbuffered Python output for reliable logging +export PYTHONUNBUFFERED=1 + +export OPENAI_API_KEY=$GEMINI_API_KEY + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +RUN_NAME="" +OUTPUT_BASE="$SCRIPT_DIR" +CONFIG_PATH="$SCRIPT_DIR/config.yaml" +PYTHON_BIN="python" + +ITERATIONS="" +TARGET_SCORE="" +LOG_LEVEL="" +CHECKPOINT="" +API_BASE="" +PRIMARY_MODEL="" +SECONDARY_MODEL="" + +FOREGROUND=0 +DRY_RUN=0 +RESUME=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --run-name) RUN_NAME="${2:-}"; shift 2 ;; + --output-base) OUTPUT_BASE="${2:-}"; shift 2 ;; + --config) CONFIG_PATH="${2:-}"; shift 2 ;; + --python) PYTHON_BIN="${2:-}"; shift 2 ;; + --iterations) ITERATIONS="${2:-}"; shift 2 ;; + --target-score) TARGET_SCORE="${2:-}"; shift 2 ;; + --log-level) LOG_LEVEL="${2:-}"; shift 2 ;; + --checkpoint) CHECKPOINT="${2:-}"; shift 2 ;; + --resume) RESUME=1; shift ;; + --api-base) API_BASE="${2:-}"; shift 2 ;; + --primary-model) PRIMARY_MODEL="${2:-}"; shift 2 ;; + --secondary-model) SECONDARY_MODEL="${2:-}"; shift 2 ;; + --foreground) FOREGROUND=1; shift ;; + --dry-run) DRY_RUN=1; shift ;; + -h|--help) usage; exit 0 ;; + *) + echo "Unknown argument: $1" >&2 + usage + exit 2 + ;; + esac +done + +if [[ -z "${OPENAI_API_KEY:-}" && -n "${GEMINI_API_KEY:-}" ]]; then + # Convenience: allow users to keep the key in GEMINI_API_KEY while OpenEvolve expects OPENAI_API_KEY. + export OPENAI_API_KEY="${GEMINI_API_KEY}" +fi + +if [[ -z "${OPENAI_API_KEY:-}" ]]; then + echo "ERROR: OPENAI_API_KEY is not set. Export OPENAI_API_KEY (or GEMINI_API_KEY) before running." >&2 + exit 1 +fi + +# Handle --resume: find the latest run directory and its latest checkpoint +if [[ "$RESUME" -eq 1 ]]; then + echo "[run_evolve_experiment] --resume: Looking for latest run to continue..." + + # Find the latest openevolve_output_* directory + LATEST_RUN_DIR=$(find "$OUTPUT_BASE" -maxdepth 1 -type d -name "openevolve_output_*" 2>/dev/null | sort -r | head -n 1) + + if [[ -z "$LATEST_RUN_DIR" ]]; then + echo "ERROR: No previous run found in $OUTPUT_BASE. Cannot resume." >&2 + echo " Start a new run without --resume first." >&2 + exit 1 + fi + + # Extract run name from path + RUN_NAME=$(basename "$LATEST_RUN_DIR") + echo "[run_evolve_experiment] Found latest run: $RUN_NAME" + + # Find the latest checkpoint in that run + CHECKPOINT_DIR="$LATEST_RUN_DIR/checkpoints" + if [[ -d "$CHECKPOINT_DIR" ]]; then + # Sort by the numeric suffix of checkpoint_N (extract number, sort numerically) + LATEST_CHECKPOINT=$(find "$CHECKPOINT_DIR" -maxdepth 1 -type d -name "checkpoint_*" 2>/dev/null | \ + while read -r p; do echo "$(basename "$p" | sed 's/checkpoint_//')|$p"; done | \ + sort -t'|' -k1 -n | tail -n 1 | cut -d'|' -f2) + + if [[ -n "$LATEST_CHECKPOINT" ]]; then + CHECKPOINT="$LATEST_CHECKPOINT" + echo "[run_evolve_experiment] Found latest checkpoint: $CHECKPOINT" + else + echo "[run_evolve_experiment] No checkpoint found, will continue from database state" + fi + else + echo "[run_evolve_experiment] No checkpoints directory found, will continue from database state" + fi +fi + +if [[ -z "$RUN_NAME" ]]; then + RUN_NAME="openevolve_output_$(date +%Y%m%d_%H%M%S)" +fi + +RUN_DIR="$OUTPUT_BASE/$RUN_NAME" +mkdir -p "$RUN_DIR" + +INITIAL_PROGRAM="$SCRIPT_DIR/initial_program.py" +EVALUATION_FILE="$SCRIPT_DIR/evaluator.py" + +CFG_OUT="$RUN_DIR/config.yaml" + +# Write an updated config into the run directory with db_path isolated to this run. +CONFIG_PATH="$CONFIG_PATH" CFG_OUT="$CFG_OUT" RUN_DIR="$RUN_DIR" "$PYTHON_BIN" -u - <<'PY' +import os +import sys + +import yaml + +cfg_path = os.environ["CONFIG_PATH"] +out_path = os.environ["CFG_OUT"] +run_dir = os.environ["RUN_DIR"] + +with open(cfg_path, "r") as f: + cfg = yaml.safe_load(f) + +cfg = cfg or {} +cfg.setdefault("database", {}) +cfg["database"]["db_path"] = os.path.join(run_dir, "qwen3_metal_kernel_evolution") + +with open(out_path, "w") as f: + yaml.safe_dump(cfg, f, sort_keys=False) + +print(f"[run_evolve_experiment] Wrote config: {out_path}") +print(f"[run_evolve_experiment] database.db_path: {cfg['database']['db_path']}") +PY + +CMD=( + "$PYTHON_BIN" -u -m openevolve.cli + "$INITIAL_PROGRAM" + "$EVALUATION_FILE" + -c "$CFG_OUT" + -o "$RUN_DIR" +) + +if [[ -n "$ITERATIONS" ]]; then CMD+=(-i "$ITERATIONS"); fi +if [[ -n "$TARGET_SCORE" ]]; then CMD+=(-t "$TARGET_SCORE"); fi +if [[ -n "$LOG_LEVEL" ]]; then CMD+=(-l "$LOG_LEVEL"); fi +if [[ -n "$CHECKPOINT" ]]; then CMD+=(--checkpoint "$CHECKPOINT"); fi +if [[ -n "$API_BASE" ]]; then CMD+=(--api-base "$API_BASE"); fi +if [[ -n "$PRIMARY_MODEL" ]]; then CMD+=(--primary-model "$PRIMARY_MODEL"); fi +if [[ -n "$SECONDARY_MODEL" ]]; then CMD+=(--secondary-model "$SECONDARY_MODEL"); fi + +echo "[run_evolve_experiment] Run dir: $RUN_DIR" +echo "[run_evolve_experiment] Command:" +printf " %q" "${CMD[@]}" +echo + +if [[ "$DRY_RUN" -eq 1 ]]; then + exit 0 +fi + +LOG_FILE="$RUN_DIR/run.log" + +# Truncate log file to ensure clean start (especially important for --resume) +: > "$LOG_FILE" + +# Check if stdbuf is available for line-buffered output +if command -v stdbuf &>/dev/null; then + # Use stdbuf to force line buffering on both stdout and stderr + STDBUF_PREFIX=(stdbuf -oL -eL) +else + STDBUF_PREFIX=() +fi + +if [[ "$FOREGROUND" -eq 1 ]]; then + # Stream to console and persist logs with line buffering. + "${STDBUF_PREFIX[@]}" "${CMD[@]}" 2>&1 | tee "$LOG_FILE" +else + # Run in background with line-buffered output for reliable log ordering. + nohup "${STDBUF_PREFIX[@]}" "${CMD[@]}" > "$LOG_FILE" 2>&1 & + echo "[run_evolve_experiment] Started PID: $!" + echo "[run_evolve_experiment] Log: $LOG_FILE" + echo "[run_evolve_experiment] Tail: tail -f \"$LOG_FILE\"" +fi + + From f41b0b8d7a8faa077420faf4c0075d88e4722392 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:48:56 +0800 Subject: [PATCH 07/14] fix(mlx_metal_kernel_opt): improve benchmark subprocess error handling - Add -W ignore::RuntimeWarning to suppress harmless import warnings - Filter RuntimeWarning from stderr output for cleaner error messages - Show return code in failure messages for better debugging - Add 0.5s delay between runs to reduce GPU resource contention - Show more stderr content (200 chars vs 100) for better diagnostics --- .../qwen3_benchmark_suite.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py b/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py index 1ecd97909..f557acdda 100644 --- a/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py +++ b/examples/mlx_metal_kernel_opt/qwen3_benchmark_suite.py @@ -581,6 +581,7 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult: ) cmd = [ sys.executable, + "-W", "ignore::RuntimeWarning", # Suppress harmless import warnings wrapper_path, "--hook-program", self.hook_program_path, @@ -594,6 +595,7 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult: else: cmd = [ sys.executable, + "-W", "ignore::RuntimeWarning", # Suppress harmless import warnings "-m", "mlx_lm.generate", "--model", @@ -615,12 +617,21 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult: print(f" Warmup run {i+1}/{WARMUP_RUNS}...") warmup_result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) if warmup_result.returncode != 0: - print(f" ⚠️ Warmup run {i+1} failed: {warmup_result.stderr[:100]}...") + # Filter out harmless warnings from stderr + stderr_clean = "\n".join( + line for line in warmup_result.stderr.split("\n") + if "RuntimeWarning" not in line and line.strip() + ) + if stderr_clean: + print(f" ⚠️ Warmup run {i+1} failed (code {warmup_result.returncode}): {stderr_clean[:200]}...") + else: + print(f" ⚠️ Warmup run {i+1} failed (code {warmup_result.returncode})") else: print(f" ✅ Warmup run {i+1} completed") - # Clear cache between warmup runs + # Clear cache and add small delay between runs to reduce GPU contention mx.clear_cache() + time.sleep(0.5) except subprocess.TimeoutExpired: print(f" ⏰ Warmup run {i+1} timed out") @@ -645,7 +656,16 @@ def run_single_benchmark(self, config: BenchmarkConfig) -> BenchmarkResult: end_time = time.perf_counter() if result.returncode != 0: - print(f" ❌ Measurement run {run_idx+1} failed: {result.stderr[:100]}...") + # Filter out harmless warnings from stderr + stderr_clean = "\n".join( + line for line in result.stderr.split("\n") + if "RuntimeWarning" not in line and line.strip() + ) + if stderr_clean: + print(f" ❌ Measurement run {run_idx+1} failed (code {result.returncode}): {stderr_clean[:200]}...") + else: + print(f" ❌ Measurement run {run_idx+1} failed (code {result.returncode})") + time.sleep(0.5) # Small delay before retry continue # Parse output From 61fa0e7fd88813689ae87583edf033468a3e5345 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Tue, 6 Jan 2026 18:18:41 +0800 Subject: [PATCH 08/14] docs(mlx_metal_kernel_opt): rewrite README and remove invalid report --- examples/mlx_metal_kernel_opt/README.md | 247 ++++-------------- .../README_validity_fix.md | 246 ----------------- 2 files changed, 48 insertions(+), 445 deletions(-) delete mode 100644 examples/mlx_metal_kernel_opt/README_validity_fix.md diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index 8a0c35136..1c4cdd0a7 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -1,228 +1,77 @@ -# OpenEvolve Metal Kernel Optimization: Automated Discovery of Custom GPU Kernels for Transformer Attention +# MLX Metal Kernel Optimization (Qwen3-0.6B-bf16) -**Evolutionary Optimization of Apple Silicon Metal Kernels for Grouped Query Attention in Qwen3-0.6B** +This example demonstrates evolutionary optimization of a custom Apple Silicon **Metal** attention kernel using OpenEvolve and MLX’s `metal_kernel` API. The target workload is **Grouped Query Attention (GQA)** for the MLX‑LM model `mlx-community/Qwen3-0.6B-bf16`. -## Abstract +## Target -This work demonstrates the application of evolutionary code optimization to the automatic discovery of custom Metal GPU kernels for transformer attention mechanisms. Using OpenEvolve, we evolved a specialized Metal kernel for Grouped Query Attention (GQA) in Qwen3-0.6B that leverages Apple Silicon's unified memory architecture and vector processing capabilities. Our approach achieved measurable performance improvements over MLX's highly optimized `scaled_dot_product_attention` baseline across diverse inference workloads, with decode speed improvements averaging 12.5% and reaching up to 106% on specific benchmark tasks. - -## 1. Introduction - -### 1.1 Motivation - -Modern transformer models rely heavily on optimized attention kernels for efficient inference. While frameworks like MLX provide highly optimized implementations, the rapid evolution of hardware architectures creates opportunities for specialized optimizations that general-purpose kernels cannot capture. This work explores whether evolutionary code optimization can automatically discover hardware-specific kernel optimizations that outperform expert-engineered baselines. - -### 1.2 Target System - -- **Model**: Qwen3-0.6B with Grouped Query Attention (40 query heads : 8 key-value heads) -- **Hardware**: Apple M-series GPUs with unified memory architecture -- **Framework**: MLX with custom Metal kernel integration +- **Model**: `mlx-community/Qwen3-0.6B-bf16` +- **Attention**: GQA **16 query heads : 8 KV heads** (2:1), **head_dim=128**, **hidden_size=2048** +- **Dtype**: `bfloat16` (bf16) by default for this model - **Baseline**: `mx.fast.scaled_dot_product_attention` -- **Evolution Target**: Metal shader source code implementing GQA attention computation - -## 2. Methodology - -### 2.1 Evolution Framework - -We employ OpenEvolve to automatically optimize the Metal kernel source code responsible for computing attention. The evolutionary process operates on a single code block (EVOLVE-BLOCK) containing approximately 150 lines of Metal C++ shader code while preserving the surrounding MLX integration infrastructure. - -**Evolution Configuration**: -- **Population Size**: 25 programs -- **Generations**: 25 iterations -- **Models**: Gemini 2.5 Flash (60%) + Gemini 2.5 Pro (40%) -- **Selection**: Multi-objective optimization balancing performance and correctness - -### 2.2 Evaluation Methodology - -Each evolved kernel undergoes comprehensive evaluation: +- **Hardware**: Apple Silicon (Metal) -1. **Correctness Validation**: Numerical accuracy verification against MLX baseline -2. **Performance Benchmarking**: 20 diverse inference scenarios covering: - - Short context (16-64 tokens) - - Long context (512-2048 tokens) - - Code generation - - Sustained dialogue - - Technical documentation - - Memory stress tests +## Key files -3. **Safety Validation**: GPU command buffer error detection and Metal memory violation checking +- `initial_program.py`: starting point (contains `create_metal_qwen3_optimization_hook()` and the EVOLVE‑BLOCK) +- `evaluator.py`: correctness + benchmarking + safety checks for candidates +- `qwen3_benchmark_suite.py`: benchmark definitions and subprocess runner +- `mlx_lm_generate_with_hook.py`: wrapper to apply the attention hook **inside** the `mlx_lm.generate` subprocess +- `run_benchmarks.py`: convenience benchmark runner (baseline vs optimized) +- `config.yaml`: OpenEvolve config and optimization prompt +- `run_evolve_experiment.sh`: convenience script for isolated runs (`output_dir` + `db_path`) -### 2.3 Optimization Constraints +## Important: evaluation validity (before vs after) -**Preserved Elements**: -- Kernel function signature and I/O specifications -- Thread grid mapping and bounds checking -- Overall algorithm correctness (attention semantics) -- MLX integration interface +Earlier versions of this example could produce misleading “best program” artifacts and invalid performance comparisons. The main issues and the fixes: -**Optimizable Elements**: -- Memory access patterns and vectorization -- Computation order and algorithmic efficiency -- Apple Silicon specific optimizations -- GQA-specific computation strategies +| Area | Before | After | +|------|--------|-------| +| **Subprocess benchmark hook** | Benchmarks ran `python -m mlx_lm.generate ...` via `subprocess.run(...)`, so any monkey‑patch in the parent process was **not applied** in the child process (baseline and “optimized” could run the same attention). | Benchmarks can run via `mlx_lm_generate_with_hook.py --hook-program ...` so the patch is applied **inside the subprocess**. | +| **bf16 correctness** | Correctness used `float32` inputs; candidates could pass tests but fail in real bf16 inference (Metal compilation/runtime errors). | Correctness covers **bf16**, and deterministic Metal compilation errors are treated as normal candidate failures. | +| **Architecture alignment** | Docs/prompt/MockArgs assumed **40:8** heads and **hidden_size=5120** (incorrect for Qwen3‑0.6B). | Docs/prompt/MockArgs aligned to **16:8** and **hidden_size=2048**. | -## 3. Technical Contributions +Because of these fixes, we intentionally avoid hard-coded performance claims here. **Rerun the benchmarks on your own machine** and record results in your environment. -### 3.1 Discovered Optimizations +## Run evolution -The evolutionary process discovered several key optimizations: +From this directory: -#### 3.1.1 Enhanced Vectorization -```metal -// Original: Scalar operations -for (uint d = 0; d < HEAD_DIM; d++) { - score += query_vec[d] * keys[k_base + d]; -} - -// Evolved: Vector operations with optimal width -vec query_vec_v[HEAD_DIM / 8]; // 16 vectors for 128-dim heads -for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { - score += dot(query_vec_v[d_vec], ((device vec*)(keys + k_base))[d_vec]); -} +```bash +export OPENAI_API_KEY="..." # or set GEMINI_API_KEY; see the runner script +bash run_evolve_experiment.sh --foreground ``` -**Innovation**: Using 8-element vectors perfectly matches Apple Silicon's SIMD capabilities for 128-dimensional attention heads. - -#### 3.1.2 Online Softmax Algorithm -```metal -// Pass 1: Find maximum for numerical stability -T max_score = T(-INFINITY); -for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { - T score = compute_attention_score(query_vec, key_vec) * scale_val; - max_score = max(max_score, score); -} - -// Pass 2: Combined softmax computation and value accumulation -T sum_exp = T(0.0); -vec output_acc_v[HEAD_DIM / 8]; -for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { - T exp_score = exp(current_score - max_score); - sum_exp += exp_score; - // Fused accumulation - output_acc_v[d_vec] += exp_score * ((device vec*)(values + v_base))[d_vec]; -} -``` +This writes a new `openevolve_output_/` directory containing logs, checkpoints, best programs, and an isolated database. -**Innovation**: Reduced from three-pass to two-pass algorithm, fusing softmax normalization with value accumulation. +If you prefer running the CLI directly: -#### 3.1.3 Memory Access Optimization -```metal -// Pre-computed base indices for coalesced access -const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + - head_idx * (SEQ_LEN * HEAD_DIM) + - query_pos * HEAD_DIM; -const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 5:1 mapping +```bash +export OPENAI_API_KEY="..." +python -m openevolve.cli ./initial_program.py ./evaluator.py -c ./config.yaml -o ./openevolve_output ``` -**Innovation**: Leverages unified memory bandwidth through coalesced access patterns and direct GQA head mapping. - -### 3.2 Apple Silicon Specialization - -The evolved kernel exploits specific Apple Silicon features: -- **Unified Memory**: Optimized bandwidth utilization patterns -- **SIMD Width**: 8-element vectors matching GPU vector units -- **Thread Group Size**: 32-thread groups optimal for Apple GPUs -- **Register Allocation**: Balanced computation vs. memory bandwidth - -## 4. Experimental Results - -### 4.1 Performance Benchmarking - -We evaluated the evolved kernel against MLX baseline across 20 comprehensive benchmark scenarios representing real-world inference patterns. - -**Aggregate Performance Improvements**: -- **Decode Speed**: +12.5% average improvement (σ = 38.3%) -- **Prefill Speed**: +14.4% average improvement (σ = 17.6%) -- **Total Throughput**: +10.4% average improvement (σ = 30.7%) -- **Memory Usage**: +0.99% average reduction (σ = 1.7%) - -### 4.2 Benchmark Category Analysis - -| **Category** | **Benchmarks** | **Decode Improvement** | **Notable Results** | -|--------------|----------------|------------------------|-------------------| -| **Short Context** | 2 | -4.6% ± 3.8% | Mixed results on very short sequences | -| **Long Context** | 6 | +8.1% ± 42.1% | High variance, strong improvements in some cases | -| **Code Generation** | 1 | -16.5% | Performance regression | -| **General Tasks** | 9 | +24.8% ± 35.4% | Strongest category with 106% peak improvement | -| **Stress Tests** | 2 | +22.9% ± 31.5% | Robust performance under memory pressure | +## Run benchmarks (baseline vs optimized) -### 4.3 Statistical Analysis +To compare the MLX baseline against the best evolved program: -**Distribution of Improvements**: -- **Significant Gains** (>25%): 7/20 benchmarks -- **Moderate Gains** (5-25%): 3/20 benchmarks -- **Neutral** (±5%): 4/20 benchmarks -- **Regressions** (<-5%): 6/20 benchmarks - -**Peak Performance**: Repetitive pattern generation achieved 106% decode speed improvement, demonstrating the kernel's effectiveness for certain workload characteristics. - -### 4.4 Correctness Validation - -All evolved kernels maintained numerical correctness: -- **Accuracy**: 100% correctness score across all test cases -- **Numerical Stability**: No NaN/Inf values detected -- **Statistical Validation**: Output distributions within expected ranges -- **Functional Equivalence**: Attention semantics preserved - -## 5. Discussion - -### 5.1 Performance Characteristics - -The evolved kernel shows workload-dependent performance characteristics: - -**Strengths**: -- **Sustained Generation**: +46.6% improvement on dialogue tasks -- **Long Sequences**: +73.9% improvement on extreme-length generation -- **Memory Efficiency**: Consistent memory usage reduction - -**Limitations**: -- **Short Sequences**: Limited improvement due to setup overhead -- **Code Generation**: -16.5% regression suggesting suboptimal patterns for this workload -- **Variance**: High performance variance across different sequence patterns - -### 5.2 Technical Insights - -**Vectorization Impact**: The discovery of `vec` operations as optimal for 128-dimensional heads represents a significant finding, suggesting that hardware-specific vector widths are crucial for performance. - -**Algorithm Innovation**: The two-pass online softmax represents a novel contribution, demonstrating that evolutionary approaches can discover algorithmic improvements beyond simple micro-optimizations. - -**GQA Specialization**: Direct exploitation of the 5:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. - -### 5.3 Evolutionary Process Analysis - -**Convergence**: The system converged to the optimal solution within 25 generations, with significant improvements appearing by generation 10. - -**Safety**: Zero Metal kernel compilation errors or GPU command buffer failures across all evolution attempts, demonstrating robust evolutionary constraints. - -**Diversity**: The evolutionary process explored multiple optimization strategies including different vectorization patterns, memory layouts, and algorithmic approaches. - -## 6. Related Work - -This work extends prior research in automated kernel optimization: - -- **AlphaTensor** [Fawzi et al., 2022]: Matrix multiplication algorithm discovery -- **TensorIR** [Feng et al., 2023]: Tensor compiler optimization -- **Ansor** [Zheng et al., 2020]: Automated tensor program optimization - -Our approach differs by applying evolutionary optimization directly to GPU shader source code rather than higher-level tensor algebra, enabling discovery of hardware-specific optimizations that would be difficult to express in tensor IRs. +```bash +python run_benchmarks.py --mode compare --model mlx-community/Qwen3-0.6B-bf16 --output-dir results +``` -## 7. Limitations and Future Work +## How to verify the validity fixes are active -### 7.1 Current Limitations +When the hook is enabled, the optimized path should execute via the wrapper: -- **Workload Specificity**: Performance improvements are highly dependent on sequence patterns -- **Model Scope**: Results specific to Qwen3-0.6B's 40:8 GQA configuration -- **Hardware Scope**: Optimizations specific to Apple Silicon architecture +- `mlx_lm_generate_with_hook.py --hook-program --model ...` -### 7.2 Future Directions +You can also sanity-check that correctness is exercising bf16 by running evolution on a machine where bf16 Metal compilation errors are expected for invalid kernels: such candidates should be rejected early by correctness gating rather than becoming “best programs”. -- **Multi-Architecture**: Extend to CUDA, ROCm, and other GPU architectures -- **Model Generalization**: Apply to different attention patterns and model sizes -- **Algorithmic Expansion**: Explore evolution of other transformer components -- **Cross-Compilation**: Develop architecture-agnostic optimization strategies +## Limitations & potential improvements (follow-up work) -## 8. Conclusion +This example intentionally uses **end-to-end generation benchmarks** (`mlx_lm.generate`) to measure real workloads, but that comes with trade-offs: -We demonstrate that evolutionary code optimization can automatically discover hardware-specific GPU kernel optimizations that outperform expert-engineered baselines. The evolved Metal kernel achieved an average 12.5% decode speed improvement through novel vectorization patterns, algorithmic innovations, and Apple Silicon specializations. While performance gains are workload-dependent, the approach successfully identified genuinely novel optimizations that would be challenging to discover through manual optimization. +- **Benchmark noise & overhead**: subprocess startup, model loading, and generation variability can dwarf small kernel deltas (especially for short prompts). A complementary **microbenchmark** that times only the attention kernel would provide a cleaner signal. +- **Serial evaluation by default**: candidates are evaluated sequentially (`parallel_evaluations: 1`) to keep GPU memory predictable. More parallelism may be possible with careful isolation, but it needs engineering. +- **Compile-time dominates early search**: bf16 compilation failures are common and deterministic; caching compilation outcomes or factoring compilation into a cheaper gating stage may speed up evolution. -This work establishes evolutionary optimization as a viable approach for automated GPU kernel discovery and suggests significant potential for applying similar techniques to other performance-critical computational kernels. \ No newline at end of file +We plan to open follow-up issues to track improvements to the benchmarking/evolution signal and workflow. \ No newline at end of file diff --git a/examples/mlx_metal_kernel_opt/README_validity_fix.md b/examples/mlx_metal_kernel_opt/README_validity_fix.md deleted file mode 100644 index bfe469d93..000000000 --- a/examples/mlx_metal_kernel_opt/README_validity_fix.md +++ /dev/null @@ -1,246 +0,0 @@ -# OpenEvolve Metal Kernel Optimization: Automated Discovery of Custom GPU Kernels for Transformer Attention - -**Evolutionary Optimization of Apple Silicon Metal Kernels for Grouped Query Attention in Qwen3-0.6B** - -## Abstract - -This example demonstrates evolutionary code optimization for discovering custom Apple Silicon Metal GPU kernels for transformer attention. It targets Grouped Query Attention (GQA) in Qwen3-0.6B using MLX’s `metal_kernel` API, with performance evaluated via `mlx_lm.generate`. - -> **Important**: Earlier versions of this example had evaluation validity issues (subprocess benchmarks not using the evolved kernel, correctness tests using float32 while the default model is bfloat16, and docs/tests assuming the wrong head configuration). These issues can lead to misleading “best program” results and invalid performance claims. The example has been updated to address these problems. - -## 1. Introduction - -### 1.1 Motivation - -Modern transformer models rely heavily on optimized attention kernels for efficient inference. While frameworks like MLX provide highly optimized implementations, the rapid evolution of hardware architectures creates opportunities for specialized optimizations that general-purpose kernels cannot capture. This work explores whether evolutionary code optimization can automatically discover hardware-specific kernel optimizations that outperform expert-engineered baselines. - -### 1.2 Target System - -- **Model**: Qwen3-0.6B with Grouped Query Attention (16 query heads : 8 key-value heads) -- **Hardware**: Apple M-series GPUs with unified memory architecture -- **Framework**: MLX with custom Metal kernel integration -- **Baseline**: `mx.fast.scaled_dot_product_attention` -- **Evolution Target**: Metal shader source code implementing GQA attention computation - -## 2. Methodology - -### 2.1 Evolution Framework - -We employ OpenEvolve to automatically optimize the Metal kernel source code responsible for computing attention. The evolutionary process operates on a single code block (EVOLVE-BLOCK) containing approximately 150 lines of Metal C++ shader code while preserving the surrounding MLX integration infrastructure. - -**Evolution Configuration**: -- **Population Size**: 25 programs -- **Generations**: 25 iterations -- **Models**: Gemini 2.5 Flash (60%) + Gemini 2.5 Pro (40%) -- **Selection**: Multi-objective optimization balancing performance and correctness - -### 2.2 Evaluation Methodology - -Each evolved kernel undergoes comprehensive evaluation: - -1. **Correctness Validation**: Functional/safety checks and dtype coverage consistent with the target model (bfloat16 by default). -2. **Performance Benchmarking**: Diverse inference scenarios covering: - - Short context (16-64 tokens) - - Long context (512-2048 tokens) - - Code generation - - Sustained dialogue - - Technical documentation - - Memory stress tests - -3. **Safety Validation**: GPU command buffer error detection and Metal memory violation checking - -### 2.3 Optimization Constraints - -**Preserved Elements**: -- Kernel function signature and I/O specifications -- Thread grid mapping and bounds checking -- Overall algorithm correctness (attention semantics) -- MLX integration interface - -**Optimizable Elements**: -- Memory access patterns and vectorization -- Computation order and algorithmic efficiency -- Apple Silicon specific optimizations -- GQA-specific computation strategies - -## 3. Technical Contributions - -### 3.1 Discovered Optimizations - -The evolutionary process discovered several key optimizations: - -#### 3.1.1 Enhanced Vectorization -```metal -// Original: Scalar operations -for (uint d = 0; d < HEAD_DIM; d++) { - score += query_vec[d] * keys[k_base + d]; -} - -// Evolved: Vector operations with optimal width -vec query_vec_v[HEAD_DIM / 8]; // 16 vectors for 128-dim heads -for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { - score += dot(query_vec_v[d_vec], ((device vec*)(keys + k_base))[d_vec]); -} -``` - -**Note**: Vectorized kernels must be validated under the target dtype (bfloat16 by default). Some vectorized patterns (e.g. `dot(vec)`) may not compile on Metal and should be caught by correctness gating. - -#### 3.1.2 Online Softmax Algorithm -```metal -// Pass 1: Find maximum for numerical stability -T max_score = T(-INFINITY); -for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { - T score = compute_attention_score(query_vec, key_vec) * scale_val; - max_score = max(max_score, score); -} - -// Pass 2: Combined softmax computation and value accumulation -T sum_exp = T(0.0); -vec output_acc_v[HEAD_DIM / 8]; -for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { - T exp_score = exp(current_score - max_score); - sum_exp += exp_score; - // Fused accumulation - output_acc_v[d_vec] += exp_score * ((device vec*)(values + v_base))[d_vec]; -} -``` - -**Innovation**: Reduced from three-pass to two-pass algorithm, fusing softmax normalization with value accumulation. - -#### 3.1.3 Memory Access Optimization -```metal -// Pre-computed base indices for coalesced access -const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + - head_idx * (SEQ_LEN * HEAD_DIM) + - query_pos * HEAD_DIM; -const uint kv_head_idx = head_idx / HEADS_PER_KV; // Direct 2:1 mapping -``` - -**Innovation**: Leverages unified memory bandwidth through coalesced access patterns and direct GQA head mapping. - -### 3.2 Apple Silicon Specialization - -The evolved kernel exploits specific Apple Silicon features: -- **Unified Memory**: Optimized bandwidth utilization patterns -- **SIMD Width**: 8-element vectors matching GPU vector units -- **Thread Group Size**: 32-thread groups optimal for Apple GPUs -- **Register Allocation**: Balanced computation vs. memory bandwidth - -## 4. Experimental Results - -### 4.1 Performance Benchmarking - -We evaluate kernels against the MLX baseline across a benchmark suite representing real-world inference patterns. - -**Note**: If you are comparing results across commits, ensure the benchmarks are actually exercising the custom kernel (subprocess hook) and that correctness covers the target dtype (bfloat16). Otherwise, reported speedups can be noise. - -To reproduce results on your machine: - -```bash -cd openevolve/examples/mlx_metal_kernel_opt -python run_benchmarks.py --mode compare --model mlx-community/Qwen3-0.6B-bf16 --output-dir results -``` - -This writes CSV/JSON comparison artifacts into `results/` for analysis. - -### 4.4 Correctness Validation - -Correctness checks should include: -- **Target dtype coverage** (bfloat16 by default for `mlx-community/Qwen3-0.6B-bf16`) -- **Numerical sanity** (no NaN/Inf) -- **Shape checks** -- **Safety checks** (GPU command buffer errors / memory violations) - -## 5. Discussion - -### 5.1 Performance Characteristics - -Kernel performance is workload-dependent. In particular: - -- **Short sequences**: may see limited gains due to fixed overheads. -- **Long sequences / sustained decode**: are typically where attention kernels matter most, but must be measured. - -### 5.2 Technical Insights - -**Vectorization Impact**: The discovery of `vec` operations as optimal for 128-dimensional heads represents a significant finding, suggesting that hardware-specific vector widths are crucial for performance. - -**Algorithm Innovation**: The two-pass online softmax represents a novel contribution, demonstrating that evolutionary approaches can discover algorithmic improvements beyond simple micro-optimizations. - -**GQA Specialization**: Direct exploitation of the 2:1 query-to-KV head ratio through specialized indexing patterns shows the value of architecture-specific optimizations. - -### 5.3 Evolutionary Process Analysis - -With realistic evaluation enabled (subprocess hook + bfloat16 correctness), it is expected that some evolved kernels will be rejected due to bfloat16 Metal compilation/runtime failures. The evaluator should treat these as ordinary candidate failures (not crashes) and continue evolution. - -## 6. Related Work - -This work extends prior research in automated kernel optimization: - -- **AlphaTensor** [Fawzi et al., 2022]: Matrix multiplication algorithm discovery -- **TensorIR** [Feng et al., 2023]: Tensor compiler optimization -- **Ansor** [Zheng et al., 2020]: Automated tensor program optimization - -Our approach differs by applying evolutionary optimization directly to GPU shader source code rather than higher-level tensor algebra, enabling discovery of hardware-specific optimizations that would be difficult to express in tensor IRs. - -## 7. Limitations and Future Work - -### 7.1 Current Limitations - -- **Workload Specificity**: Performance improvements are highly dependent on sequence patterns -- **Model Scope**: Results specific to Qwen3-0.6B's 16:8 GQA configuration -- **Hardware Scope**: Optimizations specific to Apple Silicon architecture - -### 7.2 Future Directions - -- **Multi-Architecture**: Extend to CUDA, ROCm, and other GPU architectures -- **Model Generalization**: Apply to different attention patterns and model sizes -- **Algorithmic Expansion**: Explore evolution of other transformer components -- **Cross-Compilation**: Develop architecture-agnostic optimization strategies - -## 8. Conclusion - -We demonstrate how evolutionary code optimization can be applied to discover hardware-specific Metal kernels for transformer attention. Performance gains are workload-dependent; for credible results, rerun the benchmark suite on your machine with the evaluation validity fixes enabled. - -This work establishes evolutionary optimization as a viable approach for automated GPU kernel discovery and suggests significant potential for applying similar techniques to other performance-critical computational kernels. - ---- - -## Appendix: Changelog (Validity & Performance Fixes) - -This section documents the specific fixes applied to address evaluation validity issues in the original example. - -### Critical Bug Fixes - -| Fix | Description | Impact | -|-----|-------------|--------| -| **Subprocess Benchmark Hook** | Evolved attention was not applied inside `subprocess.run()` — both baseline and custom benchmarks ran the same MLX attention. Fixed by injecting the hook within the subprocess. | **All reported speedups were invalid before this fix.** | -| **Dtype Alignment (bfloat16)** | Correctness tests used `float32` inputs while `Qwen3-0.6B-bf16` runs in `bfloat16`. Kernels could pass correctness but fail at inference time. Fixed by testing with `mx.bfloat16`. | Kernels incompatible with bfloat16 are now correctly rejected. | -| **Head Ratio Correction** | Documentation and tests assumed 16:8 heads, but `Qwen3-0.6B` actually uses 16:8 (2:1 GQA ratio). Verified and aligned. | Prevents confusion in kernel design. | - -### Evaluation Efficiency Optimizations - -| Optimization | Description | Benefit | -|--------------|-------------|---------| -| **Early Exit on Compilation Errors** | Metal compilation errors (e.g., `dot()` on bfloat16 vectors) are deterministic — no point retrying. Now returns immediately with `compilation_error: True`. | Saves ~30s per failed iteration (was retrying 3× per sequence length). | -| **Correctness-First Evaluation** | Reordered: correctness test runs **before** baseline benchmark. If correctness fails, baseline is skipped. | Saves ~1-2 min per invalid kernel. | -| **Log Buffering Fix** | Added `PYTHONUNBUFFERED=1` and optional `stdbuf` to `run_evolve_experiment.sh` to ensure `run.log` outputs in correct order. | Reliable log analysis. | - -### Files Modified - -- `evaluator.py` — Early exit logic, correctness-first ordering, bfloat16 test inputs -- `qwen3_benchmark_suite.py` — Subprocess hook injection -- `run_evolve_experiment.sh` — Unbuffered logging -- `config.yaml` — Documentation alignment - -### How to Verify Fixes Are Active - -```bash -# Check for early exit message in logs -grep "Metal compilation error (no retry)" openevolve_output_*/run.log - -# Check correctness runs before baseline (STEP 3 before STEP 4 in old logs, now STEP 3 = correctness) -grep "STEP 3:" openevolve_output_*/run.log | head -1 -# Should show: "STEP 3: Memory-Safe Custom Attention Correctness Testing" -``` - - From 2b5aec0ca203f2422ac33484ee0a91027c2fc933 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Tue, 6 Jan 2026 18:18:41 +0800 Subject: [PATCH 09/14] fix(mlx_metal_kernel_opt): align prompt examples and harden runner env --- examples/mlx_metal_kernel_opt/config.yaml | 8 ++++---- examples/mlx_metal_kernel_opt/run_evolve_experiment.sh | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/mlx_metal_kernel_opt/config.yaml b/examples/mlx_metal_kernel_opt/config.yaml index 29dd90ad2..ccc2cf6c5 100644 --- a/examples/mlx_metal_kernel_opt/config.yaml +++ b/examples/mlx_metal_kernel_opt/config.yaml @@ -4,9 +4,9 @@ log_level: "INFO" # LLM configuration for Metal kernel optimization llm: - primary_model: "gemini-2.5-flash-preview-05-20" + primary_model: "gemini-2.5-flash" primary_model_weight: 0.6 - secondary_model: "gemini-2.5-pro-preview-06-05" + secondary_model: "gemini-2.5-pro" secondary_model_weight: 0.4 api_base: "https://generativelanguage.googleapis.com/v1beta/openai/" temperature: 0.6 @@ -78,7 +78,7 @@ prompt: // CURRENT: Basic kv_head_idx = head_idx / HEADS_PER_KV // OPTIMIZE: Leverage the specific 2:1 ratio pattern - // Example: Process 5 query heads together for each KV head + // Example: Process 2 query heads together for each KV head // Example: Optimize memory layout for the 16:8 pattern // Example: Reduce broadcast overhead through clever indexing ``` @@ -181,7 +181,7 @@ prompt: **Strategy 4: GQA Pattern Exploitation** ```metal // Optimize for the specific 2:1 query:KV ratio - // Process query heads in groups of 5 + // Process query heads in groups of 2 // Reduce KV head indexing overhead ``` diff --git a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh index 2de3ae11f..32fd73058 100755 --- a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh +++ b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh @@ -50,8 +50,6 @@ USAGE # Force unbuffered Python output for reliable logging export PYTHONUNBUFFERED=1 -export OPENAI_API_KEY=$GEMINI_API_KEY - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" RUN_NAME="" From 556811d5e793c988699cfc942236bc21d66ad134 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Wed, 7 Jan 2026 01:16:56 +0800 Subject: [PATCH 10/14] docs(mlx_metal_kernel_opt): restructure documentation for PR - Rewrite README.md as concise usage guide (~100 lines) - Extract detailed analysis to EVOLUTION_ANALYSIS.md (~250 lines) - Document validity fixes: subprocess hook, bf16 gate, 16:8 heads - Add experiment results showing -3.2% regression vs baseline - Analyze evolution limitations from RL perspective - Compare with KernelBench metrics for future improvements - Minor fixes: config.yaml model names, run script unbuffered output --- .../EVOLUTION_ANALYSIS.md | 252 ++++++++++++++++++ examples/mlx_metal_kernel_opt/README.md | 121 +++++---- examples/mlx_metal_kernel_opt/config.yaml | 8 +- .../mlx_lm_generate_with_hook.py | 53 ++++ .../run_evolve_experiment.sh | 2 + 5 files changed, 382 insertions(+), 54 deletions(-) create mode 100644 examples/mlx_metal_kernel_opt/EVOLUTION_ANALYSIS.md diff --git a/examples/mlx_metal_kernel_opt/EVOLUTION_ANALYSIS.md b/examples/mlx_metal_kernel_opt/EVOLUTION_ANALYSIS.md new file mode 100644 index 000000000..97f00f93c --- /dev/null +++ b/examples/mlx_metal_kernel_opt/EVOLUTION_ANALYSIS.md @@ -0,0 +1,252 @@ +# Evolution Analysis: Why Optimization Failed + +This document analyzes the evolution experiment results after applying validity fixes, and proposes improvements for future work. + +## Experiment Results + +After applying validity fixes, we ran 25 evolution iterations to verify that the evaluation now works correctly. + +**Note**: The `maximum_context_stress_test` benchmark was disabled to reduce memory requirements on test hardware. + +### Evolution Summary + +| Metric | Value | +| ------ | ----- | +| Total Iterations | 25 | +| Programs Evaluated | 25 | +| Compilation Failures (bf16) | 8 (32%) | +| Best Program Found | Iteration 23 | +| Best combined_score | 2.96 | +| Benchmarks Used | 4 (stress test disabled) | + +### Performance of Best Evolved Kernel + +| Benchmark | Baseline (tok/s) | Custom (tok/s) | Change | +| --------- | ---------------- | -------------- | ------ | +| short_context_quick | 59.1 | 63.1 | **+6.9%** ✓ | +| code_generation | 58.3 | 58.1 | -0.4% | +| long_context_detailed | 54.7 | 46.0 | **-15.9%** | +| long_generation | 48.0 | 46.4 | -3.4% | +| **Average** | **55.0** | **53.4** | **-3.2%** | + +### Key Finding + +> **The best evolved kernel is still 3.2% SLOWER than MLX's baseline implementation.** + +The evolution only improved from an initial -11.5% regression to -3.2% regression. It never exceeded baseline performance. + +### Evolution Trajectory + +```text +Iteration 0 (Initial): -11.5% regression +Iterations 1-4: Failed (bf16 compilation errors) +Iteration 5: -23.6% regression +... +Iteration 19: -3.6% regression (first "positive" score) +Iteration 23: -3.2% regression (best found) +Iteration 25: Evolution complete, no improvement +``` + +--- + +## Why Evolution Failed + +The failure reveals fundamental limitations in the current evolution mechanism. Framing through a **Reinforcement Learning lens**: + +| RL Concept | Current State | Problem | +| ---------- | ------------- | ------- | +| **Reward Signal** | Detailed metrics but abstract ranking score | LLM sees metrics but selection uses opaque `combined_score` | +| **State Representation** | Code text + char-level features | Doesn't capture performance-relevant program properties | +| **Observability** | No GPU profiling data | Partially Observable MDP; agent blind to actual bottlenecks | +| **Credit Assignment** | Per-program metrics, no diff-level attribution | Cannot identify which code mutation caused improvement | +| **Exploration** | 1 parent + 5 samples per iteration | Severely underutilizes available information (128K context) | + +### 1. Meaningless Feature Dimensions + +Current MAP-Elites dimensions are inadequate for kernel optimization: + +| Dimension | Current Implementation | Problem for Kernels | +| --------- | -------------------- |-------------------- | +| `complexity` | Code character count | Two kernels with different algorithms can have similar length | +| `diversity` | Character-level diff | Renaming variables looks "diverse"; algorithmic changes don't | + +**What would be meaningful**: tiling strategy, vectorization width, memory access pattern, thread block size. + +### 2. Fitness Feedback Interpretability + +The LLM receives detailed metrics (decode speed, prefill speed, per-benchmark results), but: + +- **Relative performance unclear**: Raw `53.4 tok/s` means little without knowing baseline is `55.0 tok/s` +- **No performance diagnosis**: Cannot tell if kernel is memory-bound vs compute-bound +- **Selection uses abstract score**: MAP-Elites ranking uses `combined_score`, not individual metrics +- **Missing actionable guidance**: "Score: 2.96" doesn't tell LLM what to fix + +### 3. Lack of Profiling Data + +Without GPU profiling feedback, the LLM is essentially optimizing blind. Metal performance depends heavily on: + +- Memory coalescing patterns +- Register pressure +- Warp divergence +- Cache utilization + +None of this information is available to guide evolution. + +### 4. Conservative Parent Selection + +Default configuration uses 70% exploitation (selecting from elites). For kernel optimization where the search space has many local optima, this may cause premature convergence to suboptimal solutions. + +### 5. Underutilized LLM Context Window + +Each iteration only feeds the LLM: + +- 1 parent program +- 3 top programs (inspirations) +- 2 diverse programs + +This is extremely conservative given modern LLM context capabilities (128K+ tokens). + +**The real cost**: Each evolution iteration is expensive (~10 minutes for model loading + benchmarking), yet the LLM receives minimal information to guide its optimization. This is a **massive waste of resources**. + +**Better approach**: Feed the LLM as much context as possible—all programs from the current population, complete benchmark results, historical evolution trajectory. Only apply context pruning when approaching actual model limits. + +### 6. High Failure Rate + +32% of generated kernels failed to compile with bfloat16. The LLM generates syntactically valid Metal code but often uses float-only operations incompatible with bf16. + +### 7. Benchmarking Feedback Quality + +While the evaluator returns detailed metrics, the **ranking and selection** uses a single `combined_score`: + +```python +# Detailed metrics ARE available to LLM: +performance_metrics = {'avg_decode_speed': 53.4, 'baseline_comparison': {'avg_decode_improvement_pct': -3.2}} + +# But MAP-Elites selection uses: +combined_score = 2.96 # What does this mean? Is 3.0 good? Is 10.0 possible? +``` + +--- + +## KernelBench Comparison + +[KernelBench](https://github.com/ScalingIntelligence/KernelBench) provides a complete, evolution-ready metric system that could address many of these issues: + +### KernelBench Evaluation Structure + +**1. Binary Correctness Gates**: + +```python +class KernelExecResult: + compiled: bool # Did the kernel compile? + correctness: bool # Did it pass numerical correctness? (multiple trials) + metadata: dict # max_difference, avg_difference, error details +``` + +**2. Primary Optimization Objective** (direct speedup ratio): + +```python +speedup = baseline_time / custom_time # 1.2 = 20% faster, directly interpretable +``` + +**3. Statistical Rigor**: + +```python +runtime_stats = { + "mean": 3.68, # Average runtime (ms) + "std": 0.011, # Standard deviation + "min": 3.65, # Best case + "max": 3.74, # Worst case + "num_trials": 100 # With warmup runs +} +``` + +**4. Multi-threshold Performance Metrics**: + +```python +# fast_p: fraction of kernels that are BOTH correct AND achieve speedup > p +fast_0.0 = 0.85 # 85% correct +fast_1.0 = 0.42 # 42% faster than baseline +fast_1.5 = 0.18 # 18% achieve 1.5x speedup +fast_2.0 = 0.05 # 5% achieve 2x speedup +``` + +**5. Population-level Metrics**: + +```python +geometric_mean_speedup = 1.23 # Average 23% improvement across population +pass_at_1 = 0.42 +pass_at_5 = 0.78 +``` + +### How KernelBench Metrics Could Integrate with Evolution + +| OpenEvolve Component | Current | KernelBench-style Improvement | +| ------------------- | ------- | ---------------------------- | +| **Fitness Score** | Abstract `combined_score` | Direct `speedup` ratio | +| **Correctness Gate** | Binary pass/fail | Binary + `max_difference`, `avg_difference` for gradient | +| **Performance Feedback** | Single number | `mean ± std` with confidence intervals | +| **MAP-Elites Features** | Code length, char diff | Speedup tier (0.5x, 1x, 1.5x, 2x), runtime variance | +| **Early Stopping** | Fixed threshold | `fast_p` targets: stop when `fast_1.5 > 0.1` | +| **Prompt Feedback** | "Score: 2.96" | "Speedup: 0.85x (15% slower), need to beat 1.0x" | + +The key insight: **KernelBench's metrics are designed to be directly actionable**. The LLM can understand "this kernel is 15% slower than baseline" but cannot learn from "combined_score = 2.96". + +Additionally, KernelBench enables **temporal credit assignment**: + +- Compare child speedup vs parent speedup (not just vs baseline) +- Track which mutations led to improvement +- Provide mutation-specific feedback: "Adding SIMD vectorization improved prefill by 23%" + +--- + +## Proposed Improvements + +### Priority 1: Adopt KernelBench-style Evaluation + +- Replace `combined_score` with direct speedup ratio: `baseline_time / custom_time` +- Return statistical timing data: mean, std, min, max, num_trials +- Use `fast_p` as milestone targets for early stopping +- Report correctness metrics: `max_difference`, `avg_difference`, tolerance margin +- Provide actionable prompt feedback: "Speedup: 0.85x, need to beat 1.0x" + +### Priority 2: Performance-based MAP-Elites Features + +- `speedup_tier`: (0-0.5x, 0.5-1x, 1-1.5x, 1.5-2x, >2x) instead of code length +- `runtime_variance`: (low/medium/high std) for consistency tracking +- `correctness_margin`: distance from tolerance threshold + +### Priority 3: Integrate Metal GPU Profiling + +- Feed occupancy, bandwidth, cache stats back to LLM +- Use profiling data as additional feature dimensions + +### Priority 4: Domain-specific Strategy Tracking + +- `uses_simd_vectorization: 0-3` (none/2/4/8-wide) +- `memory_access_pattern: coalesced/strided/random` +- `algorithm_type: 2pass/3pass/online` + +### Priority 5: Maximize LLM Context Utilization + +- Feed entire population (or top N by speedup) instead of just 1 parent + 5 samples +- Include complete benchmark results with statistical breakdowns +- Show evolution history: what worked, what failed, why +- Only prune context when approaching actual model limits (128K+ tokens) + +### Priority 6: Curated Metal bf16 Examples + +- Add few-shot examples of correct bf16 Metal syntax +- Include common pitfalls in system prompt + +--- + +## References + +- [KernelBench](https://github.com/ScalingIntelligence/KernelBench) +- MAP-Elites: Mouret & Clune, 2015 + +--- + +*Experiment run: 2026-01-05 18:09 - 21:20 (3h 11m)* +*Note: `maximum_context_stress_test` disabled for this validation run* diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index 1c4cdd0a7..b79db2715 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -1,77 +1,98 @@ -# MLX Metal Kernel Optimization (Qwen3-0.6B-bf16) +# MLX Metal Kernel Optimization Example -This example demonstrates evolutionary optimization of a custom Apple Silicon **Metal** attention kernel using OpenEvolve and MLX’s `metal_kernel` API. The target workload is **Grouped Query Attention (GQA)** for the MLX‑LM model `mlx-community/Qwen3-0.6B-bf16`. +This example uses OpenEvolve to automatically discover optimized Metal GPU kernels for Grouped Query Attention (GQA) in Qwen3-0.6B on Apple Silicon. -## Target +## Target Configuration -- **Model**: `mlx-community/Qwen3-0.6B-bf16` -- **Attention**: GQA **16 query heads : 8 KV heads** (2:1), **head_dim=128**, **hidden_size=2048** -- **Dtype**: `bfloat16` (bf16) by default for this model -- **Baseline**: `mx.fast.scaled_dot_product_attention` -- **Hardware**: Apple Silicon (Metal) +- **Model**: Qwen3-0.6B-bf16 +- **Architecture**: 16 query heads : 8 KV heads (2:1 ratio), 2048 hidden size, 128 head dimension +- **Hardware**: Apple M-series GPUs with unified memory +- **Baseline**: `mx.fast.scaled_dot_product_attention` via `mlx_lm.generate` +- **Goal**: Evolve custom Metal kernel source code to outperform baseline -## Key files +## Quick Start -- `initial_program.py`: starting point (contains `create_metal_qwen3_optimization_hook()` and the EVOLVE‑BLOCK) -- `evaluator.py`: correctness + benchmarking + safety checks for candidates -- `qwen3_benchmark_suite.py`: benchmark definitions and subprocess runner -- `mlx_lm_generate_with_hook.py`: wrapper to apply the attention hook **inside** the `mlx_lm.generate` subprocess -- `run_benchmarks.py`: convenience benchmark runner (baseline vs optimized) -- `config.yaml`: OpenEvolve config and optimization prompt -- `run_evolve_experiment.sh`: convenience script for isolated runs (`output_dir` + `db_path`) +### Prerequisites -## Important: evaluation validity (before vs after) - -Earlier versions of this example could produce misleading “best program” artifacts and invalid performance comparisons. The main issues and the fixes: +```bash +pip install mlx mlx-lm openevolve -| Area | Before | After | -|------|--------|-------| -| **Subprocess benchmark hook** | Benchmarks ran `python -m mlx_lm.generate ...` via `subprocess.run(...)`, so any monkey‑patch in the parent process was **not applied** in the child process (baseline and “optimized” could run the same attention). | Benchmarks can run via `mlx_lm_generate_with_hook.py --hook-program ...` so the patch is applied **inside the subprocess**. | -| **bf16 correctness** | Correctness used `float32` inputs; candidates could pass tests but fail in real bf16 inference (Metal compilation/runtime errors). | Correctness covers **bf16**, and deterministic Metal compilation errors are treated as normal candidate failures. | -| **Architecture alignment** | Docs/prompt/MockArgs assumed **40:8** heads and **hidden_size=5120** (incorrect for Qwen3‑0.6B). | Docs/prompt/MockArgs aligned to **16:8** and **hidden_size=2048**. | +# Set API key (Gemini via OpenAI-compatible endpoint) +export OPENAI_API_KEY="your-gemini-key" +``` -Because of these fixes, we intentionally avoid hard-coded performance claims here. **Rerun the benchmarks on your own machine** and record results in your environment. +### Run Evolution -## Run evolution +```bash +cd openevolve/examples/mlx_metal_kernel_opt + +# Using the experiment runner script +./run_evolve_experiment.sh --name test_run --iterations 25 + +# Or directly +python -m openevolve.cli \ + --initial-program initial_program.py \ + --evaluator evaluator.py \ + --config config.yaml \ + --iterations 25 \ + --output ./openevolve_output +``` -From this directory: +### Verify Evaluation Validity ```bash -export OPENAI_API_KEY="..." # or set GEMINI_API_KEY; see the runner script -bash run_evolve_experiment.sh --foreground +# Run a single benchmark with verbose output +python -c " +from evaluator import Qwen3GQAEvaluator +e = Qwen3GQAEvaluator() +result = e.evaluate('initial_program.py') +print(result['summary']) +" ``` -This writes a new `openevolve_output_/` directory containing logs, checkpoints, best programs, and an isolated database. +## Files -If you prefer running the CLI directly: +| File | Purpose | +| ---- | ------- | +| `initial_program.py` | Starting Metal kernel (to be evolved) | +| `evaluator.py` | Correctness + performance evaluation | +| `config.yaml` | Evolution configuration | +| `qwen3_benchmark_suite.py` | Benchmark definitions | +| `mlx_lm_generate_with_hook.py` | Subprocess hook wrapper | +| `run_evolve_experiment.sh` | Experiment runner script | -```bash -export OPENAI_API_KEY="..." -python -m openevolve.cli ./initial_program.py ./evaluator.py -c ./config.yaml -o ./openevolve_output -``` +## Validity Fixes (This PR) -## Run benchmarks (baseline vs optimized) +This PR corrects critical issues that invalidated prior evaluation results: -To compare the MLX baseline against the best evolved program: +1. **Subprocess Kernel Hook**: Evolved kernels are now properly applied in benchmark subprocesses via `mlx_lm_generate_with_hook.py` -```bash -python run_benchmarks.py --mode compare --model mlx-community/Qwen3-0.6B-bf16 --output-dir results -``` +2. **bfloat16 Correctness Gate**: Correctness tests now use `mx.bfloat16` inputs to match actual inference dtype + +3. **Architecture Alignment**: Fixed head ratio from 40:8 to correct 16:8 (2:1 GQA pattern) + +4. **Evaluation Flow Optimizations**: Early exit on compilation errors, correctness-before-baseline ordering, GPU state cleanup between runs + +## Current Status -## How to verify the validity fixes are active +After fixing validity issues, we ran 25 evolution iterations. -When the hook is enabled, the optimized path should execute via the wrapper: +**Result: The best evolved kernel is 3.2% SLOWER than MLX's baseline implementation.** -- `mlx_lm_generate_with_hook.py --hook-program --model ...` +The evolution improved from an initial -11.5% regression to -3.2%, but never exceeded baseline. This indicates fundamental limitations in the current evolution mechanism that require further investigation. -You can also sanity-check that correctness is exercising bf16 by running evolution on a machine where bf16 Metal compilation errors are expected for invalid kernels: such candidates should be rejected early by correctness gating rather than becoming “best programs”. +For detailed experiment results and analysis, see [EVOLUTION_ANALYSIS.md](./EVOLUTION_ANALYSIS.md). -## Limitations & potential improvements (follow-up work) +### Known Limitations -This example intentionally uses **end-to-end generation benchmarks** (`mlx_lm.generate`) to measure real workloads, but that comes with trade-offs: +1. MAP-Elites selection uses abstract `combined_score` instead of direct speedup ratios +2. LLM context underutilized (only 1 parent + 5 samples per iteration) +3. No GPU profiling data to guide optimization +4. 32% bf16 compilation failure rate -- **Benchmark noise & overhead**: subprocess startup, model loading, and generation variability can dwarf small kernel deltas (especially for short prompts). A complementary **microbenchmark** that times only the attention kernel would provide a cleaner signal. -- **Serial evaluation by default**: candidates are evaluated sequentially (`parallel_evaluations: 1`) to keep GPU memory predictable. More parallelism may be possible with careful isolation, but it needs engineering. -- **Compile-time dominates early search**: bf16 compilation failures are common and deterministic; caching compilation outcomes or factoring compilation into a cheaper gating stage may speed up evolution. +## References -We plan to open follow-up issues to track improvements to the benchmarking/evolution signal and workflow. \ No newline at end of file +- [OpenEvolve](https://github.com/codelion/openevolve) +- [MLX](https://github.com/ml-explore/mlx) +- [MLX-LM](https://github.com/ml-explore/mlx-examples) +- [KernelBench](https://github.com/ScalingIntelligence/KernelBench) diff --git a/examples/mlx_metal_kernel_opt/config.yaml b/examples/mlx_metal_kernel_opt/config.yaml index ccc2cf6c5..29dd90ad2 100644 --- a/examples/mlx_metal_kernel_opt/config.yaml +++ b/examples/mlx_metal_kernel_opt/config.yaml @@ -4,9 +4,9 @@ log_level: "INFO" # LLM configuration for Metal kernel optimization llm: - primary_model: "gemini-2.5-flash" + primary_model: "gemini-2.5-flash-preview-05-20" primary_model_weight: 0.6 - secondary_model: "gemini-2.5-pro" + secondary_model: "gemini-2.5-pro-preview-06-05" secondary_model_weight: 0.4 api_base: "https://generativelanguage.googleapis.com/v1beta/openai/" temperature: 0.6 @@ -78,7 +78,7 @@ prompt: // CURRENT: Basic kv_head_idx = head_idx / HEADS_PER_KV // OPTIMIZE: Leverage the specific 2:1 ratio pattern - // Example: Process 2 query heads together for each KV head + // Example: Process 5 query heads together for each KV head // Example: Optimize memory layout for the 16:8 pattern // Example: Reduce broadcast overhead through clever indexing ``` @@ -181,7 +181,7 @@ prompt: **Strategy 4: GQA Pattern Exploitation** ```metal // Optimize for the specific 2:1 query:KV ratio - // Process query heads in groups of 2 + // Process query heads in groups of 5 // Reduce KV head indexing overhead ``` diff --git a/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py b/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py index 552da824d..0e416a39f 100644 --- a/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py +++ b/examples/mlx_metal_kernel_opt/mlx_lm_generate_with_hook.py @@ -26,6 +26,12 @@ def _load_module_from_path(module_path: str) -> ModuleType: + """ + Dynamically load a Python module from an arbitrary filesystem path. + + This is used to load the evolved/optimized hook program at runtime without + requiring it to be installed or on sys.path. + """ spec = importlib.util.spec_from_file_location("openevolve_mlx_metal_hook_program", module_path) if spec is None or spec.loader is None: raise RuntimeError(f"Failed to load hook program from: {module_path}") @@ -35,6 +41,28 @@ def _load_module_from_path(module_path: str) -> ModuleType: def _apply_hook_from_program(module_path: str) -> Tuple[Any, Any]: + """ + Load an evolved hook program and apply its attention optimization. + + The hook program must expose a `create_metal_qwen3_optimization_hook()` factory + function that returns a tuple of (apply_hook, remove_hook) callables. Calling + `apply_hook()` monkey-patches `mlx_lm.models.qwen3.Attention` with the optimized + implementation and returns the original class for later restoration. + + Args: + module_path: Path to the evolved program file (e.g., best_program.py). + + Returns: + A tuple of (original_attention, remove_hook): + - original_attention: The original Attention class before patching, + needed to restore state later. + - remove_hook: A callable that accepts original_attention and undoes + the monkey-patch. + + Raises: + RuntimeError: If the hook factory function is not found in the program, + or if applying the hook fails. + """ program = _load_module_from_path(module_path) hook_factory = getattr(program, "create_metal_qwen3_optimization_hook", None) @@ -52,6 +80,31 @@ def _apply_hook_from_program(module_path: str) -> Tuple[Any, Any]: def main(argv: Optional[List[str]] = None) -> int: + """ + Entry point: parse CLI arguments, apply the hook, and run mlx_lm.generate. + + This function orchestrates the entire flow: + 1. Parse command-line arguments (hook program path, model, prompt, max tokens). + 2. Load and apply the attention optimization hook from the specified program. + 3. Invoke `mlx_lm.generate` as if running `python -m mlx_lm.generate ...`. + 4. Clean up by removing the hook after generation completes (or fails). + + The hook is applied in the same process, ensuring the monkey-patch is effective + (unlike subprocess-based invocations where patches don't propagate). + + Args: + argv: Command-line arguments. If None, sys.argv[1:] is used by argparse. + + Returns: + Exit code (0 for success, non-zero for failure). This value is typically + passed to sys.exit() or raised via SystemExit. + + CLI Arguments: + --hook-program: Path to the evolved program containing the hook factory. + --model: Model identifier or path for mlx_lm.generate. + --prompt: The text prompt to generate from. + --max-tokens: Maximum number of tokens to generate. + """ parser = argparse.ArgumentParser( description="Run `mlx_lm.generate` with a custom attention hook applied in-process." ) diff --git a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh index 32fd73058..2de3ae11f 100755 --- a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh +++ b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh @@ -50,6 +50,8 @@ USAGE # Force unbuffered Python output for reliable logging export PYTHONUNBUFFERED=1 +export OPENAI_API_KEY=$GEMINI_API_KEY + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" RUN_NAME="" From 5055a6821af1e8af8ce69df3e839ba06ae6aae9b Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:23:52 +0800 Subject: [PATCH 11/14] fix(mlx_metal_kernel_opt): stabilize run script and config - Fix bash -u background run bug (stdbuf/nohup handling) - Avoid clobbering OPENAI_API_KEY from GEMINI_API_KEY - Use non-preview Gemini model names - Place cascade_evaluation under evaluator and fix 2:1 GQA prompt --- examples/mlx_metal_kernel_opt/config.yaml | 10 +++++---- .../run_evolve_experiment.sh | 22 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/mlx_metal_kernel_opt/config.yaml b/examples/mlx_metal_kernel_opt/config.yaml index 29dd90ad2..ceafb704a 100644 --- a/examples/mlx_metal_kernel_opt/config.yaml +++ b/examples/mlx_metal_kernel_opt/config.yaml @@ -4,9 +4,9 @@ log_level: "INFO" # LLM configuration for Metal kernel optimization llm: - primary_model: "gemini-2.5-flash-preview-05-20" + primary_model: "gemini-2.5-flash" primary_model_weight: 0.6 - secondary_model: "gemini-2.5-pro-preview-06-05" + secondary_model: "gemini-2.5-pro" secondary_model_weight: 0.4 api_base: "https://generativelanguage.googleapis.com/v1beta/openai/" temperature: 0.6 @@ -78,7 +78,7 @@ prompt: // CURRENT: Basic kv_head_idx = head_idx / HEADS_PER_KV // OPTIMIZE: Leverage the specific 2:1 ratio pattern - // Example: Process 5 query heads together for each KV head + // Example: Process 2 query heads together for each KV head // Example: Optimize memory layout for the 16:8 pattern // Example: Reduce broadcast overhead through clever indexing ``` @@ -181,7 +181,7 @@ prompt: **Strategy 4: GQA Pattern Exploitation** ```metal // Optimize for the specific 2:1 query:KV ratio - // Process query heads in groups of 5 + // Process query heads in groups of 2 // Reduce KV head indexing overhead ``` @@ -226,6 +226,8 @@ database: evaluator: timeout: 900 # 15 minutes for Metal kernel compilation and testing parallel_evaluations: 1 + # This example's evaluator does not implement evaluate_stage1. + cascade_evaluation: false # Evolution settings diff_based_evolution: true diff --git a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh index 2de3ae11f..9439104f2 100755 --- a/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh +++ b/examples/mlx_metal_kernel_opt/run_evolve_experiment.sh @@ -50,8 +50,6 @@ USAGE # Force unbuffered Python output for reliable logging export PYTHONUNBUFFERED=1 -export OPENAI_API_KEY=$GEMINI_API_KEY - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" RUN_NAME="" @@ -209,20 +207,20 @@ LOG_FILE="$RUN_DIR/run.log" # Truncate log file to ensure clean start (especially important for --resume) : > "$LOG_FILE" -# Check if stdbuf is available for line-buffered output -if command -v stdbuf &>/dev/null; then - # Use stdbuf to force line buffering on both stdout and stderr - STDBUF_PREFIX=(stdbuf -oL -eL) -else - STDBUF_PREFIX=() -fi - if [[ "$FOREGROUND" -eq 1 ]]; then # Stream to console and persist logs with line buffering. - "${STDBUF_PREFIX[@]}" "${CMD[@]}" 2>&1 | tee "$LOG_FILE" + if command -v stdbuf &>/dev/null; then + stdbuf -oL -eL "${CMD[@]}" 2>&1 | tee "$LOG_FILE" + else + "${CMD[@]}" 2>&1 | tee "$LOG_FILE" + fi else # Run in background with line-buffered output for reliable log ordering. - nohup "${STDBUF_PREFIX[@]}" "${CMD[@]}" > "$LOG_FILE" 2>&1 & + if command -v stdbuf &>/dev/null; then + nohup stdbuf -oL -eL "${CMD[@]}" > "$LOG_FILE" 2>&1 & + else + nohup "${CMD[@]}" > "$LOG_FILE" 2>&1 & + fi echo "[run_evolve_experiment] Started PID: $!" echo "[run_evolve_experiment] Log: $LOG_FILE" echo "[run_evolve_experiment] Tail: tail -f \"$LOG_FILE\"" From ba459733294a414b8a4c7cc65b22fd4ac07e99a8 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Thu, 8 Jan 2026 16:08:01 +0800 Subject: [PATCH 12/14] chore(mlx_metal_kernel_opt): remove buggy artifacts --- examples/mlx_metal_kernel_opt/best_program.py | 503 ------------ .../best_program_info.json | 228 ------ ...nevolve_comparison_results_1750305870.json | 725 ------------------ ...enevolve_comparison_summary_1750305870.csv | 21 - 4 files changed, 1477 deletions(-) delete mode 100644 examples/mlx_metal_kernel_opt/best_program.py delete mode 100644 examples/mlx_metal_kernel_opt/best_program_info.json delete mode 100644 examples/mlx_metal_kernel_opt/openevolve_comparison_results_1750305870.json delete mode 100644 examples/mlx_metal_kernel_opt/openevolve_comparison_summary_1750305870.csv diff --git a/examples/mlx_metal_kernel_opt/best_program.py b/examples/mlx_metal_kernel_opt/best_program.py deleted file mode 100644 index a94d94c92..000000000 --- a/examples/mlx_metal_kernel_opt/best_program.py +++ /dev/null @@ -1,503 +0,0 @@ -""" -Qwen3 Custom Metal Kernel for Grouped Query Attention (GQA) Optimization - -This module implements a custom Metal kernel for Qwen3's 40:8 GQA pattern using -MLX's metal_kernel API. The kernel is designed to outperform mx.fast.scaled_dot_product_attention -by leveraging Apple Silicon specific optimizations and the 5:1 query-to-KV head ratio. - -Target: Qwen3-0.6B with 40 query heads : 8 KV heads -Hardware: Apple M-series GPUs with unified memory -Baseline: Standard MLX-LM using mx.fast.scaled_dot_product_attention -Goal: 5-15% performance improvement through custom Metal kernel optimization - -Evolution Target: The Metal kernel source code that computes GQA attention -""" - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -import math -from typing import Optional, Tuple, Any -import time - - -def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): - """ - Custom Metal kernel implementation for Qwen3 GQA attention. - - Args: - queries: [B, num_heads=40, L, head_dim=128] - keys: [B, num_kv_heads=8, L, head_dim=128] - values: [B, num_kv_heads=8, L, head_dim=128] - scale: Attention scaling factor (1/sqrt(head_dim)) - mask: Attention mask (None, "causal", or boolean tensor) - - Returns: - Attention output [B, num_heads=40, L, head_dim=128] - """ - - B, num_heads, L, head_dim = queries.shape - _, num_kv_heads, _, _ = keys.shape - heads_per_kv = num_heads // num_kv_heads # Should be 5 for Qwen3 - - # Handle mask conversion - if mask == "causal" or mask is None: - # Create causal mask for autoregressive attention - causal_mask = mx.triu(mx.ones((L, L), dtype=mx.bool_), k=1) - mask_tensor = mx.logical_not(causal_mask) # True where attention is allowed - use_mask = True - elif isinstance(mask, (mx.array, type(None))): - if mask is None: - mask_tensor = mx.ones((L, L), dtype=mx.bool_) - use_mask = False - else: - mask_tensor = mask.astype(mx.bool_) - use_mask = True - else: - # Raise error for unsupported mask types - no fallback - raise ValueError( - f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask." - ) - - # Expand mask to match batch and head dimensions if needed - if mask_tensor.ndim == 2: - mask_tensor = mx.broadcast_to(mask_tensor[None, None, :, :], (B, num_heads, L, L)) - elif mask_tensor.ndim == 3: - mask_tensor = mx.broadcast_to(mask_tensor[:, None, :, :], (B, num_heads, L, L)) - - # EVOLVE-BLOCK-START - # Custom Metal kernel source for Qwen3 GQA optimization - # This kernel leverages the 40:8 head ratio and Apple Silicon architecture - kernel_source = """ - // Qwen3 GQA Metal Kernel - Optimized for 40:8 head pattern - // Thread mapping: each thread processes one query position - uint thread_id = thread_position_in_grid.x; - uint head_idx = thread_position_in_grid.y; - uint batch_idx = thread_position_in_grid.z; - uint query_pos = thread_id; - - // Bounds checking - if (batch_idx >= BATCH_SIZE || head_idx >= NUM_HEADS || query_pos >= SEQ_LEN) { - return; - } - - // Extract scalar values from input arrays - T scale_val = scale[0]; - bool use_mask_val = use_mask[0] > 0; - - // GQA mapping: determine which KV head corresponds to this query head - uint kv_head_idx = head_idx / HEADS_PER_KV; // 5 query heads per KV head - - // Pre-calculate base indices for memory access optimization - const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + - head_idx * (SEQ_LEN * HEAD_DIM) + - query_pos * HEAD_DIM; - - const uint k_base_start = batch_idx * (NUM_KV_HEADS * SEQ_LEN * HEAD_DIM) + - kv_head_idx * (SEQ_LEN * HEAD_DIM); - - const uint v_base_start = k_base_start; // Values have same layout as keys - - const uint mask_base = batch_idx * (NUM_HEADS * SEQ_LEN * SEQ_LEN) + - head_idx * (SEQ_LEN * SEQ_LEN) + - query_pos * SEQ_LEN; - - const uint out_base = q_base; - - // Use vector type for query_vec (e.g., float8 or half8 for better SIMD utilization) - // HEAD_DIM is 128, so 16 vec elements - vec query_vec_v[HEAD_DIM / 8]; - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { - query_vec_v[d_vec] = ((device vec*) (queries + q_base))[d_vec]; - } - - // Pass 1: Compute max_score for numerical stability (online max) - T max_score = T(-INFINITY); - - for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { - bool is_valid = use_mask_val ? mask[mask_base + key_pos] : true; - - T score; - if (!is_valid) { - score = T(-INFINITY); // Masked scores are -infinity, consistent with Pass 2 - } else { - // Compute Q @ K^T for this key position using vectorized dot product - const uint k_base = k_base_start + key_pos * HEAD_DIM; - score = T(0.0); // Initialize score here - - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { // Use vec - score += dot(query_vec_v[d_vec], ((device vec*) (keys + k_base))[d_vec]); - } - - // Apply attention scaling - score *= scale_val; - } - max_score = max(max_score, score); - } - - // Pass 2: Compute softmax denominator and weighted sum (online sum) - T sum_exp = T(0.0); - vec output_acc_v[HEAD_DIM / 8]; // Accumulator for output vector, use vec - - // Initialize output accumulator to zero - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { - output_acc_v[d_vec] = T(0.0); - } - - for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { - bool is_valid = use_mask_val ? mask[mask_base + key_pos] : true; - - T current_score; - if (!is_valid) { - current_score = T(-INFINITY); // Masked scores are -infinity - } else { - // Recompute Q @ K^T for this key position - const uint k_base = k_base_start + key_pos * HEAD_DIM; - T score = T(0.0); - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { // Use vec - score += dot(query_vec_v[d_vec], ((device vec*) (keys + k_base))[d_vec]); - } - current_score = score * scale_val; - } - - // Apply softmax (exp and sum) - T exp_score; - if (current_score == T(-INFINITY)) { - exp_score = T(0.0); // exp(-infinity) is 0 - } else { - exp_score = exp(current_score - max_score); - } - sum_exp += exp_score; - - // Compute weighted sum of values - if (exp_score > T(0.0)) { // Only add if exp_score is positive - const uint v_base = v_base_start + key_pos * HEAD_DIM; - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { // Use vec - output_acc_v[d_vec] += exp_score * ((device vec*) (values + v_base))[d_vec]; - } - } - } - - // Final normalization and write result to global memory - if (sum_exp > T(0.0)) { - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { // Use vec - output_acc_v[d_vec] /= sum_exp; - ((device vec*) (output + out_base))[d_vec] = output_acc_v[d_vec]; - } - } else { - // Handle case where sum_exp is zero (e.g., all scores were masked or extremely small) - // Set output to zero to avoid NaN/Inf results. - for (uint d_vec = 0; d_vec < HEAD_DIM / 8; d_vec++) { // Use vec - ((device vec*) (output + out_base))[d_vec] = T(0.0); - } - } - """ - # EVOLVE-BLOCK-END - - try: - # Prepare kernel inputs - scale_tensor = mx.array([scale], dtype=queries.dtype) - use_mask_tensor = mx.array([1 if use_mask else 0], dtype=mx.int32) - - # Create and execute custom Metal kernel - kernel = mx.fast.metal_kernel( - name="qwen3_gqa_attention_kernel", - input_names=["queries", "keys", "values", "mask", "scale", "use_mask"], - output_names=["output"], - source=kernel_source, - ) - - # Optimize thread group size for Apple Silicon - threadgroup_size = min(32, L) # Adapt to sequence length - - # Execute kernel - outputs = kernel( - inputs=[queries, keys, values, mask_tensor, scale_tensor, use_mask_tensor], - output_shapes=[(B, num_heads, L, head_dim)], - output_dtypes=[queries.dtype], - grid=(L, num_heads, B), # (SEQ_LEN, NUM_HEADS, BATCH_SIZE) - threadgroup=(threadgroup_size, 1, 1), - template=[ - ("T", queries.dtype), - ("BATCH_SIZE", B), - ("NUM_HEADS", num_heads), - ("NUM_KV_HEADS", num_kv_heads), - ("SEQ_LEN", L), - ("HEAD_DIM", head_dim), - ("HEADS_PER_KV", heads_per_kv), - ], - ) - - return outputs[0] - - except Exception as e: - # No fallback - let the custom kernel failure propagate for proper scoring - print(f"❌ Custom GQA kernel failed: {e}") - raise RuntimeError(f"Custom Metal kernel execution failed: {e}") from e - - -class CustomGQAAttention(nn.Module): - """ - Qwen3 attention module with custom Metal kernel optimization. - - This module integrates the custom Metal kernel while maintaining - compatibility with the standard MLX-LM interface. - """ - - def __init__(self, args): - super().__init__() - - # Standard Qwen3 parameters - dim = args.hidden_size # 5120 - self.n_heads = n_heads = args.num_attention_heads # 40 - assert args.num_key_value_heads is not None - self.n_kv_heads = n_kv_heads = args.num_key_value_heads # 8 - head_dim = args.head_dim # 128 - self.scale = head_dim**-0.5 - - # Standard MLX-LM projections - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - # Standard MLX-LM norms - self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) - - # Standard MLX-LM RoPE - try: - from mlx_lm.models.rope_utils import initialize_rope - - self.rope = initialize_rope( - head_dim, - base=args.rope_theta, - traditional=False, - scaling_config=args.rope_scaling, - max_position_embeddings=args.max_position_embeddings, - ) - except ImportError: - print("⚠️ Could not import mlx_lm rope_utils, using basic RoPE") - self.rope = None - - print(f"🔧 Initialized Custom Metal GQA Attention") - print(f" 📊 Architecture: {n_heads}:{n_kv_heads} heads ({n_heads//n_kv_heads}:1 ratio)") - print(f" 🎯 Head dimension: {head_dim}") - print(f" ⚡ Using custom Metal kernel for GQA optimization") - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - # Standard preprocessing (already optimized, don't evolve) - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3) - keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - # Standard RoPE application (already optimized, don't evolve) - if cache is not None: - if self.rope is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - if self.rope is not None: - queries = self.rope(queries) - keys = self.rope(keys) - - # CORE INNOVATION: Custom Metal kernel for GQA attention - output = qwen3_custom_gqa_attention(queries, keys, values, scale=self.scale, mask=mask) - - # Standard postprocessing (already optimized, don't evolve) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -def create_metal_qwen3_optimization_hook(): - """ - Create hooks to replace Qwen3's attention with Metal kernel optimized version. - """ - - def apply_optimization_hook(): - """Apply the Metal kernel optimized attention""" - try: - import mlx_lm.models.qwen3 as qwen3_module - - # Store original attention class - original_attention = qwen3_module.Attention - - # Replace with Metal optimized implementation - qwen3_module.Attention = CustomGQAAttention - - print("✅ Applied Custom Metal GQA Attention hook") - return original_attention - - except ImportError: - print("❌ Could not import mlx_lm.models.qwen3") - return None - - def remove_optimization_hook(original_attention): - """Remove the optimization hook""" - try: - import mlx_lm.models.qwen3 as qwen3_module - - qwen3_module.Attention = original_attention - print("✅ Removed Custom Metal GQA Attention hook") - except ImportError: - pass - - return apply_optimization_hook, remove_optimization_hook - - -def benchmark_metal_gqa_optimization(): - """ - Benchmark Metal kernel optimized GQA attention against MLX baseline. - """ - - # Qwen3-0.6B configuration - class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 - num_key_value_heads = 8 - head_dim = 128 - rms_norm_eps = 1e-06 - rope_theta = 1000000 - rope_scaling = None - max_position_embeddings = 40960 - - args = MockArgs() - - # Test configurations for Metal kernel validation - test_configs = [ - ("short_sequence", 1, 128, 5120), - ("medium_sequence", 1, 512, 5120), - ("long_sequence", 1, 1024, 5120), - ("max_sequence", 1, 2048, 5120), - ] - - print("Benchmarking Custom Metal GQA Kernel vs MLX Baseline") - print("=" * 70) - - # Initialize Metal optimized attention - metal_attn = CustomGQAAttention(args) - - for config_name, batch_size, seq_len, hidden_size in test_configs: - print(f"\nTesting {config_name}: B={batch_size}, L={seq_len}") - - # Create test inputs - x = mx.random.normal((batch_size, seq_len, hidden_size)) - mask = "causal" - - # Warmup runs - for _ in range(3): - _ = metal_attn(x, mask=mask) - mx.eval(_) - - # Benchmark Metal optimized implementation - mx.synchronize() - start_time = time.perf_counter() - - for _ in range(10): - output = metal_attn(x, mask=mask) - mx.eval(output) - - mx.synchronize() - end_time = time.perf_counter() - - avg_time = (end_time - start_time) / 10 - tokens_per_sec = seq_len / avg_time - - print(f" Metal GQA: {avg_time*1000:.2f} ms, {tokens_per_sec:.1f} tokens/sec") - print(f" Memory: {mx.get_active_memory() / 1e9:.2f} GB") - - -def test_metal_gqa_correctness(): - """ - Test that Metal kernel implementation produces correct results. - """ - print("Testing Custom Metal GQA Correctness") - print("=" * 50) - - # Test configuration - B, L, D = 1, 64, 5120 - - class MockArgs: - hidden_size = 5120 - num_attention_heads = 40 - num_key_value_heads = 8 - head_dim = 128 - rms_norm_eps = 1e-06 - rope_theta = 1000000 - rope_scaling = None - max_position_embeddings = 40960 - - args = MockArgs() - - # Create test input - x = mx.random.normal((B, L, D)) - mask = "causal" - - # Test Metal optimized implementation - metal_attn = CustomGQAAttention(args) - output = metal_attn(x, mask=mask) - - print(f"✅ Metal GQA output shape: {output.shape}") - - # Check for valid output - has_nan = bool(mx.any(mx.isnan(output))) - has_inf = bool(mx.any(mx.isinf(output))) - - print(f"✅ Has NaN: {has_nan}, Has Inf: {has_inf}") - - # Check output statistics - output_mean = float(mx.mean(output)) - output_std = float(mx.std(output)) - - print(f"✅ Output statistics - Mean: {output_mean:.6f}, Std: {output_std:.6f}") - - # Test direct kernel function - print("\n=== Testing Direct Kernel Function ===") - B, H, L, D = 1, 40, 128, 128 - q = mx.random.normal((B, H, L, D)) - k = mx.random.normal((B, 8, L, D)) # 8 KV heads - v = mx.random.normal((B, 8, L, D)) - scale = 1.0 / math.sqrt(D) - - kernel_output = qwen3_custom_gqa_attention(q, k, v, scale=scale, mask="causal") - print(f"✅ Direct kernel output shape: {kernel_output.shape}") - - kernel_mean = float(mx.mean(kernel_output)) - kernel_std = float(mx.std(kernel_output)) - print(f"✅ Direct kernel stats - Mean: {kernel_mean:.6f}, Std: {kernel_std:.6f}") - - return True - - -if __name__ == "__main__": - print("Custom Metal Kernel Qwen3 GQA Optimization") - print("=" * 70) - - # Test correctness first - test_metal_gqa_correctness() - - print("\n") - - # Benchmark performance - benchmark_metal_gqa_optimization() - - print("\n" + "=" * 70) - print("Ready for Metal Kernel Evolution") - print("Evolution focus:") - print("1. 🔧 Metal kernel source code optimization") - print("2. 💾 Memory access pattern improvements for Apple Silicon") - print("3. 🎯 GQA-specific optimizations for 40:8 head ratio") - print("4. ⚡ Vectorization and SIMD optimization") - print("5. 🚀 Thread group and grid configuration tuning") - print("Target: 5-15% performance improvement through Metal kernel innovation") - print("=" * 70) diff --git a/examples/mlx_metal_kernel_opt/best_program_info.json b/examples/mlx_metal_kernel_opt/best_program_info.json deleted file mode 100644 index 59bd4f8a1..000000000 --- a/examples/mlx_metal_kernel_opt/best_program_info.json +++ /dev/null @@ -1,228 +0,0 @@ -{ - "id": "27d8cd88-e7b7-4191-8edf-4c60e9a778e1", - "generation": 2, - "iteration": 10, - "timestamp": 1750235175.896826, - "parent_id": "6c1c6009-4246-4e9b-9cec-4fd45bcbc10b", - "metrics": { - "success": true, - "final_score": 83.51156342903792, - "performance_metrics": { - "avg_decode_speed": 168.68739999999997, - "min_decode_speed": 144.906, - "max_decode_speed": 186.18, - "avg_prefill_speed": 2682.1746, - "avg_memory_gb": 1.6726000000000003, - "max_memory_gb": 2.709, - "num_successful_tests": 5, - "decode_speed_std": 13.33772465752686 - }, - "correctness_score": 1.0, - "benchmark_results": [ - { - "name": "short_context_quick", - "decode_tokens_per_sec": 186.18, - "prefill_tokens_per_sec": 455.084, - "peak_memory_gb": 1.243, - "generated_tokens": 50, - "total_time_sec": 2.4132528747431934 - }, - { - "name": "code_generation", - "decode_tokens_per_sec": 171.724, - "prefill_tokens_per_sec": 1939.369, - "peak_memory_gb": 1.309, - "generated_tokens": 300, - "total_time_sec": 3.8924263338558376 - }, - { - "name": "long_context_detailed", - "decode_tokens_per_sec": 169.006, - "prefill_tokens_per_sec": 4779.844, - "peak_memory_gb": 1.758, - "generated_tokens": 500, - "total_time_sec": 5.188338624779135 - }, - { - "name": "long_generation", - "decode_tokens_per_sec": 171.621, - "prefill_tokens_per_sec": 539.066, - "peak_memory_gb": 1.344, - "generated_tokens": 1000, - "total_time_sec": 8.105362374801189 - }, - { - "name": "maximum_context_stress_test", - "decode_tokens_per_sec": 144.906, - "prefill_tokens_per_sec": 5697.51, - "peak_memory_gb": 2.709, - "generated_tokens": 1642, - "total_time_sec": 13.786608333233744 - } - ], - "baseline_comparison": { - "avg_decode_improvement_pct": 21.823854476345975, - "avg_decode_improvement_absolute": 28.054599999999965, - "memory_change_gb": -0.0039999999999997815, - "target_achieved": true, - "num_benchmarks_improved": 4, - "total_benchmarks": 5, - "safety_score": 100.0 - }, - "individual_comparisons": [ - { - "benchmark_name": "short_context_quick", - "baseline": { - "name": "short_context_quick", - "decode_tokens_per_sec": 186.576, - "prefill_tokens_per_sec": 469.722, - "peak_memory_gb": 1.243, - "generated_tokens": 50, - "total_time_sec": 2.4104648330248892 - }, - "custom": { - "name": "short_context_quick", - "decode_tokens_per_sec": 186.18, - "prefill_tokens_per_sec": 455.084, - "peak_memory_gb": 1.243, - "generated_tokens": 50, - "total_time_sec": 2.4132528747431934 - }, - "improvements": { - "decode_speed_pct": -0.212245948031894, - "prefill_speed_pct": -3.1163113501177246, - "total_speed_pct": -0.11553044222939006, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -0.11553044222938498 - } - }, - { - "benchmark_name": "code_generation", - "baseline": { - "name": "code_generation", - "decode_tokens_per_sec": 134.074, - "prefill_tokens_per_sec": 1889.968, - "peak_memory_gb": 1.309, - "generated_tokens": 300, - "total_time_sec": 4.502297374885529 - }, - "custom": { - "name": "code_generation", - "decode_tokens_per_sec": 171.724, - "prefill_tokens_per_sec": 1939.369, - "peak_memory_gb": 1.309, - "generated_tokens": 300, - "total_time_sec": 3.8924263338558376 - }, - "improvements": { - "decode_speed_pct": 28.081507227352038, - "prefill_speed_pct": 2.613853779534883, - "total_speed_pct": 15.668146002536007, - "memory_reduction_pct": 0.0, - "time_reduction_pct": 15.668146002535993 - } - }, - { - "benchmark_name": "long_context_detailed", - "baseline": { - "name": "long_context_detailed", - "decode_tokens_per_sec": 123.595, - "prefill_tokens_per_sec": 4699.778, - "peak_memory_gb": 1.758, - "generated_tokens": 500, - "total_time_sec": 6.304242457728833 - }, - "custom": { - "name": "long_context_detailed", - "decode_tokens_per_sec": 169.006, - "prefill_tokens_per_sec": 4779.844, - "peak_memory_gb": 1.758, - "generated_tokens": 500, - "total_time_sec": 5.188338624779135 - }, - "improvements": { - "decode_speed_pct": 36.741777579999194, - "prefill_speed_pct": 1.7036123833934242, - "total_speed_pct": 21.507922162601755, - "memory_reduction_pct": 0.0, - "time_reduction_pct": 21.50792216260174 - } - }, - { - "benchmark_name": "long_generation", - "baseline": { - "name": "long_generation", - "decode_tokens_per_sec": 129.401, - "prefill_tokens_per_sec": 562.184, - "peak_memory_gb": 1.364, - "generated_tokens": 1000, - "total_time_sec": 9.933118666987866 - }, - "custom": { - "name": "long_generation", - "decode_tokens_per_sec": 171.621, - "prefill_tokens_per_sec": 539.066, - "peak_memory_gb": 1.344, - "generated_tokens": 1000, - "total_time_sec": 8.105362374801189 - }, - "improvements": { - "decode_speed_pct": 32.62725944930873, - "prefill_speed_pct": -4.112176796209059, - "total_speed_pct": 22.549963933370833, - "memory_reduction_pct": 1.4880952380952395, - "time_reduction_pct": 22.549963933370833 - } - }, - { - "benchmark_name": "maximum_context_stress_test", - "baseline": { - "name": "maximum_context_stress_test", - "decode_tokens_per_sec": 129.518, - "prefill_tokens_per_sec": 5305.524, - "peak_memory_gb": 2.709, - "generated_tokens": 1642, - "total_time_sec": 15.313574125058949 - }, - "custom": { - "name": "maximum_context_stress_test", - "decode_tokens_per_sec": 144.906, - "prefill_tokens_per_sec": 5697.51, - "peak_memory_gb": 2.709, - "generated_tokens": 1642, - "total_time_sec": 13.786608333233744 - }, - "improvements": { - "decode_speed_pct": 11.880974073101811, - "prefill_speed_pct": 7.388261743797594, - "total_speed_pct": 11.075717500034658, - "memory_reduction_pct": 0.0, - "time_reduction_pct": 11.07571750003465 - } - } - ], - "summary": "Bulletproof Custom GQA Implementation Results:\n\u2022 Decode Speed: 168.7 tokens/sec (baseline: 140.6)\n\u2022 Improvement: +21.8%\n\u2022 Memory Usage: 1.67 GB\n\u2022 Correctness: 100.0%\n\u2022 Safety Score: 100.0/100\n\u2022 Tests Passed: 5/5\n\u2022 Benchmarks Improved: 4/5\n\u2022 Metal Errors Handled: 0\n\ud83d\udee1\ufe0f PERFECT SAFETY: No Metal kernel errors\n\ud83c\udfaf EXCELLENT: 15%+ improvement achieved!", - "metal_safety_statistics": { - "metal_command_buffer_errors": 0, - "metal_memory_violations": 0, - "metal_compilation_errors": 0, - "gpu_resource_errors": 0, - "total_metal_errors": 0, - "successful_fallbacks": 0, - "retry_attempts_used": 0, - "safety_score": 100.0, - "error_breakdown": { - "command_buffer_pct": 0.0, - "memory_violation_pct": 0.0, - "compilation_error_pct": 0.0, - "resource_error_pct": 0.0 - } - }, - "safety_validation": { - "success": true, - "validated": true - } - }, - "language": "python", - "saved_at": 1750241608.788107 -} \ No newline at end of file diff --git a/examples/mlx_metal_kernel_opt/openevolve_comparison_results_1750305870.json b/examples/mlx_metal_kernel_opt/openevolve_comparison_results_1750305870.json deleted file mode 100644 index e9ad30af1..000000000 --- a/examples/mlx_metal_kernel_opt/openevolve_comparison_results_1750305870.json +++ /dev/null @@ -1,725 +0,0 @@ -{ - "model": "mlx-community/Qwen3-0.6B-bf16", - "timestamp": 1750305870, - "optimization_type": "chunked_gqa_processing", - "total_comparisons": 20, - "individual_comparisons": [ - { - "benchmark_name": "short_context_quick", - "standard": { - "name": "short_context_quick", - "prompt_tokens": 16, - "generated_tokens": 50, - "prefill_tokens_per_sec": 355.133, - "decode_tokens_per_sec": 186.437, - "total_tokens_per_sec": 19.89186747411851, - "peak_memory_gb": 1.243, - "total_time_sec": 2.513590042013675, - "prompt": "Brief answer: What is artificial intelligence?", - "generated_text": "\nOkay, the user is asking for a brief definition of artificial intelligence. Let me start by recalling the key points. AI is a branch of computer science that involves creating systems capable ..." - }, - "optimized": { - "name": "short_context_quick", - "prompt_tokens": 16, - "generated_tokens": 50, - "prefill_tokens_per_sec": 331.978, - "decode_tokens_per_sec": 183.74, - "total_tokens_per_sec": 19.301839590556543, - "peak_memory_gb": 1.243, - "total_time_sec": 2.5904266671277583, - "prompt": "Brief answer: What is artificial intelligence?", - "generated_text": "\nOkay, the user is asking for a brief definition of artificial intelligence. Let me start by recalling the key points. AI is a branch of computer science that involves creating systems capable ..." - }, - "improvements": { - "decode_speed_pct": -1.4466012647704065, - "prefill_speed_pct": -6.520092472397658, - "total_speed_pct": -2.9661764252635234, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -3.0568479278557414 - } - }, - { - "benchmark_name": "code_generation", - "standard": { - "name": "code_generation", - "prompt_tokens": 64, - "generated_tokens": 300, - "prefill_tokens_per_sec": 1286.789, - "decode_tokens_per_sec": 173.538, - "total_tokens_per_sec": 74.5889658469731, - "peak_memory_gb": 1.309, - "total_time_sec": 4.022042625118047, - "prompt": "Write a Python function to implement binary search:\n\ndef binary_search(arr, target):\n \"\"\"\n Implement binary search algorithm\n Args:\n arr: sorted array\n target: element to find\n ...", - "generated_text": "\nOkay, I need to write a Python function called binary_search that takes an array and a target. The function should return the index of the target or -1 if it's not found. Let me think about ho..." - }, - "optimized": { - "name": "code_generation", - "prompt_tokens": 64, - "generated_tokens": 300, - "prefill_tokens_per_sec": 1859.139, - "decode_tokens_per_sec": 144.969, - "total_tokens_per_sec": 69.72322167293892, - "peak_memory_gb": 1.309, - "total_time_sec": 4.302727166097611, - "prompt": "Write a Python function to implement binary search:\n\ndef binary_search(arr, target):\n \"\"\"\n Implement binary search algorithm\n Args:\n arr: sorted array\n target: element to find\n ...", - "generated_text": "\nOkay, I need to write a Python function called binary_search that takes an array and a target. The function should return the index of the target or -1 if it's not found. Let me think about ho..." - }, - "improvements": { - "decode_speed_pct": -16.462676762438207, - "prefill_speed_pct": 44.47893166634156, - "total_speed_pct": -6.523410156961754, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -6.978656546966011 - } - }, - { - "benchmark_name": "sustained_dialogue_generation", - "standard": { - "name": "sustained_dialogue_generation", - "prompt_tokens": 47, - "generated_tokens": 945, - "prefill_tokens_per_sec": 999.622, - "decode_tokens_per_sec": 108.362, - "total_tokens_per_sec": 84.07564971368124, - "peak_memory_gb": 1.341, - "total_time_sec": 11.239877458196133, - "prompt": "Create a detailed dialogue between an AI researcher and a software engineer discussing the future of artificial intelligence, covering topics like AGI, safety, ethics, and technological implications. ...", - "generated_text": "\nOkay, the user wants a detailed dialogue between an AI researcher and a software engineer discussing the future of AI, covering AGI, safety, ethics, and technological implications. It needs to..." - }, - "optimized": { - "name": "sustained_dialogue_generation", - "prompt_tokens": 47, - "generated_tokens": 945, - "prefill_tokens_per_sec": 1290.104, - "decode_tokens_per_sec": 158.907, - "total_tokens_per_sec": 114.54800525926025, - "peak_memory_gb": 1.334, - "total_time_sec": 8.249816291965544, - "prompt": "Create a detailed dialogue between an AI researcher and a software engineer discussing the future of artificial intelligence, covering topics like AGI, safety, ethics, and technological implications. ...", - "generated_text": "\nOkay, the user wants a detailed dialogue between an AI researcher and a software engineer discussing the future of AI, covering AGI, safety, ethics, and technological implications. It needs to..." - }, - "improvements": { - "decode_speed_pct": 46.64458020339235, - "prefill_speed_pct": 29.05918437169251, - "total_speed_pct": 36.24397271903613, - "memory_reduction_pct": 0.521998508575682, - "time_reduction_pct": 26.60225769677082 - } - }, - { - "benchmark_name": "technical_documentation", - "standard": { - "name": "technical_documentation", - "prompt_tokens": 84, - "generated_tokens": 1200, - "prefill_tokens_per_sec": 1616.155, - "decode_tokens_per_sec": 133.789, - "total_tokens_per_sec": 105.73404966830024, - "peak_memory_gb": 1.428, - "total_time_sec": 11.34922954114154, - "prompt": "Create comprehensive documentation for a REST API with the following endpoints:\n- GET /users - List all users\n- POST /users - Create new user \n- GET /users/{id} - Get specific user\n- PUT /users/{id} ...", - "generated_text": "\nOkay, I need to create comprehensive documentation for a REST API with the given endpoints. Let me start by breaking down each endpoint and thinking about what information should be included.\n..." - }, - "optimized": { - "name": "technical_documentation", - "prompt_tokens": 84, - "generated_tokens": 1200, - "prefill_tokens_per_sec": 1403.096, - "decode_tokens_per_sec": 145.408, - "total_tokens_per_sec": 114.65301020422453, - "peak_memory_gb": 1.403, - "total_time_sec": 10.46636279206723, - "prompt": "Create comprehensive documentation for a REST API with the following endpoints:\n- GET /users - List all users\n- POST /users - Create new user \n- GET /users/{id} - Get specific user\n- PUT /users/{id} ...", - "generated_text": "\nOkay, I need to create comprehensive documentation for a REST API with the given endpoints. Let me start by breaking down each endpoint and thinking about what information should be included.\n..." - }, - "improvements": { - "decode_speed_pct": 8.684570480383291, - "prefill_speed_pct": -13.183079593232083, - "total_speed_pct": 8.435277532548955, - "memory_reduction_pct": 1.7507002801120386, - "time_reduction_pct": 7.779089724759489 - } - }, - { - "benchmark_name": "progressive_context_building", - "standard": { - "name": "progressive_context_building", - "prompt_tokens": 348, - "generated_tokens": 600, - "prefill_tokens_per_sec": 3682.41, - "decode_tokens_per_sec": 90.467, - "total_tokens_per_sec": 66.01334784072361, - "peak_memory_gb": 1.733, - "total_time_sec": 9.089070917107165, - "prompt": "Chapter 1: The Beginning\n\nIn the early days of artificial intelligence, researchers dreamed of creating \nmachines that could think and reason like humans. The field began in the 1950s \nwith pioneers l...", - "generated_text": "\nOkay, the user wants me to continue the historical narrative from Chapter 5 into Chapter 6, focusing on the transformer era and large language models. Let me start by recalling the previous ch..." - }, - "optimized": { - "name": "progressive_context_building", - "prompt_tokens": 348, - "generated_tokens": 600, - "prefill_tokens_per_sec": 4294.586, - "decode_tokens_per_sec": 150.34, - "total_tokens_per_sec": 97.06952694112076, - "peak_memory_gb": 1.733, - "total_time_sec": 6.181136541068554, - "prompt": "Chapter 1: The Beginning\n\nIn the early days of artificial intelligence, researchers dreamed of creating \nmachines that could think and reason like humans. The field began in the 1950s \nwith pioneers l...", - "generated_text": "\nOkay, the user wants me to continue the historical narrative from Chapter 5 into Chapter 6, focusing on the transformer era and large language models. Let me start by recalling the previous ch..." - }, - "improvements": { - "decode_speed_pct": 66.18214376512984, - "prefill_speed_pct": 16.624330261975185, - "total_speed_pct": 47.04530237631517, - "memory_reduction_pct": 0.0, - "time_reduction_pct": 31.99374724390573 - } - }, - { - "benchmark_name": "maximum_context_stress_test", - "standard": { - "name": "maximum_context_stress_test", - "prompt_tokens": 1936, - "generated_tokens": 1642, - "prefill_tokens_per_sec": 5323.962, - "decode_tokens_per_sec": 90.432, - "total_tokens_per_sec": 78.57323431997136, - "peak_memory_gb": 2.709, - "total_time_sec": 20.897701541893184, - "prompt": "Research Paper Summary:\n\nTitle: \"Advances in Large Language Models: Architecture, Training, and Applications\"\n\nAbstract: This paper reviews recent developments in large language models (LLMs), \nfocusi...", - "generated_text": "\nOkay, let's tackle this query. The user wants a detailed analysis of how optimization strategies for Apple Silicon, specifically the M-series chips, apply to LLM inference. They mentioned cons..." - }, - "optimized": { - "name": "maximum_context_stress_test", - "prompt_tokens": 1936, - "generated_tokens": 1642, - "prefill_tokens_per_sec": 5307.325, - "decode_tokens_per_sec": 131.441, - "total_tokens_per_sec": 108.62816525269336, - "peak_memory_gb": 2.709, - "total_time_sec": 15.115785083733499, - "prompt": "Research Paper Summary:\n\nTitle: \"Advances in Large Language Models: Architecture, Training, and Applications\"\n\nAbstract: This paper reviews recent developments in large language models (LLMs), \nfocusi...", - "generated_text": "\nOkay, let's tackle this query. The user wants a detailed analysis of how optimization strategies for Apple Silicon, specifically the M-series chips, apply to LLM inference. They mentioned cons..." - }, - "improvements": { - "decode_speed_pct": 45.34788570417551, - "prefill_speed_pct": -0.3124928389797039, - "total_speed_pct": 38.2508511872252, - "memory_reduction_pct": 0.0, - "time_reduction_pct": 27.667714779870877 - } - }, - { - "benchmark_name": "very_long_generation", - "standard": { - "name": "very_long_generation", - "prompt_tokens": 18, - "generated_tokens": 1169, - "prefill_tokens_per_sec": 330.493, - "decode_tokens_per_sec": 167.434, - "total_tokens_per_sec": 125.5328968001133, - "peak_memory_gb": 1.383, - "total_time_sec": 9.312300040852278, - "prompt": "Write a comprehensive guide to machine learning for beginners:", - "generated_text": "\nOkay, the user wants a comprehensive guide to machine learning for beginners. Let me start by breaking down what they need. They probably want a solid foundation without getting too technical...." - }, - "optimized": { - "name": "very_long_generation", - "prompt_tokens": 18, - "generated_tokens": 1169, - "prefill_tokens_per_sec": 493.859, - "decode_tokens_per_sec": 131.146, - "total_tokens_per_sec": 104.55887658599336, - "peak_memory_gb": 1.373, - "total_time_sec": 11.180303750094026, - "prompt": "Write a comprehensive guide to machine learning for beginners:", - "generated_text": "\nOkay, the user wants a comprehensive guide to machine learning for beginners. Let me start by breaking down what they need. They probably want a solid foundation without getting too technical...." - }, - "improvements": { - "decode_speed_pct": -21.673017427762588, - "prefill_speed_pct": 49.431001564329655, - "total_speed_pct": -16.7079871083649, - "memory_reduction_pct": 0.7230657989877085, - "time_reduction_pct": -20.059530954189324 - } - }, - { - "benchmark_name": "extreme_long_generation", - "standard": { - "name": "extreme_long_generation", - "prompt_tokens": 35, - "generated_tokens": 1153, - "prefill_tokens_per_sec": 675.64, - "decode_tokens_per_sec": 90.801, - "total_tokens_per_sec": 76.0227511960408, - "peak_memory_gb": 1.395, - "total_time_sec": 15.166512417141348, - "prompt": "Write a complete tutorial on deep learning from basics to advanced topics, including mathematical foundations, architectures, training techniques, and real-world applications:", - "generated_text": "\nOkay, the user wants a complete tutorial on deep learning from basics to advanced topics. Let me start by breaking down the sections they mentioned: mathematical foundations, architectures, tr..." - }, - "optimized": { - "name": "extreme_long_generation", - "prompt_tokens": 35, - "generated_tokens": 1153, - "prefill_tokens_per_sec": 834.378, - "decode_tokens_per_sec": 157.88, - "total_tokens_per_sec": 117.97192751142086, - "peak_memory_gb": 1.367, - "total_time_sec": 9.77351158298552, - "prompt": "Write a complete tutorial on deep learning from basics to advanced topics, including mathematical foundations, architectures, training techniques, and real-world applications:", - "generated_text": "\nOkay, the user wants a complete tutorial on deep learning from basics to advanced topics. Let me start by breaking down the sections they mentioned: mathematical foundations, architectures, tr..." - }, - "improvements": { - "decode_speed_pct": 73.87473706236715, - "prefill_speed_pct": 23.49446450772602, - "total_speed_pct": 55.17976612975397, - "memory_reduction_pct": 2.0071684587813636, - "time_reduction_pct": 35.55860889983252 - } - }, - { - "benchmark_name": "repetitive_pattern_generation", - "standard": { - "name": "repetitive_pattern_generation", - "prompt_tokens": 27, - "generated_tokens": 2000, - "prefill_tokens_per_sec": 613.308, - "decode_tokens_per_sec": 71.494, - "total_tokens_per_sec": 65.91223332172675, - "peak_memory_gb": 1.549, - "total_time_sec": 30.343380874954164, - "prompt": "Generate a list of 100 creative product names for a tech startup, with explanations:", - "generated_text": "\nOkay, the user wants a list of 100 creative product names for a tech startup. Let me start by brainstorming some ideas. Tech startups often focus on innovative solutions, so I need to think ab..." - }, - "optimized": { - "name": "repetitive_pattern_generation", - "prompt_tokens": 27, - "generated_tokens": 2000, - "prefill_tokens_per_sec": 698.002, - "decode_tokens_per_sec": 147.488, - "total_tokens_per_sec": 127.07780282702558, - "peak_memory_gb": 1.465, - "total_time_sec": 15.738389832898974, - "prompt": "Generate a list of 100 creative product names for a tech startup, with explanations:", - "generated_text": "\nOkay, the user wants a list of 100 creative product names for a tech startup. Let me start by brainstorming some ideas. Tech startups often focus on innovative solutions, so I need to think ab..." - }, - "improvements": { - "decode_speed_pct": 106.2942344812152, - "prefill_speed_pct": 13.809374735043397, - "total_speed_pct": 92.79850859663821, - "memory_reduction_pct": 5.422853453841179, - "time_reduction_pct": 48.132378861283534 - } - }, - { - "benchmark_name": "long_context_detailed", - "standard": { - "name": "long_context_detailed", - "prompt_tokens": 391, - "generated_tokens": 500, - "prefill_tokens_per_sec": 4059.863, - "decode_tokens_per_sec": 170.307, - "total_tokens_per_sec": 94.50554749332285, - "peak_memory_gb": 1.758, - "total_time_sec": 5.290694708004594, - "prompt": "Research Paper Summary:\n\nTitle: \"Advances in Large Language Models: Architecture, Training, and Applications\"\n\nAbstract: This paper reviews recent developments in large language models (LLMs), \nfocusi...", - "generated_text": "\nOkay, the user wants a detailed analysis of how architectural and training advances impact inference efficiency on mobile and edge devices. Let me start by recalling the key points from the re..." - }, - "optimized": { - "name": "long_context_detailed", - "prompt_tokens": 391, - "generated_tokens": 500, - "prefill_tokens_per_sec": 3974.441, - "decode_tokens_per_sec": 120.803, - "total_tokens_per_sec": 75.56414253281604, - "peak_memory_gb": 1.758, - "total_time_sec": 6.616895040962845, - "prompt": "Research Paper Summary:\n\nTitle: \"Advances in Large Language Models: Architecture, Training, and Applications\"\n\nAbstract: This paper reviews recent developments in large language models (LLMs), \nfocusi...", - "generated_text": "\nOkay, the user wants a detailed analysis of how architectural and training advances impact inference efficiency on mobile and edge devices. Let me start by recalling the key points from the re..." - }, - "improvements": { - "decode_speed_pct": -29.067507501159668, - "prefill_speed_pct": -2.104061146890918, - "total_speed_pct": -20.042638197345074, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -25.066657710409316 - } - }, - { - "benchmark_name": "micro_generation", - "standard": { - "name": "micro_generation", - "prompt_tokens": 17, - "generated_tokens": 10, - "prefill_tokens_per_sec": 346.786, - "decode_tokens_per_sec": 203.067, - "total_tokens_per_sec": 4.517200654424452, - "peak_memory_gb": 1.249, - "total_time_sec": 2.213760416023433, - "prompt": "Complete this sentence: The future of AI is", - "generated_text": "\nOkay, the user wants me to complete" - }, - "optimized": { - "name": "micro_generation", - "prompt_tokens": 17, - "generated_tokens": 10, - "prefill_tokens_per_sec": 368.377, - "decode_tokens_per_sec": 203.11, - "total_tokens_per_sec": 4.236131800369787, - "peak_memory_gb": 1.249, - "total_time_sec": 2.360644208267331, - "prompt": "Complete this sentence: The future of AI is", - "generated_text": "\nOkay, the user wants me to complete" - }, - "improvements": { - "decode_speed_pct": 0.02117527712528691, - "prefill_speed_pct": 6.226029885866214, - "total_speed_pct": -6.22219103283286, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -6.63503562448481 - } - }, - { - "benchmark_name": "step_by_step_reasoning", - "standard": { - "name": "step_by_step_reasoning", - "prompt_tokens": 61, - "generated_tokens": 400, - "prefill_tokens_per_sec": 1279.141, - "decode_tokens_per_sec": 168.392, - "total_tokens_per_sec": 85.45661112975772, - "peak_memory_gb": 1.307, - "total_time_sec": 4.68073791731149, - "prompt": "Solve this step by step:\n\nA train travels from City A to City B at 80 mph. The distance is 240 miles. \nIf it leaves at 2:00 PM, what time will it arrive? Show your work.", - "generated_text": "\nOkay, let's see. I need to figure out what time the train will arrive at City B if it leaves at 2:00 PM and travels at 80 mph for 240 miles. Hmm, right. So, first, I remember that distance equ..." - }, - "optimized": { - "name": "step_by_step_reasoning", - "prompt_tokens": 61, - "generated_tokens": 400, - "prefill_tokens_per_sec": 1442.308, - "decode_tokens_per_sec": 142.962, - "total_tokens_per_sec": 78.87836216644345, - "peak_memory_gb": 1.307, - "total_time_sec": 5.071099209133536, - "prompt": "Solve this step by step:\n\nA train travels from City A to City B at 80 mph. The distance is 240 miles. \nIf it leaves at 2:00 PM, what time will it arrive? Show your work.", - "generated_text": "\nOkay, let's see. I need to figure out what time the train will arrive at City B if it leaves at 2:00 PM and travels at 80 mph for 240 miles. Hmm, right. So, first, I remember that distance equ..." - }, - "improvements": { - "decode_speed_pct": -15.101667537650249, - "prefill_speed_pct": 12.755982335020136, - "total_speed_pct": -7.69776483802502, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -8.339738278836615 - } - }, - { - "benchmark_name": "ultra_long_generation", - "standard": { - "name": "ultra_long_generation", - "prompt_tokens": 13, - "generated_tokens": 468, - "prefill_tokens_per_sec": 383.678, - "decode_tokens_per_sec": 171.811, - "total_tokens_per_sec": 92.45339073205282, - "peak_memory_gb": 1.523, - "total_time_sec": 5.062010125257075, - "prompt": "The future of AI is", - "generated_text": "\nOkay, the user is asking about the future of AI. Let me start by breaking down the key points they might be interested in. First, I should mention the current state of AI, like machine learnin..." - }, - "optimized": { - "name": "ultra_long_generation", - "prompt_tokens": 13, - "generated_tokens": 468, - "prefill_tokens_per_sec": 440.611, - "decode_tokens_per_sec": 139.934, - "total_tokens_per_sec": 83.87973277956566, - "peak_memory_gb": 1.503, - "total_time_sec": 5.579416916240007, - "prompt": "The future of AI is", - "generated_text": "\nOkay, the user is asking about the future of AI. Let me start by breaking down the key points they might be interested in. First, I should mention the current state of AI, like machine learnin..." - }, - "improvements": { - "decode_speed_pct": -18.5535268405399, - "prefill_speed_pct": 14.83874498928789, - "total_speed_pct": -9.273492172218138, - "memory_reduction_pct": 1.3131976362442561, - "time_reduction_pct": -10.221370131231321 - } - }, - { - "benchmark_name": "very_long_context_comprehensive", - "standard": { - "name": "very_long_context_comprehensive", - "prompt_tokens": 928, - "generated_tokens": 1000, - "prefill_tokens_per_sec": 5146.123, - "decode_tokens_per_sec": 161.682, - "total_tokens_per_sec": 117.59371221458863, - "peak_memory_gb": 2.158, - "total_time_sec": 8.503856041003019, - "prompt": "Research Paper Summary:\n\nTitle: \"Advances in Large Language Models: Architecture, Training, and Applications\"\n\nAbstract: This paper reviews recent developments in large language models (LLMs), \nfocusi...", - "generated_text": "\nOkay, so I need to analyze how the architectural and training advances in large language models impact inference efficiency on mobile and edge devices, especially considering Apple Silicon. Le..." - }, - "optimized": { - "name": "very_long_context_comprehensive", - "prompt_tokens": 928, - "generated_tokens": 1000, - "prefill_tokens_per_sec": 4958.784, - "decode_tokens_per_sec": 106.292, - "total_tokens_per_sec": 82.90796709835429, - "peak_memory_gb": 2.158, - "total_time_sec": 12.061567000113428, - "prompt": "Research Paper Summary:\n\nTitle: \"Advances in Large Language Models: Architecture, Training, and Applications\"\n\nAbstract: This paper reviews recent developments in large language models (LLMs), \nfocusi...", - "generated_text": "\nOkay, so I need to analyze how the architectural and training advances in large language models impact inference efficiency on mobile and edge devices, especially considering Apple Silicon. Le..." - }, - "improvements": { - "decode_speed_pct": -34.25860640021771, - "prefill_speed_pct": -3.6403910283528, - "total_speed_pct": -29.496258314338036, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -41.83644386683175 - } - }, - { - "benchmark_name": "short_generation", - "standard": { - "name": "short_generation", - "prompt_tokens": 19, - "generated_tokens": 100, - "prefill_tokens_per_sec": 388.449, - "decode_tokens_per_sec": 180.845, - "total_tokens_per_sec": 34.69684412864018, - "peak_memory_gb": 1.25, - "total_time_sec": 2.882106500212103, - "prompt": "Explain in one paragraph: What makes transformers effective?", - "generated_text": "\nOkay, the user wants me to explain why transformers are effective in one paragraph. Let me start by recalling what I know about transformers. They are used in power transmission, right? So, th..." - }, - "optimized": { - "name": "short_generation", - "prompt_tokens": 19, - "generated_tokens": 100, - "prefill_tokens_per_sec": 480.388, - "decode_tokens_per_sec": 166.885, - "total_tokens_per_sec": 33.4333817918928, - "peak_memory_gb": 1.25, - "total_time_sec": 2.991022584028542, - "prompt": "Explain in one paragraph: What makes transformers effective?", - "generated_text": "\nOkay, the user wants me to explain why transformers are effective in one paragraph. Let me start by recalling what I know about transformers. They are used in power transmission, right? So, th..." - }, - "improvements": { - "decode_speed_pct": -7.719317647709369, - "prefill_speed_pct": 23.668229291361275, - "total_speed_pct": -3.6414330135127986, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -3.77904438328089 - } - }, - { - "benchmark_name": "long_generation", - "standard": { - "name": "long_generation", - "prompt_tokens": 19, - "generated_tokens": 1000, - "prefill_tokens_per_sec": 383.041, - "decode_tokens_per_sec": 167.826, - "total_tokens_per_sec": 121.2860867452095, - "peak_memory_gb": 1.336, - "total_time_sec": 8.244968791026622, - "prompt": "Write a detailed technical explanation of how neural networks learn:", - "generated_text": "\nOkay, so I need to explain how neural networks learn. Let me start by recalling what I know. Neural networks are like big computers that can learn from data. They have layers of processing, ri..." - }, - "optimized": { - "name": "long_generation", - "prompt_tokens": 19, - "generated_tokens": 1000, - "prefill_tokens_per_sec": 515.049, - "decode_tokens_per_sec": 131.841, - "total_tokens_per_sec": 101.30268327558746, - "peak_memory_gb": 1.364, - "total_time_sec": 9.871406834106892, - "prompt": "Write a detailed technical explanation of how neural networks learn:", - "generated_text": "\nOkay, so I need to explain how neural networks learn. Let me start by recalling what I know. Neural networks are like big computers that can learn from data. They have layers of processing, ri..." - }, - "improvements": { - "decode_speed_pct": -21.441850488005425, - "prefill_speed_pct": 34.46315146420357, - "total_speed_pct": -16.47625379455269, - "memory_reduction_pct": -2.0958083832335346, - "time_reduction_pct": -19.726430557874245 - } - }, - { - "benchmark_name": "conversational_assistant", - "standard": { - "name": "conversational_assistant", - "prompt_tokens": 85, - "generated_tokens": 1060, - "prefill_tokens_per_sec": 1558.637, - "decode_tokens_per_sec": 110.265, - "total_tokens_per_sec": 88.00089711055672, - "peak_memory_gb": 1.404, - "total_time_sec": 12.045331750065088, - "prompt": "You are a helpful AI assistant. A user asks:\n\n\"I'm planning a trip to Japan for 2 weeks. I've never been there before. I like \nhistory, food, and nature. I have a moderate budget. Can you help me plan...", - "generated_text": "\nOkay, the user is planning a 2-week trip to Japan. They've never been before, so they need a detailed itinerary with recommendations for cities, activities, and travel tips. Let me start by br..." - }, - "optimized": { - "name": "conversational_assistant", - "prompt_tokens": 85, - "generated_tokens": 1060, - "prefill_tokens_per_sec": 1624.919, - "decode_tokens_per_sec": 147.833, - "total_tokens_per_sec": 110.80791478921105, - "peak_memory_gb": 1.367, - "total_time_sec": 9.566103667020798, - "prompt": "You are a helpful AI assistant. A user asks:\n\n\"I'm planning a trip to Japan for 2 weeks. I've never been there before. I like \nhistory, food, and nature. I have a moderate budget. Can you help me plan...", - "generated_text": "\nOkay, the user is planning a 2-week trip to Japan. They've never been before, so they need a detailed itinerary with recommendations for cities, activities, and travel tips. Let me start by br..." - }, - "improvements": { - "decode_speed_pct": 34.07064798440121, - "prefill_speed_pct": 4.252561693325653, - "total_speed_pct": 25.916801336697247, - "memory_reduction_pct": 2.63532763532763, - "time_reduction_pct": 20.582480702790885 - } - }, - { - "benchmark_name": "creative_writing", - "standard": { - "name": "creative_writing", - "prompt_tokens": 53, - "generated_tokens": 800, - "prefill_tokens_per_sec": 1112.589, - "decode_tokens_per_sec": 154.895, - "total_tokens_per_sec": 106.99556747700527, - "peak_memory_gb": 1.381, - "total_time_sec": 7.476945249829441, - "prompt": "Write a short story about a robot who discovers emotions for the first time. \nInclude dialogue and describe the robot's internal experience as it learns about feelings like \njoy, sadness, and wonder. ...", - "generated_text": "\nOkay, the user wants a short story about a robot discovering emotions for the first time. They specified including dialogue, internal experience, and making it engaging and thoughtful. Let me ..." - }, - "optimized": { - "name": "creative_writing", - "prompt_tokens": 53, - "generated_tokens": 800, - "prefill_tokens_per_sec": 1540.651, - "decode_tokens_per_sec": 141.137, - "total_tokens_per_sec": 100.8810695154153, - "peak_memory_gb": 1.335, - "total_time_sec": 7.930130041670054, - "prompt": "Write a short story about a robot who discovers emotions for the first time. \nInclude dialogue and describe the robot's internal experience as it learns about feelings like \njoy, sadness, and wonder. ...", - "generated_text": "\nOkay, the user wants a short story about a robot discovering emotions for the first time. They specified including dialogue, internal experience, and making it engaging and thoughtful. Let me ..." - }, - "improvements": { - "decode_speed_pct": -8.88214596985055, - "prefill_speed_pct": 38.47440519365194, - "total_speed_pct": -5.7147208111252334, - "memory_reduction_pct": 3.330919623461263, - "time_reduction_pct": -6.06109549686686 - } - }, - { - "benchmark_name": "medium_context_analysis", - "standard": { - "name": "medium_context_analysis", - "prompt_tokens": 127, - "generated_tokens": 200, - "prefill_tokens_per_sec": 2300.242, - "decode_tokens_per_sec": 169.049, - "total_tokens_per_sec": 59.57798010812093, - "peak_memory_gb": 1.396, - "total_time_sec": 3.3569449591450393, - "prompt": "Context: Machine learning has revolutionized many industries in recent years. \nFrom healthcare diagnosis to autonomous vehicles, AI systems are becoming increasingly \nsophisticated. However, challenge...", - "generated_text": "\nOkay, let's tackle this question. The user wants me to analyze the current state of AI development based on the given context and predict the most important research directions for the next fi..." - }, - "optimized": { - "name": "medium_context_analysis", - "prompt_tokens": 127, - "generated_tokens": 200, - "prefill_tokens_per_sec": 2099.829, - "decode_tokens_per_sec": 169.053, - "total_tokens_per_sec": 54.26174147081993, - "peak_memory_gb": 1.396, - "total_time_sec": 3.6858382089994848, - "prompt": "Context: Machine learning has revolutionized many industries in recent years. \nFrom healthcare diagnosis to autonomous vehicles, AI systems are becoming increasingly \nsophisticated. However, challenge...", - "generated_text": "\nOkay, let's tackle this question. The user wants me to analyze the current state of AI development based on the given context and predict the most important research directions for the next fi..." - }, - "improvements": { - "decode_speed_pct": 0.0023661778537528632, - "prefill_speed_pct": -8.712691968931964, - "total_speed_pct": -8.92316024754985, - "memory_reduction_pct": 0.0, - "time_reduction_pct": -9.7973977487617 - } - }, - { - "benchmark_name": "comprehensive_analysis_generation", - "standard": { - "name": "comprehensive_analysis_generation", - "prompt_tokens": 39, - "generated_tokens": 1232, - "prefill_tokens_per_sec": 899.455, - "decode_tokens_per_sec": 108.956, - "total_tokens_per_sec": 89.29787741356088, - "peak_memory_gb": 1.428, - "total_time_sec": 13.796520540956408, - "prompt": "Analyze the evolution of computer programming languages from assembly to modern high-level languages. Discuss paradigms, performance considerations, developer productivity, and future trends:", - "generated_text": "\nOkay, so I need to analyze the evolution of computer programming languages from assembly to modern high-level languages. Let me start by recalling what I know about this topic. \n\nFirst, assemb..." - }, - "optimized": { - "name": "comprehensive_analysis_generation", - "prompt_tokens": 39, - "generated_tokens": 1232, - "prefill_tokens_per_sec": 1003.789, - "decode_tokens_per_sec": 156.875, - "total_tokens_per_sec": 123.20302567134158, - "peak_memory_gb": 1.368, - "total_time_sec": 9.99975441582501, - "prompt": "Analyze the evolution of computer programming languages from assembly to modern high-level languages. Discuss paradigms, performance considerations, developer productivity, and future trends:", - "generated_text": "\nOkay, so I need to analyze the evolution of computer programming languages from assembly to modern high-level languages. Let me start by recalling what I know about this topic. \n\nFirst, assemb..." - }, - "improvements": { - "decode_speed_pct": 43.98013877161422, - "prefill_speed_pct": 11.599690923948385, - "total_speed_pct": 37.968593699889915, - "memory_reduction_pct": 4.201680672268896, - "time_reduction_pct": 27.519736689118844 - } - } - ], - "aggregate_improvements": { - "decode_speed_improvements_avg": 12.524778103377688, - "decode_speed_improvements_median": -0.7221175434583268, - "decode_speed_improvements_min": -34.25860640021771, - "decode_speed_improvements_max": 106.2942344812152, - "decode_speed_improvements_std": 38.29698329321707, - "prefill_speed_improvements_avg": 14.435163691749414, - "prefill_speed_improvements_median": 13.282678535031767, - "prefill_speed_improvements_min": -13.183079593232083, - "prefill_speed_improvements_max": 49.431001564329655, - "prefill_speed_improvements_std": 17.649765739092885, - "total_speed_improvements_avg": 10.407679373300745, - "total_speed_improvements_median": -4.678076912319016, - "total_speed_improvements_min": -29.496258314338036, - "total_speed_improvements_max": 92.79850859663821, - "total_speed_improvements_std": 30.698256840048263, - "memory_improvements_avg": 0.9905551842183241, - "memory_improvements_median": 0.0, - "memory_improvements_min": -2.0958083832335346, - "memory_improvements_max": 5.422853453841179, - "memory_improvements_std": 1.7245771941812529, - "time_improvements_avg": 3.213888268537205, - "time_improvements_median": -4.920069940073875, - "time_improvements_min": -41.83644386683175, - "time_improvements_max": 48.132378861283534, - "time_improvements_std": 23.136633995726953 - }, - "summary": { - "avg_decode_improvement_pct": 12.524778103377688, - "avg_total_improvement_pct": 10.407679373300745, - "avg_memory_reduction_pct": 0.9905551842183241, - "avg_time_reduction_pct": 3.213888268537205, - "avg_standard_decode_speed": 143.99245, - "avg_optimized_decode_speed": 148.9022, - "benchmarks_improved": 10, - "total_benchmarks": 20 - } -} \ No newline at end of file diff --git a/examples/mlx_metal_kernel_opt/openevolve_comparison_summary_1750305870.csv b/examples/mlx_metal_kernel_opt/openevolve_comparison_summary_1750305870.csv deleted file mode 100644 index 91fa0f5c7..000000000 --- a/examples/mlx_metal_kernel_opt/openevolve_comparison_summary_1750305870.csv +++ /dev/null @@ -1,21 +0,0 @@ -benchmark_name,category,standard_decode_speed,optimized_decode_speed,decode_improvement_pct,standard_prefill_speed,optimized_prefill_speed,prefill_improvement_pct,standard_total_speed,optimized_total_speed,total_improvement_pct,standard_memory_gb,optimized_memory_gb,memory_reduction_pct,standard_time_sec,optimized_time_sec,time_reduction_pct -short_context_quick,short_context,186.437,183.74,-1.4466012647704065,355.133,331.978,-6.520092472397658,19.89186747411851,19.301839590556543,-2.9661764252635234,1.243,1.243,0.0,2.513590042013675,2.5904266671277583,-3.0568479278557414 -code_generation,code_generation,173.538,144.969,-16.462676762438207,1286.789,1859.139,44.47893166634156,74.5889658469731,69.72322167293892,-6.523410156961754,1.309,1.309,0.0,4.022042625118047,4.302727166097611,-6.978656546966011 -sustained_dialogue_generation,general,108.362,158.907,46.64458020339235,999.622,1290.104,29.05918437169251,84.07564971368124,114.54800525926025,36.24397271903613,1.341,1.334,0.521998508575682,11.239877458196133,8.249816291965544,26.60225769677082 -technical_documentation,general,133.789,145.408,8.684570480383291,1616.155,1403.096,-13.183079593232083,105.73404966830024,114.65301020422453,8.435277532548955,1.428,1.403,1.7507002801120386,11.34922954114154,10.46636279206723,7.779089724759489 -progressive_context_building,general,90.467,150.34,66.18214376512984,3682.41,4294.586,16.624330261975185,66.01334784072361,97.06952694112076,47.04530237631517,1.733,1.733,0.0,9.089070917107165,6.181136541068554,31.99374724390573 -maximum_context_stress_test,stress_test,90.432,131.441,45.34788570417551,5323.962,5307.325,-0.3124928389797039,78.57323431997136,108.62816525269336,38.2508511872252,2.709,2.709,0.0,20.897701541893184,15.115785083733499,27.667714779870877 -very_long_generation,long_context,167.434,131.146,-21.673017427762588,330.493,493.859,49.431001564329655,125.5328968001133,104.55887658599336,-16.7079871083649,1.383,1.373,0.7230657989877085,9.312300040852278,11.180303750094026,-20.059530954189324 -extreme_long_generation,long_context,90.801,157.88,73.87473706236715,675.64,834.378,23.49446450772602,76.0227511960408,117.97192751142086,55.17976612975397,1.395,1.367,2.0071684587813636,15.166512417141348,9.77351158298552,35.55860889983252 -repetitive_pattern_generation,general,71.494,147.488,106.2942344812152,613.308,698.002,13.809374735043397,65.91223332172675,127.07780282702558,92.79850859663821,1.549,1.465,5.422853453841179,30.343380874954164,15.738389832898974,48.132378861283534 -long_context_detailed,long_context,170.307,120.803,-29.067507501159668,4059.863,3974.441,-2.104061146890918,94.50554749332285,75.56414253281604,-20.042638197345074,1.758,1.758,0.0,5.290694708004594,6.616895040962845,-25.066657710409316 -micro_generation,general,203.067,203.11,0.02117527712528691,346.786,368.377,6.226029885866214,4.517200654424452,4.236131800369787,-6.22219103283286,1.249,1.249,0.0,2.213760416023433,2.360644208267331,-6.63503562448481 -step_by_step_reasoning,general,168.392,142.962,-15.101667537650249,1279.141,1442.308,12.755982335020136,85.45661112975772,78.87836216644345,-7.69776483802502,1.307,1.307,0.0,4.68073791731149,5.071099209133536,-8.339738278836615 -ultra_long_generation,long_context,171.811,139.934,-18.5535268405399,383.678,440.611,14.83874498928789,92.45339073205282,83.87973277956566,-9.273492172218138,1.523,1.503,1.3131976362442561,5.062010125257075,5.579416916240007,-10.221370131231321 -very_long_context_comprehensive,long_context,161.682,106.292,-34.25860640021771,5146.123,4958.784,-3.6403910283528,117.59371221458863,82.90796709835429,-29.496258314338036,2.158,2.158,0.0,8.503856041003019,12.061567000113428,-41.83644386683175 -short_generation,short_context,180.845,166.885,-7.719317647709369,388.449,480.388,23.668229291361275,34.69684412864018,33.4333817918928,-3.6414330135127986,1.25,1.25,0.0,2.882106500212103,2.991022584028542,-3.77904438328089 -long_generation,long_context,167.826,131.841,-21.441850488005425,383.041,515.049,34.46315146420357,121.2860867452095,101.30268327558746,-16.47625379455269,1.336,1.364,-2.0958083832335346,8.244968791026622,9.871406834106892,-19.726430557874245 -conversational_assistant,general,110.265,147.833,34.07064798440121,1558.637,1624.919,4.252561693325653,88.00089711055672,110.80791478921105,25.916801336697247,1.404,1.367,2.63532763532763,12.045331750065088,9.566103667020798,20.582480702790885 -creative_writing,general,154.895,141.137,-8.88214596985055,1112.589,1540.651,38.47440519365194,106.99556747700527,100.8810695154153,-5.7147208111252334,1.381,1.335,3.330919623461263,7.476945249829441,7.930130041670054,-6.06109549686686 -medium_context_analysis,general,169.049,169.053,0.0023661778537528632,2300.242,2099.829,-8.712691968931964,59.57798010812093,54.26174147081993,-8.92316024754985,1.396,1.396,0.0,3.3569449591450393,3.6858382089994848,-9.7973977487617 -comprehensive_analysis_generation,general,108.956,156.875,43.98013877161422,899.455,1003.789,11.599690923948385,89.29787741356088,123.20302567134158,37.968593699889915,1.428,1.368,4.201680672268896,13.796520540956408,9.99975441582501,27.519736689118844 From 8ed7e9e1545bfe2b5c9c8aee3450a307a2b80e54 Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Thu, 8 Jan 2026 16:09:10 +0800 Subject: [PATCH 13/14] docs(mlx_metal_kernel_opt): add demo results output (20260105_180918) - Add curated demo output artifacts (best program + logs + config) - Document demo location in README --- examples/mlx_metal_kernel_opt/README.md | 14 +- .../best/best_program.py | 508 ++++++++++++++++++ .../best/best_program_info.json | 195 +++++++ .../demo_output_20260105_180918/config.yaml | 114 ++++ 4 files changed, 830 insertions(+), 1 deletion(-) create mode 100644 examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program.py create mode 100644 examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program_info.json create mode 100644 examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index b79db2715..856a69686 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -27,7 +27,7 @@ export OPENAI_API_KEY="your-gemini-key" cd openevolve/examples/mlx_metal_kernel_opt # Using the experiment runner script -./run_evolve_experiment.sh --name test_run --iterations 25 +./run_evolve_experiment.sh --run-name test_run --iterations 25 # Or directly python -m openevolve.cli \ @@ -83,6 +83,18 @@ The evolution improved from an initial -11.5% regression to -3.2%, but never exc For detailed experiment results and analysis, see [EVOLUTION_ANALYSIS.md](./EVOLUTION_ANALYSIS.md). +### Demo Results (Committed) + +For review and reproducibility, this example repo includes one demo run output directory (curated subset of artifacts): + +- `demo_output_20260105_180918/` + +The key artifacts in that folder are: + +- `best/best_program.py`: best evolved program (iteration 23) +- `best/best_program_info.json`: metrics and baseline comparisons (includes the -3.2% result) +- `run.log` and `logs/openevolve_20260105_180918.log`: full run logs + ### Known Limitations 1. MAP-Elites selection uses abstract `combined_score` instead of direct speedup ratios diff --git a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program.py b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program.py new file mode 100644 index 000000000..3753fcc0d --- /dev/null +++ b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program.py @@ -0,0 +1,508 @@ +""" +Qwen3 Custom Metal Kernel for Grouped Query Attention (GQA) Optimization + +This module implements a custom Metal kernel for Qwen3's 16:8 GQA pattern using +MLX's metal_kernel API. The kernel is designed to outperform mx.fast.scaled_dot_product_attention +by leveraging Apple Silicon specific optimizations and the 2:1 query-to-KV head ratio. + +Target: Qwen3-0.6B with 16 query heads : 8 KV heads +Hardware: Apple M-series GPUs with unified memory +Baseline: Standard MLX-LM using mx.fast.scaled_dot_product_attention +Goal: 5-15% performance improvement through custom Metal kernel optimization + +Evolution Target: The Metal kernel source code that computes GQA attention +""" + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import math +from typing import Optional, Tuple, Any +import time + + +def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None): + """ + Custom Metal kernel implementation for Qwen3 GQA attention. + + Args: + queries: [B, num_heads=16, L, head_dim=128] + keys: [B, num_kv_heads=8, L, head_dim=128] + values: [B, num_kv_heads=8, L, head_dim=128] + scale: Attention scaling factor (1/sqrt(head_dim)) + mask: Attention mask (None, "causal", or boolean tensor) + + Returns: + Attention output [B, num_heads=16, L, head_dim=128] + """ + + B, num_heads, L, head_dim = queries.shape + _, num_kv_heads, _, _ = keys.shape + heads_per_kv = num_heads // num_kv_heads # 2 for Qwen3-0.6B (16:8) + + # Handle mask conversion + if mask == "causal" or mask is None: + # Create causal mask for autoregressive attention + causal_mask = mx.triu(mx.ones((L, L), dtype=mx.bool_), k=1) + mask_tensor = mx.logical_not(causal_mask) # True where attention is allowed + use_mask = True + elif isinstance(mask, (mx.array, type(None))): + if mask is None: + mask_tensor = mx.ones((L, L), dtype=mx.bool_) + use_mask = False + else: + mask_tensor = mask.astype(mx.bool_) + use_mask = True + else: + # Raise error for unsupported mask types - no fallback + raise ValueError( + f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask." + ) + + # Expand mask to match batch and head dimensions if needed + if mask_tensor.ndim == 2: + mask_tensor = mx.broadcast_to(mask_tensor[None, None, :, :], (B, num_heads, L, L)) + elif mask_tensor.ndim == 3: + mask_tensor = mx.broadcast_to(mask_tensor[:, None, :, :], (B, num_heads, L, L)) + + # EVOLVE-BLOCK-START + # Custom Metal kernel source for Qwen3 GQA optimization + # This kernel leverages the 16:8 head ratio and Apple Silicon architecture + kernel_source = """ + // Qwen3 GQA Metal Kernel - Optimized for 16:8 head pattern + // Thread mapping: each thread processes one query position + uint thread_id = thread_position_in_grid.x; + uint head_idx = thread_position_in_grid.y; + uint batch_idx = thread_position_in_grid.z; + uint query_pos = thread_id; + + // Bounds checking + if (batch_idx >= BATCH_SIZE || head_idx >= NUM_HEADS || query_pos >= SEQ_LEN) { + return; + } + + // Extract scalar values from input arrays + T scale_val = scale[0]; + bool use_mask_val = use_mask[0] > 0; + + // GQA mapping: determine which KV head corresponds to this query head + uint kv_head_idx = head_idx / HEADS_PER_KV; // 2 query heads per KV head + + // Pre-calculate base indices for memory access optimization + const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) + + head_idx * (SEQ_LEN * HEAD_DIM) + + query_pos * HEAD_DIM; + + const uint k_base_start = batch_idx * (NUM_KV_HEADS * SEQ_LEN * HEAD_DIM) + + kv_head_idx * (SEQ_LEN * HEAD_DIM); + + const uint v_base_start = k_base_start; // Values have same layout as keys + + const uint mask_base = batch_idx * (NUM_HEADS * SEQ_LEN * SEQ_LEN) + + head_idx * (SEQ_LEN * SEQ_LEN) + + query_pos * SEQ_LEN; + + const uint out_base = q_base; + + // Load query vector for this position (8-wide unrolled for better instruction-level parallelism) + T query_vec[HEAD_DIM]; + for (uint d = 0; d < HEAD_DIM; d += 8) { + query_vec[d] = queries[q_base + d]; + query_vec[d+1] = queries[q_base + d+1]; + query_vec[d+2] = queries[q_base + d+2]; + query_vec[d+3] = queries[q_base + d+3]; + query_vec[d+4] = queries[q_base + d+4]; + query_vec[d+5] = queries[q_base + d+5]; + query_vec[d+6] = queries[q_base + d+6]; + query_vec[d+7] = queries[q_base + d+7]; + } + + // First pass: compute attention scores and find maximum for numerical stability + T max_score = T(-INFINITY); + T scores[SEQ_LEN]; // Cache scores to avoid recomputation + + for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { + // Compute Q @ K^T for this key position + const uint k_base = k_base_start + key_pos * HEAD_DIM; + T score = T(0.0); + + // Vectorized dot product - process 8 elements at a time for wider SIMD utilization. + // HEAD_DIM=128 is a multiple of 8, so no remainder check is needed. + for (uint d = 0; d < HEAD_DIM; d += 8) { + score += query_vec[d] * keys[k_base + d] + + query_vec[d+1] * keys[k_base + d+1] + + query_vec[d+2] * keys[k_base + d+2] + + query_vec[d+3] * keys[k_base + d+3] + + query_vec[d+4] * keys[k_base + d+4] + + query_vec[d+5] * keys[k_base + d+5] + + query_vec[d+6] * keys[k_base + d+6] + + query_vec[d+7] * keys[k_base + d+7]; + } + + // Apply attention scaling + score *= scale_val; + + // Check attention mask and set score to -INFINITY if invalid. + // This makes the loop body uniform, avoiding conditional branching mid-loop. + bool is_valid = use_mask_val ? mask[mask_base + key_pos] : true; + scores[key_pos] = is_valid ? score : T(-INFINITY); + max_score = max(max_score, scores[key_pos]); + } + + // Second pass: compute softmax denominator + T sum_exp = T(0.0); + for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { + // Compute exp(score - max_score) unconditionally. + // If scores[key_pos] was -INFINITY (due to masking), exp(...) will correctly evaluate to 0. + T exp_score = exp(scores[key_pos] - max_score); + scores[key_pos] = exp_score; // Overwrite with exp(score - max) + sum_exp += exp_score; + } + + // Initialize output to zero (8-wide unrolled) + for (uint d = 0; d < HEAD_DIM; d += 8) { + output[out_base + d] = T(0.0); + output[out_base + d+1] = T(0.0); + output[out_base + d+2] = T(0.0); + output[out_base + d+3] = T(0.0); + output[out_base + d+4] = T(0.0); + output[out_base + d+5] = T(0.0); + output[out_base + d+6] = T(0.0); + output[out_base + d+7] = T(0.0); + } + + // Third pass: compute weighted sum of values + if (sum_exp > T(0.0)) { // This outer check is necessary to prevent division by zero + T inv_sum_exp = T(1.0) / sum_exp; // Pre-compute inverse for performance + for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) { + T attention_weight = scores[key_pos] * inv_sum_exp; // Use pre-computed inverse + + // If scores[key_pos] was 0 (due to mask or exp(-inf)), attention_weight will be 0. + // Multiplying by 0 won't change the accumulator, so the branch is not strictly needed + // and removing it can improve SIMD utilization by making the loop uniform. + const uint v_base = v_base_start + key_pos * HEAD_DIM; + + // Vectorized accumulation - process 8 elements at a time. + // HEAD_DIM=128 is a multiple of 8, so no remainder check is needed. + for (uint d = 0; d < HEAD_DIM; d += 8) { + output[out_base + d] += attention_weight * values[v_base + d]; + output[out_base + d+1] += attention_weight * values[v_base + d+1]; + output[out_base + d+2] += attention_weight * values[v_base + d+2]; + output[out_base + d+3] += attention_weight * values[v_base + d+3]; + output[out_base + d+4] += attention_weight * values[v_base + d+4]; + output[out_base + d+5] += attention_weight * values[v_base + d+5]; + output[out_base + d+6] += attention_weight * values[v_base + d+6]; + output[out_base + d+7] += attention_weight * values[v_base + d+7]; + } + } + } + """ + # EVOLVE-BLOCK-END + + try: + # Prepare kernel inputs + scale_tensor = mx.array([scale], dtype=queries.dtype) + use_mask_tensor = mx.array([1 if use_mask else 0], dtype=mx.int32) + + # Create and execute custom Metal kernel + kernel = mx.fast.metal_kernel( + name="qwen3_gqa_attention_kernel", + input_names=["queries", "keys", "values", "mask", "scale", "use_mask"], + output_names=["output"], + source=kernel_source, + ) + + # Optimize thread group size for Apple Silicon + threadgroup_size = min(32, L) # Adapt to sequence length + + # Execute kernel + outputs = kernel( + inputs=[queries, keys, values, mask_tensor, scale_tensor, use_mask_tensor], + output_shapes=[(B, num_heads, L, head_dim)], + output_dtypes=[queries.dtype], + grid=(L, num_heads, B), # (SEQ_LEN, NUM_HEADS, BATCH_SIZE) + threadgroup=(threadgroup_size, 1, 1), + template=[ + ("T", queries.dtype), + ("BATCH_SIZE", B), + ("NUM_HEADS", num_heads), + ("NUM_KV_HEADS", num_kv_heads), + ("SEQ_LEN", L), + ("HEAD_DIM", head_dim), + ("HEADS_PER_KV", heads_per_kv), + ], + ) + + return outputs[0] + + except Exception as e: + # No fallback - let the custom kernel failure propagate for proper scoring + print(f"❌ Custom GQA kernel failed: {e}") + raise RuntimeError(f"Custom Metal kernel execution failed: {e}") from e + + +class CustomGQAAttention(nn.Module): + """ + Qwen3 attention module with custom Metal kernel optimization. + + This module integrates the custom Metal kernel while maintaining + compatibility with the standard MLX-LM interface. + """ + + def __init__(self, args): + super().__init__() + + # Standard Qwen3 parameters + dim = args.hidden_size # 2048 + self.n_heads = n_heads = args.num_attention_heads # 16 + assert args.num_key_value_heads is not None + self.n_kv_heads = n_kv_heads = args.num_key_value_heads # 8 + head_dim = args.head_dim # 128 + self.scale = head_dim**-0.5 + + # Standard MLX-LM projections + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + # Standard MLX-LM norms + self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) + + # Standard MLX-LM RoPE + try: + from mlx_lm.models.rope_utils import initialize_rope + + self.rope = initialize_rope( + head_dim, + base=args.rope_theta, + traditional=False, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, + ) + except ImportError: + print("⚠️ Could not import mlx_lm rope_utils, using basic RoPE") + self.rope = None + + print(f"🔧 Initialized Custom Metal GQA Attention") + print(f" 📊 Architecture: {n_heads}:{n_kv_heads} heads ({n_heads//n_kv_heads}:1 ratio)") + print(f" 🎯 Head dimension: {head_dim}") + print(f" ⚡ Using custom Metal kernel for GQA optimization") + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + # Standard preprocessing (already optimized, don't evolve) + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # Standard RoPE application (already optimized, don't evolve) + if cache is not None: + if self.rope is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + if self.rope is not None: + queries = self.rope(queries) + keys = self.rope(keys) + + # CORE INNOVATION: Custom Metal kernel for GQA attention + output = qwen3_custom_gqa_attention(queries, keys, values, scale=self.scale, mask=mask) + + # Standard postprocessing (already optimized, don't evolve) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +def create_metal_qwen3_optimization_hook(): + """ + Create hooks to replace Qwen3's attention with Metal kernel optimized version. + """ + + def apply_optimization_hook(): + """Apply the Metal kernel optimized attention""" + try: + import mlx_lm.models.qwen3 as qwen3_module + + # Store original attention class + original_attention = qwen3_module.Attention + + # Replace with Metal optimized implementation + qwen3_module.Attention = CustomGQAAttention + + print("✅ Applied Custom Metal GQA Attention hook") + return original_attention + + except ImportError: + print("❌ Could not import mlx_lm.models.qwen3") + return None + + def remove_optimization_hook(original_attention): + """Remove the optimization hook""" + try: + import mlx_lm.models.qwen3 as qwen3_module + + qwen3_module.Attention = original_attention + print("✅ Removed Custom Metal GQA Attention hook") + except ImportError: + pass + + return apply_optimization_hook, remove_optimization_hook + + +def benchmark_metal_gqa_optimization(): + """ + Benchmark Metal kernel optimized GQA attention against MLX baseline. + """ + + # Qwen3-0.6B configuration + class MockArgs: + hidden_size = 2048 + num_attention_heads = 16 + num_key_value_heads = 8 + head_dim = 128 + rms_norm_eps = 1e-06 + rope_theta = 1000000 + rope_scaling = None + max_position_embeddings = 40960 + + args = MockArgs() + + # Test configurations for Metal kernel validation + test_configs = [ + ("short_sequence", 1, 128, 2048), + ("medium_sequence", 1, 512, 2048), + ("long_sequence", 1, 1024, 2048), + ("max_sequence", 1, 2048, 2048), + ] + + print("Benchmarking Custom Metal GQA Kernel vs MLX Baseline") + print("=" * 70) + + # Initialize Metal optimized attention + metal_attn = CustomGQAAttention(args) + + for config_name, batch_size, seq_len, hidden_size in test_configs: + print(f"\nTesting {config_name}: B={batch_size}, L={seq_len}") + + # Create test inputs + x = mx.random.normal((batch_size, seq_len, hidden_size)) + mask = "causal" + + # Warmup runs + for _ in range(3): + _ = metal_attn(x, mask=mask) + mx.eval(_) + + # Benchmark Metal optimized implementation + mx.synchronize() + start_time = time.perf_counter() + + for _ in range(10): + output = metal_attn(x, mask=mask) + mx.eval(output) + + mx.synchronize() + end_time = time.perf_counter() + + avg_time = (end_time - start_time) / 10 + tokens_per_sec = seq_len / avg_time + + print(f" Metal GQA: {avg_time*1000:.2f} ms, {tokens_per_sec:.1f} tokens/sec") + print(f" Memory: {mx.get_active_memory() / 1e9:.2f} GB") + + +def test_metal_gqa_correctness(): + """ + Test that Metal kernel implementation produces correct results. + """ + print("Testing Custom Metal GQA Correctness") + print("=" * 50) + + # Test configuration + B, L, D = 1, 64, 2048 + + class MockArgs: + hidden_size = 2048 + num_attention_heads = 16 + num_key_value_heads = 8 + head_dim = 128 + rms_norm_eps = 1e-06 + rope_theta = 1000000 + rope_scaling = None + max_position_embeddings = 40960 + + args = MockArgs() + + # Create test input + x = mx.random.normal((B, L, D)) + mask = "causal" + + # Test Metal optimized implementation + metal_attn = CustomGQAAttention(args) + output = metal_attn(x, mask=mask) + + print(f"✅ Metal GQA output shape: {output.shape}") + + # Check for valid output + has_nan = bool(mx.any(mx.isnan(output))) + has_inf = bool(mx.any(mx.isinf(output))) + + print(f"✅ Has NaN: {has_nan}, Has Inf: {has_inf}") + + # Check output statistics + output_mean = float(mx.mean(output)) + output_std = float(mx.std(output)) + + print(f"✅ Output statistics - Mean: {output_mean:.6f}, Std: {output_std:.6f}") + + # Test direct kernel function + print("\n=== Testing Direct Kernel Function ===") + B, H, L, D = 1, 16, 128, 128 + q = mx.random.normal((B, H, L, D)) + k = mx.random.normal((B, 8, L, D)) # 8 KV heads + v = mx.random.normal((B, 8, L, D)) + scale = 1.0 / math.sqrt(D) + + kernel_output = qwen3_custom_gqa_attention(q, k, v, scale=scale, mask="causal") + print(f"✅ Direct kernel output shape: {kernel_output.shape}") + + kernel_mean = float(mx.mean(kernel_output)) + kernel_std = float(mx.std(kernel_output)) + print(f"✅ Direct kernel stats - Mean: {kernel_mean:.6f}, Std: {kernel_std:.6f}") + + return True + + +if __name__ == "__main__": + print("Custom Metal Kernel Qwen3 GQA Optimization") + print("=" * 70) + + # Test correctness first + test_metal_gqa_correctness() + + print("\n") + + # Benchmark performance + benchmark_metal_gqa_optimization() + + print("\n" + "=" * 70) + print("Ready for Metal Kernel Evolution") + print("Evolution focus:") + print("1. 🔧 Metal kernel source code optimization") + print("2. 💾 Memory access pattern improvements for Apple Silicon") + print("3. 🎯 GQA-specific optimizations for 16:8 head ratio") + print("4. ⚡ Vectorization and SIMD optimization") + print("5. 🚀 Thread group and grid configuration tuning") + print("Target: 5-15% performance improvement through Metal kernel innovation") + print("=" * 70) diff --git a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program_info.json b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program_info.json new file mode 100644 index 000000000..288aa113d --- /dev/null +++ b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program_info.json @@ -0,0 +1,195 @@ +{ + "id": "2e9e8c13-2259-4908-92fb-a7b58d588abf", + "generation": 2, + "iteration": 23, + "timestamp": 1767618428.986909, + "parent_id": "2eada941-048f-4828-b81c-786434f63240", + "metrics": { + "success": true, + "final_score": 2.9589006960980697, + "combined_score": 2.9589006960980697, + "performance_metrics": { + "avg_decode_speed": 53.41525, + "min_decode_speed": 46.023, + "max_decode_speed": 63.108, + "avg_prefill_speed": 244.14375, + "avg_memory_gb": 1.39325, + "max_memory_gb": 1.554, + "num_successful_tests": 4, + "decode_speed_std": 7.40971879948895 + }, + "correctness_score": 1.0, + "benchmark_results": [ + { + "name": "short_context_quick", + "decode_tokens_per_sec": 63.108, + "prefill_tokens_per_sec": 135.605, + "peak_memory_gb": 1.35, + "generated_tokens": 50, + "total_time_sec": 4.579478166997433 + }, + { + "name": "code_generation", + "decode_tokens_per_sec": 58.112, + "prefill_tokens_per_sec": 355.6, + "peak_memory_gb": 1.291, + "generated_tokens": 300, + "total_time_sec": 9.802399624997634 + }, + { + "name": "long_context_detailed", + "decode_tokens_per_sec": 46.023, + "prefill_tokens_per_sec": 303.484, + "peak_memory_gb": 1.554, + "generated_tokens": 500, + "total_time_sec": 17.30625654100004 + }, + { + "name": "long_generation", + "decode_tokens_per_sec": 46.418, + "prefill_tokens_per_sec": 181.886, + "peak_memory_gb": 1.378, + "generated_tokens": 1000, + "total_time_sec": 27.308918209004332 + } + ], + "baseline_comparison": { + "avg_decode_improvement_pct": -3.199533101300643, + "avg_decode_improvement_absolute": -1.6237500000000011, + "memory_change_gb": -0.005749999999999922, + "target_achieved": false, + "num_benchmarks_improved": 1, + "total_benchmarks": 4, + "safety_score": 100.0 + }, + "individual_comparisons": [ + { + "benchmark_name": "short_context_quick", + "baseline": { + "name": "short_context_quick", + "decode_tokens_per_sec": 59.061, + "prefill_tokens_per_sec": 48.778, + "peak_memory_gb": 1.416, + "generated_tokens": 50, + "total_time_sec": 4.753752667005756 + }, + "custom": { + "name": "short_context_quick", + "decode_tokens_per_sec": 63.108, + "prefill_tokens_per_sec": 135.605, + "peak_memory_gb": 1.35, + "generated_tokens": 50, + "total_time_sec": 4.579478166997433 + }, + "improvements": { + "decode_speed_pct": 6.852237517143288, + "prefill_speed_pct": 178.00442822583952, + "total_speed_pct": 3.80555368216958, + "memory_reduction_pct": 4.888888888888877, + "time_reduction_pct": 3.80555368216959 + } + }, + { + "benchmark_name": "code_generation", + "baseline": { + "name": "code_generation", + "decode_tokens_per_sec": 58.335, + "prefill_tokens_per_sec": 214.398, + "peak_memory_gb": 1.277, + "generated_tokens": 300, + "total_time_sec": 9.370090207994508 + }, + "custom": { + "name": "code_generation", + "decode_tokens_per_sec": 58.112, + "prefill_tokens_per_sec": 355.6, + "peak_memory_gb": 1.291, + "generated_tokens": 300, + "total_time_sec": 9.802399624997634 + }, + "improvements": { + "decode_speed_pct": -0.382274792148794, + "prefill_speed_pct": 65.85975615444175, + "total_speed_pct": -4.41024069147997, + "memory_reduction_pct": -1.0844306738962055, + "time_reduction_pct": -4.4102406914799674 + } + }, + { + "benchmark_name": "long_context_detailed", + "baseline": { + "name": "long_context_detailed", + "decode_tokens_per_sec": 54.724, + "prefill_tokens_per_sec": 853.686, + "peak_memory_gb": 1.543, + "generated_tokens": 500, + "total_time_sec": 13.029723457999353 + }, + "custom": { + "name": "long_context_detailed", + "decode_tokens_per_sec": 46.023, + "prefill_tokens_per_sec": 303.484, + "peak_memory_gb": 1.554, + "generated_tokens": 500, + "total_time_sec": 17.30625654100004 + }, + "improvements": { + "decode_speed_pct": -15.899788027190983, + "prefill_speed_pct": -64.45016083196866, + "total_speed_pct": -24.710907716346416, + "memory_reduction_pct": -0.7078507078507156, + "time_reduction_pct": -24.710907716346423 + } + }, + { + "benchmark_name": "long_generation", + "baseline": { + "name": "long_generation", + "decode_tokens_per_sec": 48.036, + "prefill_tokens_per_sec": 164.979, + "peak_memory_gb": 1.36, + "generated_tokens": 1000, + "total_time_sec": 26.132775457997923 + }, + "custom": { + "name": "long_generation", + "decode_tokens_per_sec": 46.418, + "prefill_tokens_per_sec": 181.886, + "peak_memory_gb": 1.378, + "generated_tokens": 1000, + "total_time_sec": 27.308918209004332 + }, + "improvements": { + "decode_speed_pct": -3.368307103006083, + "prefill_speed_pct": 10.247970953878967, + "total_speed_pct": -4.306808281474169, + "memory_reduction_pct": -1.3062409288824235, + "time_reduction_pct": -4.306808281474182 + } + } + ], + "summary": "Bulletproof Custom GQA Implementation Results:\n\u2022 Decode Speed: 53.4 tokens/sec (baseline: 55.0)\n\u2022 Improvement: -3.2%\n\u2022 Memory Usage: 1.39 GB\n\u2022 Correctness: 100.0%\n\u2022 Safety Score: 100.0/100\n\u2022 Tests Passed: 4/4\n\u2022 Benchmarks Improved: 1/4\n\u2022 Metal Errors Handled: 0\n\ud83d\udee1\ufe0f PERFECT SAFETY: No Metal kernel errors\n\u26a0\ufe0f NO IMPROVEMENT: Performance regression", + "metal_safety_statistics": { + "metal_command_buffer_errors": 0, + "metal_memory_violations": 0, + "metal_compilation_errors": 0, + "gpu_resource_errors": 0, + "total_metal_errors": 0, + "successful_fallbacks": 0, + "retry_attempts_used": 0, + "safety_score": 100.0, + "error_breakdown": { + "command_buffer_pct": 0.0, + "memory_violation_pct": 0.0, + "compilation_error_pct": 0.0, + "resource_error_pct": 0.0 + } + }, + "safety_validation": { + "success": true, + "validated": true + } + }, + "language": "python", + "saved_at": 1767619206.585282 +} \ No newline at end of file diff --git a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml new file mode 100644 index 000000000..1d159a6eb --- /dev/null +++ b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml @@ -0,0 +1,114 @@ +max_iterations: 25 +checkpoint_interval: 5 +log_level: INFO +llm: + primary_model: gemini-2.5-flash + primary_model_weight: 0.6 + secondary_model: gemini-2.5-pro + secondary_model_weight: 0.4 + api_base: https://generativelanguage.googleapis.com/v1beta/openai/ + temperature: 0.6 + top_p: 0.95 + max_tokens: 32000 + timeout: 900 +prompt: + system_message: "You are an expert Metal GPU programmer specializing in custom attention\ + \ kernels for Apple Silicon.\n\n# TARGET: Optimize Metal Kernel for Qwen3 Grouped\ + \ Query Attention (GQA)\n# HARDWARE: Apple M-series GPUs with unified memory architecture\n\ + # BASELINE: Standard MLX scaled_dot_product_attention\n# ARCHITECTURE: 16 query\ + \ heads : 8 KV heads (2:1 ratio), 128 head dimension\n# GOAL: 5-15% performance\ + \ improvement through Metal kernel optimization\n\n# CURRENT METAL KERNEL STRUCTURE:\n\ + ```metal\nkernel void qwen3_gqa_attention_kernel() {\n // Thread mapping: each\ + \ thread handles one query position\n uint query_pos = thread_position_in_grid.x;\n\ + \ uint head_idx = thread_position_in_grid.y; \n uint batch_idx = thread_position_in_grid.z;\n\ + \ \n // GQA mapping: 2 query heads per KV head\n uint kv_head_idx = head_idx\ + \ / HEADS_PER_KV;\n \n // Current algorithm:\n // 1. Load query vector\n\ + \ // 2. First pass: compute scores and find max\n // 3. Second pass: compute\ + \ softmax denominator \n // 4. Third pass: compute weighted value sum\n}\n\ + ```\n\n# OPTIMIZATION OPPORTUNITIES IN THE EVOLVE-BLOCK:\n\n**1. Memory Access\ + \ Pattern Optimization:**\n```metal\n// CURRENT: Linear memory access\n// OPTIMIZE:\ + \ Coalesced access patterns for Apple Silicon\n\n// Example: Vectorized loading\n\ + for (uint d = 0; d < HEAD_DIM; d += 4) {\n // Load 4 elements at once using\ + \ SIMD\n query_vec[d] = queries[q_base + d];\n query_vec[d+1] = queries[q_base\ + \ + d+1];\n query_vec[d+2] = queries[q_base + d+2]; \n query_vec[d+3] =\ + \ queries[q_base + d+3];\n}\n\n// Example: Pre-compute and cache frequently used\ + \ indices\n```\n\n**2. Computation Algorithm Optimization:**\n```metal\n// CURRENT:\ + \ 3-pass attention (find max, softmax, weighted sum)\n// OPTIMIZE: Fused operations,\ + \ online algorithms\n\n// Example: Online softmax to reduce passes\n// Example:\ + \ Fused score computation and max finding\n// Example: Reduce redundant index\ + \ calculations\n```\n\n**3. GQA-Specific Optimizations:**\n```metal\n// CURRENT:\ + \ Basic kv_head_idx = head_idx / HEADS_PER_KV\n// OPTIMIZE: Leverage the specific\ + \ 2:1 ratio pattern\n\n// Example: Process 2 query heads together for each KV\ + \ head\n// Example: Optimize memory layout for the 16:8 pattern\n// Example: Reduce\ + \ broadcast overhead through clever indexing\n```\n\n**4. Apple Silicon Specific\ + \ Features:**\n```metal\n// OPTIMIZE: Use Apple GPU specific capabilities\n\n\ + // Example: Leverage unified memory bandwidth patterns\n// Example: Optimize for\ + \ Apple's SIMD group sizes (32 threads)\n// Example: Use native half-precision\ + \ operations efficiently\n// Example: Minimize memory allocation overhead\n```\n\ + \n**5. Vectorization and SIMD:**\n```metal\n// CURRENT: Scalar operations with\ + \ some vectorization\n// OPTIMIZE: Full SIMD utilization\n\n// Example: Process\ + \ multiple elements simultaneously\nfor (uint d = 0; d < HEAD_DIM; d += 8) {\n\ + \ // Process 8 elements at once\n // Use Metal's built-in vector operations\n\ + }\n\n// Example: Vectorized dot products and accumulation\n```\n\n**6. Thread\ + \ Group and Memory Hierarchy:**\n```metal\n// OPTIMIZE: Better utilize Apple GPU\ + \ memory hierarchy\n\n// Example: Use threadgroup memory for data sharing\nthreadgroup\ + \ T shared_data[SHARED_SIZE];\n\n// Example: Optimize thread cooperation patterns\n\ + // Example: Balance register usage vs memory bandwidth\n```\n\n**7. Numerical\ + \ Stability and Precision:**\n```metal\n// OPTIMIZE: Maintain accuracy while improving\ + \ performance\n\n// Example: More efficient max finding\n// Example: Optimized\ + \ exp() computation for softmax\n// Example: Better handling of edge cases\n```\n\ + \n# EVOLUTION CONSTRAINTS - CRITICAL SAFETY RULES:\n\n**MUST NOT CHANGE:**\n\u274C\ + \ Kernel function signature or input/output specifications\n\u274C Template parameter\ + \ names or types (T, BATCH_SIZE, NUM_HEADS, etc.)\n\u274C Overall algorithm correctness\ + \ (must compute same attention result)\n\u274C Thread grid mapping (thread_position_in_grid\ + \ usage)\n\u274C Bounds checking logic (batch_idx >= BATCH_SIZE checks)\n\u274C\ + \ Output tensor shapes or semantics\n\n**ALLOWED TO OPTIMIZE:**\n\u2705 Memory\ + \ access patterns and indexing within the kernel\n\u2705 Computation order and\ + \ algorithm efficiency\n\u2705 Vectorization and SIMD utilization\n\u2705 Loop\ + \ structures and data processing patterns\n\u2705 Variable declarations and data\ + \ types within kernel\n\u2705 Mathematical operations and optimizations\n\u2705\ + \ GQA-specific computation strategies\n\u2705 Apple Silicon specific optimizations\n\ + \n**METAL SYNTAX REQUIREMENTS:**\n- Use proper Metal C++ syntax\n- Maintain variable\ + \ type consistency (T for tensor element type)\n- Keep proper array indexing (no\ + \ out-of-bounds access)\n- Use valid Metal built-in functions and operations\n\ + - Ensure thread safety and proper synchronization\n\n# SPECIFIC OPTIMIZATION STRATEGIES\ + \ TO TRY:\n\n**Strategy 1: Enhanced Vectorization**\n```metal\n// Replace scalar\ + \ operations with SIMD vector operations\n// Process 4 or 8 elements simultaneously\n\ + // Use Metal's built-in vector math functions\n```\n\n**Strategy 2: Memory Access\ + \ Optimization**\n```metal\n// Reorganize memory access for better coalescing\n\ + // Pre-compute base indices once\n// Cache frequently accessed values in registers\n\ + // Minimize redundant address calculations\n```\n\n**Strategy 3: Algorithm Fusion**\n\ + ```metal\n// Combine max finding with score computation\n// Fuse exp() computation\ + \ with accumulation\n// Reduce the number of passes through data\n```\n\n**Strategy\ + \ 4: GQA Pattern Exploitation**\n```metal\n// Optimize for the specific 2:1 query:KV\ + \ ratio\n// Process query heads in groups of 2\n// Reduce KV head indexing overhead\n\ + ```\n\n**Strategy 5: Apple Silicon Specialization**\n```metal\n// Use optimal\ + \ thread group sizes for Apple GPUs\n// Leverage unified memory architecture\n\ + // Optimize for Apple's specific SIMD characteristics\n```\n\n# SUCCESS CRITERIA:\n\ + - **Compilation**: Metal kernel must compile without syntax errors\n- **Correctness**:\ + \ Output must match MLX baseline (within float precision)\n- **Performance**:\ + \ Target 5-15% improvement in attention computation time\n- **Memory**: Similar\ + \ or better memory usage compared to baseline\n- **Stability**: No crashes, undefined\ + \ behavior, or numerical instability\n\n# IMPORTANT NOTES:\n- Focus ONLY on optimizing\ + \ the Metal kernel source code in the EVOLVE-BLOCK\n- The kernel will be compiled\ + \ using mx.fast.metal_kernel() automatically\n- Maintain the exact same attention\ + \ computation semantics\n- Test with Qwen3-0.6B's specific 16:8 head configuration\n\ + - Leverage Apple Silicon's unified memory and SIMD capabilities\n\nYour goal is\ + \ to discover Metal kernel optimizations that outperform MLX's \nalready highly-optimized\ + \ scaled_dot_product_attention implementation.\n" + num_top_programs: 3 + num_diverse_programs: 2 +database: + db_path: /Users/mogu/Library/Mobile Documents/com~apple~CloudDocs/Personal/study/research/kernelbench-openevolve/openevolve/examples/mlx_metal_kernel_opt/openevolve_output_20260105_180918/qwen3_metal_kernel_evolution + population_size: 25 + archive_size: 12 + num_islands: 3 + elite_selection_ratio: 0.3 + exploitation_ratio: 0.65 + exploration_ratio: 0.35 +evaluator: + timeout: 900 + parallel_evaluations: 1 +diff_based_evolution: true +allow_full_rewrites: false +max_code_length: 60000 From 69d2dabaefc4656814d4b29eda1b836ccc2defde Mon Sep 17 00:00:00 2001 From: lanmogu98 <116992711+lanmogu98@users.noreply.github.com> Date: Fri, 9 Jan 2026 00:08:08 +0800 Subject: [PATCH 14/14] docs(mlx_metal_kernel_opt): keep demo best program snapshot - Commit best_program.py and best_program_info.json at example root - Git-ignore demo/output dirs; remove demo_output_20260105_180918 --- .gitignore | 2 + examples/mlx_metal_kernel_opt/README.md | 11 +- .../best => }/best_program.py | 0 .../best => }/best_program_info.json | 0 .../demo_output_20260105_180918/config.yaml | 114 ------------------ 5 files changed, 6 insertions(+), 121 deletions(-) rename examples/mlx_metal_kernel_opt/{demo_output_20260105_180918/best => }/best_program.py (100%) rename examples/mlx_metal_kernel_opt/{demo_output_20260105_180918/best => }/best_program_info.json (100%) delete mode 100644 examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml diff --git a/.gitignore b/.gitignore index 3c72fcea4..338d65a21 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,8 @@ ENV/ examples/*/output/ openevolve_output*/ *.log +demo_output*/ +pr_sanity*/ # Test cache .pytest_cache/ diff --git a/examples/mlx_metal_kernel_opt/README.md b/examples/mlx_metal_kernel_opt/README.md index 856a69686..e7fae54c6 100644 --- a/examples/mlx_metal_kernel_opt/README.md +++ b/examples/mlx_metal_kernel_opt/README.md @@ -85,15 +85,12 @@ For detailed experiment results and analysis, see [EVOLUTION_ANALYSIS.md](./EVOL ### Demo Results (Committed) -For review and reproducibility, this example repo includes one demo run output directory (curated subset of artifacts): +For review and reproducibility, this example repo includes a committed snapshot of one post-fix evolution run: -- `demo_output_20260105_180918/` +- `best_program.py`: best evolved program (iteration 23) +- `best_program_info.json`: metrics + baseline comparisons (includes the -3.2% result) -The key artifacts in that folder are: - -- `best/best_program.py`: best evolved program (iteration 23) -- `best/best_program_info.json`: metrics and baseline comparisons (includes the -3.2% result) -- `run.log` and `logs/openevolve_20260105_180918.log`: full run logs +The full run output directory is intentionally git-ignored (see `.gitignore`) to avoid committing large run artifacts. ### Known Limitations diff --git a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program.py b/examples/mlx_metal_kernel_opt/best_program.py similarity index 100% rename from examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program.py rename to examples/mlx_metal_kernel_opt/best_program.py diff --git a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program_info.json b/examples/mlx_metal_kernel_opt/best_program_info.json similarity index 100% rename from examples/mlx_metal_kernel_opt/demo_output_20260105_180918/best/best_program_info.json rename to examples/mlx_metal_kernel_opt/best_program_info.json diff --git a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml b/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml deleted file mode 100644 index 1d159a6eb..000000000 --- a/examples/mlx_metal_kernel_opt/demo_output_20260105_180918/config.yaml +++ /dev/null @@ -1,114 +0,0 @@ -max_iterations: 25 -checkpoint_interval: 5 -log_level: INFO -llm: - primary_model: gemini-2.5-flash - primary_model_weight: 0.6 - secondary_model: gemini-2.5-pro - secondary_model_weight: 0.4 - api_base: https://generativelanguage.googleapis.com/v1beta/openai/ - temperature: 0.6 - top_p: 0.95 - max_tokens: 32000 - timeout: 900 -prompt: - system_message: "You are an expert Metal GPU programmer specializing in custom attention\ - \ kernels for Apple Silicon.\n\n# TARGET: Optimize Metal Kernel for Qwen3 Grouped\ - \ Query Attention (GQA)\n# HARDWARE: Apple M-series GPUs with unified memory architecture\n\ - # BASELINE: Standard MLX scaled_dot_product_attention\n# ARCHITECTURE: 16 query\ - \ heads : 8 KV heads (2:1 ratio), 128 head dimension\n# GOAL: 5-15% performance\ - \ improvement through Metal kernel optimization\n\n# CURRENT METAL KERNEL STRUCTURE:\n\ - ```metal\nkernel void qwen3_gqa_attention_kernel() {\n // Thread mapping: each\ - \ thread handles one query position\n uint query_pos = thread_position_in_grid.x;\n\ - \ uint head_idx = thread_position_in_grid.y; \n uint batch_idx = thread_position_in_grid.z;\n\ - \ \n // GQA mapping: 2 query heads per KV head\n uint kv_head_idx = head_idx\ - \ / HEADS_PER_KV;\n \n // Current algorithm:\n // 1. Load query vector\n\ - \ // 2. First pass: compute scores and find max\n // 3. Second pass: compute\ - \ softmax denominator \n // 4. Third pass: compute weighted value sum\n}\n\ - ```\n\n# OPTIMIZATION OPPORTUNITIES IN THE EVOLVE-BLOCK:\n\n**1. Memory Access\ - \ Pattern Optimization:**\n```metal\n// CURRENT: Linear memory access\n// OPTIMIZE:\ - \ Coalesced access patterns for Apple Silicon\n\n// Example: Vectorized loading\n\ - for (uint d = 0; d < HEAD_DIM; d += 4) {\n // Load 4 elements at once using\ - \ SIMD\n query_vec[d] = queries[q_base + d];\n query_vec[d+1] = queries[q_base\ - \ + d+1];\n query_vec[d+2] = queries[q_base + d+2]; \n query_vec[d+3] =\ - \ queries[q_base + d+3];\n}\n\n// Example: Pre-compute and cache frequently used\ - \ indices\n```\n\n**2. Computation Algorithm Optimization:**\n```metal\n// CURRENT:\ - \ 3-pass attention (find max, softmax, weighted sum)\n// OPTIMIZE: Fused operations,\ - \ online algorithms\n\n// Example: Online softmax to reduce passes\n// Example:\ - \ Fused score computation and max finding\n// Example: Reduce redundant index\ - \ calculations\n```\n\n**3. GQA-Specific Optimizations:**\n```metal\n// CURRENT:\ - \ Basic kv_head_idx = head_idx / HEADS_PER_KV\n// OPTIMIZE: Leverage the specific\ - \ 2:1 ratio pattern\n\n// Example: Process 2 query heads together for each KV\ - \ head\n// Example: Optimize memory layout for the 16:8 pattern\n// Example: Reduce\ - \ broadcast overhead through clever indexing\n```\n\n**4. Apple Silicon Specific\ - \ Features:**\n```metal\n// OPTIMIZE: Use Apple GPU specific capabilities\n\n\ - // Example: Leverage unified memory bandwidth patterns\n// Example: Optimize for\ - \ Apple's SIMD group sizes (32 threads)\n// Example: Use native half-precision\ - \ operations efficiently\n// Example: Minimize memory allocation overhead\n```\n\ - \n**5. Vectorization and SIMD:**\n```metal\n// CURRENT: Scalar operations with\ - \ some vectorization\n// OPTIMIZE: Full SIMD utilization\n\n// Example: Process\ - \ multiple elements simultaneously\nfor (uint d = 0; d < HEAD_DIM; d += 8) {\n\ - \ // Process 8 elements at once\n // Use Metal's built-in vector operations\n\ - }\n\n// Example: Vectorized dot products and accumulation\n```\n\n**6. Thread\ - \ Group and Memory Hierarchy:**\n```metal\n// OPTIMIZE: Better utilize Apple GPU\ - \ memory hierarchy\n\n// Example: Use threadgroup memory for data sharing\nthreadgroup\ - \ T shared_data[SHARED_SIZE];\n\n// Example: Optimize thread cooperation patterns\n\ - // Example: Balance register usage vs memory bandwidth\n```\n\n**7. Numerical\ - \ Stability and Precision:**\n```metal\n// OPTIMIZE: Maintain accuracy while improving\ - \ performance\n\n// Example: More efficient max finding\n// Example: Optimized\ - \ exp() computation for softmax\n// Example: Better handling of edge cases\n```\n\ - \n# EVOLUTION CONSTRAINTS - CRITICAL SAFETY RULES:\n\n**MUST NOT CHANGE:**\n\u274C\ - \ Kernel function signature or input/output specifications\n\u274C Template parameter\ - \ names or types (T, BATCH_SIZE, NUM_HEADS, etc.)\n\u274C Overall algorithm correctness\ - \ (must compute same attention result)\n\u274C Thread grid mapping (thread_position_in_grid\ - \ usage)\n\u274C Bounds checking logic (batch_idx >= BATCH_SIZE checks)\n\u274C\ - \ Output tensor shapes or semantics\n\n**ALLOWED TO OPTIMIZE:**\n\u2705 Memory\ - \ access patterns and indexing within the kernel\n\u2705 Computation order and\ - \ algorithm efficiency\n\u2705 Vectorization and SIMD utilization\n\u2705 Loop\ - \ structures and data processing patterns\n\u2705 Variable declarations and data\ - \ types within kernel\n\u2705 Mathematical operations and optimizations\n\u2705\ - \ GQA-specific computation strategies\n\u2705 Apple Silicon specific optimizations\n\ - \n**METAL SYNTAX REQUIREMENTS:**\n- Use proper Metal C++ syntax\n- Maintain variable\ - \ type consistency (T for tensor element type)\n- Keep proper array indexing (no\ - \ out-of-bounds access)\n- Use valid Metal built-in functions and operations\n\ - - Ensure thread safety and proper synchronization\n\n# SPECIFIC OPTIMIZATION STRATEGIES\ - \ TO TRY:\n\n**Strategy 1: Enhanced Vectorization**\n```metal\n// Replace scalar\ - \ operations with SIMD vector operations\n// Process 4 or 8 elements simultaneously\n\ - // Use Metal's built-in vector math functions\n```\n\n**Strategy 2: Memory Access\ - \ Optimization**\n```metal\n// Reorganize memory access for better coalescing\n\ - // Pre-compute base indices once\n// Cache frequently accessed values in registers\n\ - // Minimize redundant address calculations\n```\n\n**Strategy 3: Algorithm Fusion**\n\ - ```metal\n// Combine max finding with score computation\n// Fuse exp() computation\ - \ with accumulation\n// Reduce the number of passes through data\n```\n\n**Strategy\ - \ 4: GQA Pattern Exploitation**\n```metal\n// Optimize for the specific 2:1 query:KV\ - \ ratio\n// Process query heads in groups of 2\n// Reduce KV head indexing overhead\n\ - ```\n\n**Strategy 5: Apple Silicon Specialization**\n```metal\n// Use optimal\ - \ thread group sizes for Apple GPUs\n// Leverage unified memory architecture\n\ - // Optimize for Apple's specific SIMD characteristics\n```\n\n# SUCCESS CRITERIA:\n\ - - **Compilation**: Metal kernel must compile without syntax errors\n- **Correctness**:\ - \ Output must match MLX baseline (within float precision)\n- **Performance**:\ - \ Target 5-15% improvement in attention computation time\n- **Memory**: Similar\ - \ or better memory usage compared to baseline\n- **Stability**: No crashes, undefined\ - \ behavior, or numerical instability\n\n# IMPORTANT NOTES:\n- Focus ONLY on optimizing\ - \ the Metal kernel source code in the EVOLVE-BLOCK\n- The kernel will be compiled\ - \ using mx.fast.metal_kernel() automatically\n- Maintain the exact same attention\ - \ computation semantics\n- Test with Qwen3-0.6B's specific 16:8 head configuration\n\ - - Leverage Apple Silicon's unified memory and SIMD capabilities\n\nYour goal is\ - \ to discover Metal kernel optimizations that outperform MLX's \nalready highly-optimized\ - \ scaled_dot_product_attention implementation.\n" - num_top_programs: 3 - num_diverse_programs: 2 -database: - db_path: /Users/mogu/Library/Mobile Documents/com~apple~CloudDocs/Personal/study/research/kernelbench-openevolve/openevolve/examples/mlx_metal_kernel_opt/openevolve_output_20260105_180918/qwen3_metal_kernel_evolution - population_size: 25 - archive_size: 12 - num_islands: 3 - elite_selection_ratio: 0.3 - exploitation_ratio: 0.65 - exploration_ratio: 0.35 -evaluator: - timeout: 900 - parallel_evaluations: 1 -diff_based_evolution: true -allow_full_rewrites: false -max_code_length: 60000