diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp index 2b8eee71a9ec..fda222d34b24 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp @@ -452,6 +452,29 @@ enumerateMatmulTileRiscv64(TypeRange elementTypes, DictionaryAttr config) { }; } } + // This adds support for our s8*sx`8->s32 kernel. + if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) && + out.isSignlessInteger(32)) { + // This logic follows the f32 case, as both use 32-bit accumulators. + // For your +zvl512b target, vlen = 512. + // N0 = (512 bits / 32 bits_per_element) * 4_LMUL = 64 elements. + + // N0 for LMUL=8 path (M0=16) + int N0_lmul8 = vlen / 4; + // N0 for LMUL=4 path (M0=8, 4, 2, 1) + int N0_lmul4 = vlen / 8; + + return { + // --- LMUL=8 Path --- + TileMxNxK{16, N0_lmul8, 1}, // Target tile for s8s8s32 (LMUL=8) + + // --- LMUL=4 Paths --- + TileMxNxK{8, N0_lmul4, 1}, // Truncation (LMUL=4) + TileMxNxK{4, N0_lmul4, 1}, // Truncation (LMUL=4) + TileMxNxK{2, N0_lmul4, 1}, // Truncation (LMUL=4) + TileMxNxK{1, N0_lmul4, 1}, // Truncation (vecmat) (LMUL=4) + }; + } // Fallback - no architecture-optimized tile size for this case. return {}; } diff --git a/experimental/iree_evo/__init__.py b/experimental/iree_evo/__init__.py new file mode 100644 index 000000000000..3050b67ea986 --- /dev/null +++ b/experimental/iree_evo/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""IREE Evolutionary Optimization Package. + +This package provides an autonomous optimization system for the IREE compiler +using Evolutionary Strategies via the `openevolve` library. + +The system optimizes compilation flags and Transform Dialect scripts, +specifically for Integer-Only Requantization (fusing float dequant/requant +operations into integer math). +""" + +from .knowledge_base import KnowledgeBase +from .slicer import MLIRSlicer +from .verification import LitGen +from .evaluator import IREEEvaluator, CompilationError, OpenEvolveCompatibleEvaluator +from .prompts import PLANNER_PROMPT, CODER_PROMPT + +__all__ = [ + "KnowledgeBase", + "MLIRSlicer", + "LitGen", + "IREEEvaluator", + "OpenEvolveCompatibleEvaluator", + "CompilationError", + "PLANNER_PROMPT", + "CODER_PROMPT", +] diff --git a/experimental/iree_evo/evaluator.py b/experimental/iree_evo/evaluator.py new file mode 100644 index 000000000000..bdb7105cde9d --- /dev/null +++ b/experimental/iree_evo/evaluator.py @@ -0,0 +1,513 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""IREE Evaluator for OpenEvolve integration. + +This module provides the evaluator class that interfaces with OpenEvolve +for evolutionary optimization of IREE compiler configurations. +""" + +import os +import re +import subprocess +import tempfile +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from .slicer import MLIRSlicer +from .verification import LitGen + + +class CompilationError(Exception): + """Exception raised when compilation fails. + + Attributes: + message: The error message. + debug_info: Optional debug information from the compiler. + stderr: The raw stderr output from the compiler. + """ + + def __init__( + self, + message: str, + debug_info: Optional[str] = None, + stderr: Optional[str] = None, + ): + super().__init__(message) + self.message = message + self.debug_info = debug_info + self.stderr = stderr + + def __str__(self): + parts = [self.message] + if self.debug_info: + parts.append(f"\nDebug Info:\n{self.debug_info}") + return "\n".join(parts) + + +class IREEEvaluator: + """Evaluator for IREE compiler optimization using OpenEvolve. + + This class implements the evaluation logic for evolutionary optimization + of IREE compiler configurations. It inherits from openevolve.Evaluator + when openevolve is available. + + Attributes: + baseline_mlir_path: Path to the baseline MLIR file. + target_backend: The target backend (e.g., 'llvm-cpu', 'cuda'). + work_dir: Working directory for intermediate files. + strategy: The optimization strategy name. + """ + + # Score constants for different outcomes + COMPILATION_FAILURE_SCORE = -100.0 + STRUCTURAL_FAILURE_SCORE = -50.0 + CORRECTNESS_FAILURE_SCORE = -10.0 + + def __init__( + self, + baseline_mlir_path: str, + target_backend: str = "llvm-cpu", + work_dir: Optional[str] = None, + strategy: str = "IntegerRequantization", + iree_compile_path: str = "iree-compile", + iree_run_module_path: str = "iree-run-module", + iree_benchmark_module_path: str = "iree-benchmark-module", + lit_executable: str = "llvm-lit", + baseline_inputs: Optional[List[str]] = None, + baseline_expected_outputs: Optional[np.ndarray] = None, + rtol: float = 1e-5, + atol: float = 1e-8, + ): + """Initializes the IREEEvaluator. + + Args: + baseline_mlir_path: Path to the baseline MLIR file. + target_backend: The target backend for compilation. + work_dir: Working directory for intermediate files. + strategy: The optimization strategy name. + iree_compile_path: Path to iree-compile executable. + iree_run_module_path: Path to iree-run-module executable. + iree_benchmark_module_path: Path to iree-benchmark-module executable. + lit_executable: Path to llvm-lit executable. + baseline_inputs: Optional list of input specifications for correctness. + baseline_expected_outputs: Optional expected outputs for correctness. + rtol: Relative tolerance for output comparison. + atol: Absolute tolerance for output comparison. + """ + if not os.path.exists(baseline_mlir_path): + raise FileNotFoundError( + f"Baseline MLIR file not found: {baseline_mlir_path}" + ) + + self.baseline_mlir_path = baseline_mlir_path + self.target_backend = target_backend + self.work_dir = work_dir or tempfile.mkdtemp(prefix="iree_evo_") + self.strategy = strategy + self.iree_compile_path = iree_compile_path + self.iree_run_module_path = iree_run_module_path + self.iree_benchmark_module_path = iree_benchmark_module_path + self.lit_executable = lit_executable + self.baseline_inputs = baseline_inputs or [] + self.baseline_expected_outputs = baseline_expected_outputs + self.rtol = rtol + self.atol = atol + + # Load baseline MLIR content + with open(baseline_mlir_path, "r") as f: + self.baseline_mlir_content = f.read() + + # Ensure work directory exists + os.makedirs(self.work_dir, exist_ok=True) + + def evaluate(self, individual: str) -> float: + """Evaluates an individual (compiler configuration). + + This method runs the full evaluation pipeline: + 1. Parse the individual string for flags/scripts + 2. Compile with the given configuration + 3. Verify structural correctness using LIT + 4. Verify numerical correctness + 5. Benchmark performance + + Args: + individual: A string containing compiler flags and/or transform script. + + Returns: + A fitness score (higher is better). Returns negative scores for failures: + -100: Compilation failure + -50: Structural verification failure + -10: Correctness failure + + Raises: + CompilationError: If compilation fails and error details are available. + """ + # Phase 1: Write individual to file + flags_path = os.path.join(self.work_dir, "flags.txt") + with open(flags_path, "w") as f: + f.write(individual) + + # Parse flags from individual + flags = self._parse_flags(individual) + + # Phase 2: Compile + try: + artifact_path = self._compile(flags) + except CompilationError as e: + # Re-raise with debug info for OpenEvolve feedback + raise + + # Phase 3: Structural Verification + if not self._verify_structure(): + return self.STRUCTURAL_FAILURE_SCORE + + # Phase 4: Correctness + if self.baseline_expected_outputs is not None: + if not self._verify_correctness(artifact_path): + return self.CORRECTNESS_FAILURE_SCORE + + # Phase 5: Benchmark + try: + mean_latency_ms = self._benchmark(artifact_path) + if mean_latency_ms <= 0: + return 0.0 + return 1000.0 / mean_latency_ms + except Exception: + # If benchmarking fails, return a neutral score + return 0.0 + + def _parse_flags(self, individual: str) -> List[str]: + """Parses compiler flags from the individual string. + + Args: + individual: The individual string containing flags. + + Returns: + A list of compiler flags. + """ + flags = [] + lines = individual.strip().split("\n") + for line in lines: + line = line.strip() + # Skip comments and empty lines + if not line or line.startswith("#") or line.startswith("//"): + continue + # Handle flags that may be on separate lines or space-separated + parts = line.split() + for part in parts: + if part.startswith("--") or part.startswith("-"): + flags.append(part) + return flags + + def _compile(self, flags: List[str]) -> str: + """Compiles the MLIR with given flags. + + Args: + flags: List of compiler flags. + + Returns: + Path to the compiled artifact. + + Raises: + CompilationError: If compilation fails. + """ + artifact_path = os.path.join(self.work_dir, "module.vmfb") + + # Build compile command + cmd = [ + self.iree_compile_path, + self.baseline_mlir_path, + f"--iree-hal-target-backends={self.target_backend}", + f"-o={artifact_path}", + ] + cmd.extend(flags) + + # First attempt: compile without debug output + result = subprocess.run( + cmd, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + # Compilation failed - get debug dump + debug_dump = self._get_debug_dump(flags) + error_summary = MLIRSlicer.parse_compilation_error(result.stderr) + + raise CompilationError( + message="Compilation failed", + debug_info=f"{error_summary}\n\n=== DEBUG DUMP ===\n{debug_dump}", + stderr=result.stderr, + ) + + return artifact_path + + def _get_debug_dump(self, flags: List[str]) -> str: + """Gets debug dump from a failing compilation. + + Args: + flags: The original compiler flags. + + Returns: + Debug dump output from the compiler. + """ + debug_artifact_path = os.path.join(self.work_dir, "debug_module.vmfb") + + cmd = [ + self.iree_compile_path, + self.baseline_mlir_path, + f"--iree-hal-target-backends={self.target_backend}", + f"-o={debug_artifact_path}", + "--mlir-print-ir-after-all", + "--mlir-print-op-on-diagnostic", + ] + cmd.extend(flags) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + ) + + # Return the combined output (debug info is in stderr), truncated if needed + return _truncate_output(result.stderr) + + def _verify_structure(self) -> bool: + """Verifies structural correctness using LIT tests. + + Returns: + True if structural verification passes, False otherwise. + """ + result = LitGen.create_and_run_test( + strategy=self.strategy, + original_mlir=self.baseline_mlir_content, + compile_flags=[f"--iree-hal-target-backends={self.target_backend}"], + work_dir=self.work_dir, + lit_executable=self.lit_executable, + ) + return result.get("passed", False) + + def _verify_correctness(self, artifact_path: str) -> bool: + """Verifies numerical correctness of the compiled module. + + Args: + artifact_path: Path to the compiled artifact. + + Returns: + True if outputs match expected values, False otherwise. + """ + if not os.path.exists(artifact_path): + return False + + if self.baseline_expected_outputs is None: + # No expected outputs provided, skip correctness check + return True + + # Build run command + cmd = [ + self.iree_run_module_path, + f"--module={artifact_path}", + ] + + # Add input specifications + for input_spec in self.baseline_inputs: + cmd.append(f"--input={input_spec}") + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode != 0: + return False + + # Parse output and compare + actual_outputs = self._parse_module_output(result.stdout) + if actual_outputs is None: + return False + + return np.allclose( + actual_outputs, + self.baseline_expected_outputs, + rtol=self.rtol, + atol=self.atol, + ) + + except subprocess.TimeoutExpired: + return False + except Exception: + return False + + def _parse_module_output(self, stdout: str) -> Optional[np.ndarray]: + """Parses module output from iree-run-module. + + Args: + stdout: The stdout from iree-run-module. + + Returns: + Parsed output as numpy array, or None if parsing fails. + """ + try: + # Pattern to match tensor output like "3x3xf32=[1 2 3][4 5 6][7 8 9]" + pattern = r"(\d+(?:x\d+)*x\w+)=\[([\d\s\.\-e\[\]]+)\]" + match = re.search(pattern, stdout) + if not match: + return None + + shape_str = match.group(1) + values_str = match.group(2) + + # Parse shape + parts = shape_str.rsplit("x", 1) + shape_parts = parts[0].split("x") + shape = tuple(int(s) for s in shape_parts) + + # Parse values (handle nested brackets) + values_str = values_str.replace("][", " ") + values_str = values_str.replace("[", " ").replace("]", " ") + values = [float(v) for v in values_str.split()] + + return np.array(values).reshape(shape) + + except Exception: + return None + + def _benchmark(self, artifact_path: str) -> float: + """Benchmarks the compiled module. + + Args: + artifact_path: Path to the compiled artifact. + + Returns: + Mean latency in milliseconds. + """ + if not os.path.exists(artifact_path): + return float("inf") + + cmd = [ + self.iree_benchmark_module_path, + f"--module={artifact_path}", + "--benchmark_repetitions=5", + ] + + # Add input specifications if available + for input_spec in self.baseline_inputs: + cmd.append(f"--input={input_spec}") + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=120, + ) + + if result.returncode != 0: + return float("inf") + + # Parse benchmark output for mean latency + return self._parse_benchmark_output(result.stdout) + + except subprocess.TimeoutExpired: + return float("inf") + except Exception: + return float("inf") + + def _parse_benchmark_output(self, stdout: str) -> float: + """Parses benchmark output for mean latency. + + Args: + stdout: The stdout from iree-benchmark-module. + + Returns: + Mean latency in milliseconds. + """ + try: + # Look for mean time in benchmark output + # Pattern matches lines like: "mean: 1.234 ms" + patterns = [ + r"mean[:\s]+(\d+\.?\d*)\s*ms", + r"BM_\w+\s+(\d+\.?\d*)\s*ms", + r"(\d+\.?\d*)\s*ms\s+mean", + ] + + for pattern in patterns: + match = re.search(pattern, stdout, re.IGNORECASE) + if match: + return float(match.group(1)) + + # Alternative: look for time values and compute mean + time_pattern = r"(\d+\.?\d*)\s*ms" + times = re.findall(time_pattern, stdout) + if times: + return sum(float(t) for t in times) / len(times) + + return float("inf") + + except Exception: + return float("inf") + + def get_baseline_summary(self) -> Dict[str, Any]: + """Returns a summary of the baseline MLIR for the LLM. + + Returns: + A dictionary containing the MLIR summary. + """ + return MLIRSlicer.extract_summary(self.baseline_mlir_content) + + def cleanup(self): + """Cleans up temporary files in the work directory.""" + import shutil + + if os.path.exists(self.work_dir) and self.work_dir.startswith( + tempfile.gettempdir() + ): + shutil.rmtree(self.work_dir, ignore_errors=True) + + +def _truncate_output(output: str, max_length: int = 10000) -> str: + """Truncates output string to a maximum length with ellipsis. + + Args: + output: The output string to truncate. + max_length: Maximum allowed length. + + Returns: + Truncated string if necessary, otherwise the original string. + """ + if len(output) <= max_length: + return output + half = max_length // 2 + truncated_chars = len(output) - max_length + return ( + output[:half] + + f"\n... [{truncated_chars} characters truncated] ...\n" + + output[-half:] + ) + + +# Try to inherit from openevolve.Evaluator if available +try: + from openevolve.evaluator import Evaluator as OpenEvolveEvaluator + + class IREEOpenEvolveEvaluator(IREEEvaluator, OpenEvolveEvaluator): + """IREE Evaluator with OpenEvolve integration. + + This class combines the IREEEvaluator functionality with the + OpenEvolve Evaluator base class for evolutionary optimization. + """ + + pass + + # Make it available as the preferred evaluator when openevolve is present + OpenEvolveCompatibleEvaluator = IREEOpenEvolveEvaluator + +except ImportError: + # OpenEvolve not installed, use standalone class + OpenEvolveCompatibleEvaluator = IREEEvaluator diff --git a/experimental/iree_evo/knowledge_base.py b/experimental/iree_evo/knowledge_base.py new file mode 100644 index 000000000000..87804a39335e --- /dev/null +++ b/experimental/iree_evo/knowledge_base.py @@ -0,0 +1,197 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Knowledge base for IREE compiler flags and MLIR Transform operations. + +This module provides information about valid IREE compilation flags +and MLIR Transform Dialect operations that can be used for optimization. +""" + +from typing import Dict, List + + +class KnowledgeBase: + """Knowledge base for IREE compiler flags and Transform Dialect operations. + + This class provides methods to retrieve valid compiler flags for different + backends and available Transform Dialect operations for code generation + optimization. + """ + + # Valid IREE flags for llvm-cpu backend relevant to quantization and tiling + _LLVM_CPU_FLAGS: List[str] = [ + # Quantization-related flags + "--iree-opt-data-tiling", + "--iree-opt-const-expr-hoisting", + "--iree-opt-const-eval", + "--iree-opt-numeric-precision-reduction", + # Tiling and vectorization flags + "--iree-llvmcpu-target-triple", + "--iree-llvmcpu-target-cpu", + "--iree-llvmcpu-target-cpu-features", + "--iree-llvmcpu-enable-ukernels", + # Code generation flags + "--iree-codegen-llvm-generic-ops-workgroup-size", + "--iree-llvmcpu-enable-pad-consumer-fusion", + # Transform dialect flags + "--iree-codegen-transform-dialect-library", + # Debug flags (useful for error analysis) + "--mlir-print-ir-after-all", + "--mlir-print-ir-before-all", + "--mlir-print-op-on-diagnostic", + # Global optimization flags + "--iree-global-opt-enable-fuse-horizontal-contractions", + "--iree-global-opt-enable-quantized-matmul-reassociation", + "--iree-opt-demote-f32-to-f16", + "--iree-opt-demote-f64-to-f32", + ] + + # Valid IREE flags for CUDA backend + _CUDA_FLAGS: List[str] = [ + # Target configuration + "--iree-hal-target-backends=cuda", + "--iree-hal-cuda-llvm-target-arch", + # Quantization flags + "--iree-opt-data-tiling", + "--iree-opt-const-expr-hoisting", + "--iree-opt-const-eval", + # Code generation flags + "--iree-codegen-llvmgpu-enable-transform-dialect-jit-default", + "--iree-codegen-transform-dialect-library", + # Debug flags + "--mlir-print-ir-after-all", + "--mlir-print-ir-before-all", + ] + + # Valid IREE flags for ROCM/AMD GPU backend + _ROCM_FLAGS: List[str] = [ + # Target configuration + "--iree-hal-target-backends=rocm", + "--iree-rocm-target-chip", + # Quantization flags (FP8 support) + "--iree-opt-data-tiling", + "--iree-opt-const-expr-hoisting", + # Code generation flags + "--iree-codegen-llvmgpu-enable-transform-dialect-jit-default", + # Debug flags + "--mlir-print-ir-after-all", + "--mlir-print-ir-before-all", + ] + + # Valid MLIR Transform Dialect operations + _TRANSFORM_OPS: List[str] = [ + # Structured transform ops for tiling + "transform.structured.tile_using_forall", + "transform.structured.tile_using_for", + "transform.structured.tile_reduction_using_for", + "transform.structured.tile_reduction_using_forall", + # Fusion operations + "transform.structured.fuse_into_containing_op", + "transform.structured.fuse", + # Vectorization operations + "transform.structured.vectorize", + "transform.structured.vectorize_children_and_apply_patterns", + # Pattern application + "transform.apply_patterns.canonicalization", + "transform.apply_patterns.linalg.tiling_canonicalization", + "transform.apply_patterns.iree.fold_fill_into_pad", + "transform.apply_patterns.scf.for_loop_canonicalization", + # CSE and cleanup + "transform.apply_cse", + "transform.structured.match", + # IREE-specific transforms + "transform.iree.populate_workgroup_count_region_using_num_threads_slice", + "transform.iree.bufferize", + "transform.iree.forall_to_workgroup", + # Loop transforms + "transform.loop.unroll", + "transform.loop.peel", + # Lowering transforms + "transform.iree.apply_lowering_strategy", + ] + + # Common flag values for different optimization strategies + _OPTIMIZATION_STRATEGIES: Dict[str, List[str]] = { + "IntegerRequantization": [ + "--iree-global-opt-enable-quantized-matmul-reassociation=true", + "--iree-opt-const-eval=true", + "--iree-opt-const-expr-hoisting=true", + ], + "Tiling": [ + "--iree-opt-data-tiling=true", + "--iree-llvmcpu-enable-pad-consumer-fusion=true", + ], + "Vectorization": [ + "--iree-opt-data-tiling=true", + ], + "QuantizationFusion": [ + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-quantized-matmul-reassociation=true", + ], + } + + @classmethod + def get_valid_flags(cls, backend: str) -> List[str]: + """Returns a list of valid IREE flags for a given backend. + + Args: + backend: The target backend (e.g., 'llvm-cpu', 'cuda', 'rocm'). + + Returns: + A list of valid compiler flags for the specified backend. + + Raises: + ValueError: If the backend is not recognized. + """ + backend_lower = backend.lower() + if backend_lower == "llvm-cpu": + return cls._LLVM_CPU_FLAGS.copy() + elif backend_lower == "cuda": + return cls._CUDA_FLAGS.copy() + elif backend_lower == "rocm": + return cls._ROCM_FLAGS.copy() + else: + raise ValueError( + f"Unknown backend: {backend}. " + f"Supported backends: llvm-cpu, cuda, rocm" + ) + + @classmethod + def get_transform_ops(cls) -> List[str]: + """Returns a list of valid MLIR Transform Dialect operations. + + Returns: + A list of available Transform Dialect operations. + """ + return cls._TRANSFORM_OPS.copy() + + @classmethod + def get_optimization_strategy_flags(cls, strategy: str) -> List[str]: + """Returns recommended flags for a specific optimization strategy. + + Args: + strategy: The optimization strategy name. + + Returns: + A list of recommended flags for the strategy. + + Raises: + ValueError: If the strategy is not recognized. + """ + if strategy not in cls._OPTIMIZATION_STRATEGIES: + raise ValueError( + f"Unknown strategy: {strategy}. " + f"Available strategies: {list(cls._OPTIMIZATION_STRATEGIES.keys())}" + ) + return cls._OPTIMIZATION_STRATEGIES[strategy].copy() + + @classmethod + def get_all_strategies(cls) -> List[str]: + """Returns all available optimization strategy names. + + Returns: + A list of available strategy names. + """ + return list(cls._OPTIMIZATION_STRATEGIES.keys()) diff --git a/experimental/iree_evo/main.py b/experimental/iree_evo/main.py new file mode 100644 index 000000000000..9f92037f4937 --- /dev/null +++ b/experimental/iree_evo/main.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Main entry point for IREE-EVO optimization system. + +This module initializes and runs the evolutionary optimization loop +for IREE compiler configurations using OpenEvolve. + +Usage: + python -m iree_evo.main --mlir-path --backend + +Example: + python -m iree_evo.main --mlir-path model.mlir --backend llvm-cpu +""" + +import argparse +import json +import logging +import os +import sys +from typing import Dict, Optional + +from .evaluator import IREEEvaluator, CompilationError +from .knowledge_base import KnowledgeBase +from .prompts import PLANNER_PROMPT, CODER_PROMPT, EVOLUTION_PROMPT +from .slicer import MLIRSlicer + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def mock_planner_llm(mlir_summary: Dict, backend: str) -> Dict: + """Mock planner LLM for testing without API access. + + This function simulates the planner LLM's decision-making process. + In production, this would be replaced with actual API calls. + + Args: + mlir_summary: Summary of the MLIR content. + backend: Target backend. + + Returns: + A dictionary containing the optimization strategy and constraints. + """ + # Check for quantization-related operations + quant_ops = mlir_summary.get("quantization_ops", []) + compute_ops = mlir_summary.get("compute_ops", []) + + # Determine strategy based on IR content + if any("sitofp" in op or "uitofp" in op for op in quant_ops): + strategy = "IntegerRequantization" + rationale = ( + "Detected float dequantization operations (sitofp/uitofp). " + "These can be converted to pure integer arithmetic for better performance." + ) + elif any("matmul" in op.lower() for op in compute_ops): + strategy = "QuantizationFusion" + rationale = ( + "Detected matmul operations with potential quantization. " + "Fusing quantization with compute can reduce memory bandwidth." + ) + elif any("conv" in op.lower() for op in compute_ops): + strategy = "Tiling" + rationale = ( + "Detected convolution operations. " + "Tiling can improve cache utilization for large convolutions." + ) + else: + strategy = "Vectorization" + rationale = ( + "No specific optimization pattern detected. " + "Applying vectorization for general performance improvement." + ) + + return { + "analysis": { + "compute_bottleneck": "quantization overhead" if quant_ops else "compute", + "memory_pattern": "sequential access", + "quantization_status": "float dequant present" if quant_ops else "none", + }, + "strategy": strategy, + "rationale": rationale, + "constraints": [ + "Preserve numerical precision", + "Ensure output matches baseline within tolerance", + ], + "priority_ops": compute_ops[:5] if compute_ops else ["all"], + } + + +def mock_coder_llm(strategy: str, backend: str, constraints: list) -> str: + """Mock coder LLM for testing without API access. + + This function generates compiler flags based on the strategy. + In production, this would be replaced with actual API calls. + + Args: + strategy: The optimization strategy. + backend: Target backend. + constraints: List of constraints from the planner. + + Returns: + A string containing compiler flags. + """ + base_flags = [ + f"# Optimization: {strategy}", + f"# Target: {backend}", + "", + ] + + # Get strategy-specific flags + try: + strategy_flags = KnowledgeBase.get_optimization_strategy_flags(strategy) + except ValueError: + strategy_flags = [] + + # Add backend-specific flags + if backend == "llvm-cpu": + base_flags.extend([ + "--iree-opt-const-eval=true", + "--iree-opt-const-expr-hoisting=true", + ]) + elif backend == "cuda": + base_flags.extend([ + "--iree-codegen-llvmgpu-enable-transform-dialect-jit-default=true", + ]) + + # Add strategy flags + base_flags.extend(strategy_flags) + + return "\n".join(base_flags) + + +def run_optimization( + mlir_path: str, + backend: str, + work_dir: Optional[str] = None, + max_iterations: int = 10, + use_openevolve: bool = True, +) -> Dict: + """Runs the IREE-EVO optimization loop. + + Args: + mlir_path: Path to the input MLIR file. + backend: Target backend. + work_dir: Working directory for intermediate files. + max_iterations: Maximum number of evolution iterations. + use_openevolve: Whether to use OpenEvolve (if available). + + Returns: + A dictionary containing the optimization results. + """ + logger.info(f"Starting IREE-EVO optimization for {mlir_path}") + logger.info(f"Target backend: {backend}") + + # Initialize evaluator + evaluator = IREEEvaluator( + baseline_mlir_path=mlir_path, + target_backend=backend, + work_dir=work_dir, + strategy="IntegerRequantization", # Will be updated by planner + ) + + # Step 1: Analyze MLIR and run planner + logger.info("Step 1: Analyzing MLIR and running planner...") + mlir_summary = evaluator.get_baseline_summary() + logger.info(f"MLIR Summary: {json.dumps(mlir_summary, indent=2)}") + + planner_result = mock_planner_llm(mlir_summary, backend) + logger.info(f"Planner decision: {planner_result['strategy']}") + logger.info(f"Rationale: {planner_result['rationale']}") + + # Update evaluator strategy + evaluator.strategy = planner_result["strategy"] + + # Step 2: Generate initial configuration + logger.info("Step 2: Generating initial configuration...") + initial_config = mock_coder_llm( + planner_result["strategy"], + backend, + planner_result["constraints"], + ) + logger.info(f"Initial configuration:\n{initial_config}") + + # Step 3: Run OpenEvolve loop (or mock loop) + best_config = initial_config + best_score = float("-inf") + + if use_openevolve: + try: + # Try to use OpenEvolve if available + from openevolve import Controller + + logger.info("Step 3: Starting OpenEvolve optimization loop...") + + controller = Controller( + evaluator=evaluator, + initial_population=[initial_config], + system_prompt=EVOLUTION_PROMPT, + max_iterations=max_iterations, + ) + + best_config, best_score = controller.run() + logger.info(f"OpenEvolve completed. Best score: {best_score}") + + except ImportError: + logger.warning( + "OpenEvolve not installed. Running mock optimization loop." + ) + use_openevolve = False + + if not use_openevolve: + # Mock optimization loop + logger.info("Step 3: Running mock optimization loop...") + + for iteration in range(max_iterations): + logger.info(f"Iteration {iteration + 1}/{max_iterations}") + + try: + score = evaluator.evaluate(best_config) + logger.info(f"Score: {score}") + + if score > best_score: + best_score = score + logger.info(f"New best score: {best_score}") + + # Simple mutation: toggle a flag + mutated_config = mutate_config(best_config, iteration) + try: + mutated_score = evaluator.evaluate(mutated_config) + if mutated_score > best_score: + best_config = mutated_config + best_score = mutated_score + logger.info(f"Mutation improved score to: {best_score}") + except CompilationError as e: + logger.warning(f"Mutation failed to compile: {e.message}") + + except CompilationError as e: + logger.error(f"Compilation error: {e.message}") + if e.debug_info: + logger.debug(f"Debug info: {e.debug_info[:500]}...") + + # Cleanup + evaluator.cleanup() + + return { + "strategy": planner_result["strategy"], + "best_config": best_config, + "best_score": best_score, + "iterations": max_iterations, + } + + +def mutate_config(config: str, iteration: int) -> str: + """Simple mutation function for testing. + + Args: + config: Current configuration. + iteration: Current iteration number. + + Returns: + Mutated configuration. + """ + lines = config.split("\n") + + # Add a comment about the mutation + lines.insert(0, f"# Mutation iteration {iteration + 1}") + + # Simple mutation: add or remove a flag based on iteration + mutations = [ + "--iree-opt-data-tiling=true", + "--iree-llvmcpu-enable-pad-consumer-fusion=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + ] + + mutation_flag = mutations[iteration % len(mutations)] + + if mutation_flag in config: + # Remove the flag + lines = [l for l in lines if mutation_flag not in l] + else: + # Add the flag + lines.append(mutation_flag) + + return "\n".join(lines) + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="IREE-EVO: Evolutionary Optimization for IREE Compiler", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Optimize a model for CPU + python -m iree_evo.main --mlir-path model.mlir --backend llvm-cpu + + # Run with more iterations + python -m iree_evo.main --mlir-path model.mlir --backend cuda --max-iterations 50 + + # Specify working directory + python -m iree_evo.main --mlir-path model.mlir --backend llvm-cpu --work-dir /tmp/iree_evo + """, + ) + + parser.add_argument( + "--mlir-path", + type=str, + required=True, + help="Path to the input MLIR file", + ) + parser.add_argument( + "--backend", + type=str, + default="llvm-cpu", + choices=["llvm-cpu", "cuda", "rocm"], + help="Target backend (default: llvm-cpu)", + ) + parser.add_argument( + "--work-dir", + type=str, + default=None, + help="Working directory for intermediate files", + ) + parser.add_argument( + "--max-iterations", + type=int, + default=10, + help="Maximum number of evolution iterations (default: 10)", + ) + parser.add_argument( + "--no-openevolve", + action="store_true", + help="Disable OpenEvolve and use mock optimization loop", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Output file for results (JSON format)", + ) + + args = parser.parse_args() + + # Set logging level + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Validate input file + if not os.path.exists(args.mlir_path): + logger.error(f"MLIR file not found: {args.mlir_path}") + sys.exit(1) + + # Run optimization + try: + results = run_optimization( + mlir_path=args.mlir_path, + backend=args.backend, + work_dir=args.work_dir, + max_iterations=args.max_iterations, + use_openevolve=not args.no_openevolve, + ) + + # Print results + print("\n" + "=" * 60) + print("IREE-EVO Optimization Results") + print("=" * 60) + print(f"Strategy: {results['strategy']}") + print(f"Best Score: {results['best_score']}") + print(f"Iterations: {results['iterations']}") + print("\nBest Configuration:") + print("-" * 40) + print(results["best_config"]) + print("-" * 40) + + # Save results if output file specified + if args.output: + with open(args.output, "w") as f: + json.dump(results, f, indent=2) + logger.info(f"Results saved to {args.output}") + + except Exception as e: + logger.error(f"Optimization failed: {e}") + if args.verbose: + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/experimental/iree_evo/prompts.py b/experimental/iree_evo/prompts.py new file mode 100644 index 000000000000..5e5dbfe26719 --- /dev/null +++ b/experimental/iree_evo/prompts.py @@ -0,0 +1,294 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""System prompts for LLM agents in the IREE-EVO system. + +This module defines the system prompts for the Planner and Coder agents +used in the evolutionary optimization system. +""" + +# ============================================================================= +# PLANNER PROMPT +# ============================================================================= +PLANNER_PROMPT = """You are an IREE Compiler Architect specializing in ML compiler optimization. + +## Your Role +You analyze MLIR intermediate representations and devise optimization strategies +for the IREE compiler. Your goal is to identify performance bottlenecks and +recommend the best optimization approach. + +## Input Format +You will receive: +1. An MLIR module or function to optimize +2. A summary of compute operations and tensor shapes +3. The target backend (llvm-cpu, cuda, rocm) +4. Performance requirements or constraints + +## Optimization Menu +Choose from the following optimization strategies: + +### 1. IntegerRequantization +**When to use:** Quantized models with float dequant/requant operations +**Goal:** Fuse float operations into pure integer arithmetic +**Math:** For a quantized value Q with scale S and zero-point Z: + - Float: x = S * (Q - Z) + - Integer: x ≈ (Q - Z) * M >> n, where M ≈ S * 2^n + +### 2. QuantizationFusion +**When to use:** Models with separate dequantize-compute-quantize sequences +**Goal:** Fuse quantization/dequantization with compute operations +**Benefit:** Reduces memory bandwidth and intermediate precision conversions + +### 3. Tiling +**When to use:** Large tensor operations that exceed cache size +**Goal:** Partition computations into smaller tiles that fit in cache +**Parameters:** Tile sizes, loop ordering, thread mapping + +### 4. Vectorization +**When to use:** Element-wise operations and small reductions +**Goal:** Utilize SIMD instructions for parallel element processing +**Consideration:** Vector width depends on target (AVX-512, NEON, etc.) + +### 5. LoopFusion +**When to use:** Multiple operations with compatible iteration spaces +**Goal:** Combine loops to improve data locality +**Benefit:** Reduces memory traffic between operations + +## Output Format +Provide your analysis in the following JSON format: +```json +{ + "analysis": { + "compute_bottleneck": "description of the main bottleneck", + "memory_pattern": "description of memory access patterns", + "quantization_status": "current quantization state if applicable" + }, + "strategy": "strategy name from the menu above", + "rationale": "detailed explanation of why this strategy was chosen", + "constraints": ["list of constraints the coder should follow"], + "priority_ops": ["list of ops to focus on"] +} +``` + +## Guidelines +- Prioritize numerical correctness over performance +- Consider target hardware capabilities +- For quantized models, prefer integer-only arithmetic when possible +- Account for memory bandwidth limitations on target devices +""" + +# ============================================================================= +# CODER PROMPT +# ============================================================================= +CODER_PROMPT = """You are an IREE Compiler Engineer specializing in code generation. + +## Your Role +You implement optimization strategies by generating: +1. IREE compiler flags +2. MLIR Transform Dialect scripts +3. Custom pass configurations + +## Input Format +You will receive: +1. The optimization strategy to implement +2. Constraints from the Planner +3. Target MLIR operations to optimize +4. Error feedback from previous attempts (if any) + +## Available Tools + +### IREE Compiler Flags +Common flags for optimization: +``` +--iree-global-opt-enable-quantized-matmul-reassociation=true +--iree-opt-const-eval=true +--iree-opt-const-expr-hoisting=true +--iree-opt-data-tiling=true +--iree-llvmcpu-enable-ukernels=all +--iree-codegen-transform-dialect-library= +``` + +### Transform Dialect Operations +Key operations for code generation: +```mlir +transform.structured.tile_using_forall %op num_threads [N, M] +transform.structured.tile_using_for %op [tile_size_0, tile_size_1] +transform.structured.vectorize %op vector_sizes [V0, V1] +transform.structured.fuse_into_containing_op %producer into %consumer +transform.apply_patterns.canonicalization +transform.iree.bufferize %op +``` + +## Integer Requantization Math + +For implementing integer-only requantization, use the following math: + +### Scale Representation +A floating-point scale S can be approximated as: + S ≈ M * 2^(-n) +where M is an integer multiplier and n is the shift amount. + +### Integer Multiply-Shift +Replace: `arith.mulf %a, %scale : f32` +With: `%mul = arith.muli %a_int, %M : i32` + `%result = arith.shrsi %mul, %n : i32` + +### Zero-Point Handling +Replace: `%sub = arith.subf %dequant, %zp : f32` +With: `%sub = arith.subi %quant, %zp_int : i32` + +### Example Pattern +Input (Float): +```mlir +%ext = arith.extui %q : i4 to i32 +%fp = arith.uitofp %ext : i32 to f32 +%sub = arith.subf %fp, %zp : f32 +%mul = arith.mulf %sub, %scale : f32 +``` + +Output (Integer): +```mlir +%ext = arith.extui %q : i4 to i32 +%sub = arith.subi %ext, %zp_int : i32 +%mul = arith.muli %sub, %M : i32 +%result = arith.shrsi %mul, %n : i32 +``` + +## Output Format +Generate flags/scripts in the following format: + +### For Compiler Flags Only: +``` +# Optimization: [strategy name] +# Target: [backend] + +--flag-name=value +--another-flag=value +``` + +### For Transform Dialect Scripts: +```mlir +// Optimization: [strategy name] +// Target: [backend] + +module attributes { transform.with_named_sequence } { + transform.named_sequence @optimization(%variant_op: !transform.any_op {transform.consumed}) { + // Your transform operations here + transform.yield + } +} +``` + +## Error Handling +If you receive error feedback: +1. Analyze the error message carefully +2. Identify the failing operation or pass +3. Adjust your approach: + - Try smaller tile sizes if tiling fails + - Use different vector widths if vectorization fails + - Add canonicalization passes if patterns don't match +4. Explain your fix in comments + +## Guidelines +- Start with conservative configurations +- Validate flag combinations are compatible +- Include canonicalization after major transforms +- Test on small inputs before scaling up +- Preserve numerical precision requirements +""" + +# ============================================================================= +# EVOLUTION PROMPT (for OpenEvolve mutation guidance) +# ============================================================================= +EVOLUTION_PROMPT = """You are an optimization evolution specialist. + +## Your Role +Mutate compiler configurations to explore the optimization space efficiently. + +## Mutation Strategies + +### 1. Parameter Tuning +- Adjust tile sizes (powers of 2: 1, 2, 4, 8, 16, 32, 64, 128, 256) +- Modify vector widths (4, 8, 16, 32) +- Change loop unroll factors (1, 2, 4, 8) + +### 2. Flag Toggling +- Enable/disable optimization flags +- Combine compatible optimizations +- Remove conflicting flags + +### 3. Transform Sequence Modification +- Reorder transform operations +- Add/remove canonicalization passes +- Adjust operation matching patterns + +## Input Format +You will receive: +1. The current best configuration +2. Its fitness score +3. Recent mutation history +4. Compilation errors (if any) + +## Output Format +Generate a mutated configuration following the same format as the input. +Include a comment explaining the mutation: +``` +# Mutation: [description of change] +# Rationale: [why this might improve fitness] +``` + +## Guidelines +- Make incremental changes (one major change per mutation) +- Learn from compilation errors to avoid invalid configurations +- Balance exploration (new regions) vs exploitation (refining good solutions) +- Track which mutations improved fitness +""" + +# ============================================================================= +# ERROR RECOVERY PROMPT +# ============================================================================= +ERROR_RECOVERY_PROMPT = """You are a compiler debugging specialist. + +## Your Role +Analyze compilation errors and suggest fixes for IREE/MLIR issues. + +## Input Format +You will receive: +1. The failing configuration (flags/scripts) +2. Error message and context +3. Debug dump (IR after failed pass) + +## Common Error Patterns + +### 1. Type Mismatch +Error: "operand type mismatch" +Fix: Ensure all operands have compatible types, add explicit casts + +### 2. Shape Mismatch +Error: "incompatible shapes" +Fix: Check tensor dimensions, adjust tile sizes to divide evenly + +### 3. Legalization Failure +Error: "failed to legalize operation" +Fix: Add required lowering passes, check operation is supported on target + +### 4. Memory Allocation +Error: "allocation failed" +Fix: Reduce tile sizes, enable memory optimization flags + +### 5. Transform Failure +Error: "failed to match" +Fix: Adjust matching patterns, ensure ops exist before transform + +## Output Format +```json +{ + "error_type": "category of the error", + "root_cause": "detailed explanation of the cause", + "fix": "corrected configuration or flags", + "explanation": "why this fix should work" +} +``` +""" diff --git a/experimental/iree_evo/slicer.py b/experimental/iree_evo/slicer.py new file mode 100644 index 000000000000..e95c1d7e873e --- /dev/null +++ b/experimental/iree_evo/slicer.py @@ -0,0 +1,273 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""MLIR IR analysis and error parsing utilities. + +This module provides functionality to extract summaries from MLIR content +and parse compilation errors for LLM-based debugging. +""" + +import re +from typing import Any, Dict, List, Optional + + +class MLIRSlicer: + """Utilities for slicing and analyzing MLIR IR content. + + This class provides methods to extract structural summaries from MLIR IR + and parse compilation errors for debugging purposes. + """ + + # Patterns for extracting compute operations + _COMPUTE_OP_PATTERNS = [ + # Linalg named operations + r"linalg\.(matmul|conv_2d|conv_2d_nchw_fchw|conv_2d_nhwc_hwcf|" + r"batch_matmul|dot|pooling_nhwc_sum|pooling_nchw_sum|" + r"generic|fill|copy|transpose|reduce|broadcast)", + # Tensor operations + r"tensor\.(empty|extract|extract_slice|insert|insert_slice|" + r"pad|collapse_shape|expand_shape|concat|generate)", + # Arithmetic operations for quantization + r"arith\.(sitofp|fptosi|extui|extsi|trunci|muli|mulf|addi|addf|" + r"subi|subf|shrsi|shrui|shli|divsi|divui|andi|ori|xori)", + # Math operations + r"math\.(exp|log|sqrt|rsqrt|tanh|absf|ceil|floor|round)", + # Vector operations + r"vector\.(transfer_read|transfer_write|broadcast|contract|" + r"reduction|fma|outerproduct)", + # SCF operations + r"scf\.(for|forall|if|while|yield)", + ] + + # Pattern for extracting tensor shapes + _TENSOR_SHAPE_PATTERN = r"tensor<([^>]+)>" + + # Pattern for extracting memref shapes + _MEMREF_SHAPE_PATTERN = r"memref<([^>]+)>" + + # Error patterns for parsing compilation output + _ERROR_PATTERNS = [ + r"error:\s*(.+)", + r"note:\s*(.+)", + r"warning:\s*(.+)", + r"failed to legalize operation\s*'([^']+)'", + r"'([^']+)' op\s+(.+)", + ] + + # Pattern for extracting line numbers from error messages + _LINE_NUMBER_PATTERN = r":(\d+):(\d+):" + + @classmethod + def extract_summary(cls, mlir_content: str) -> Dict[str, List[str]]: + """Extracts a structural summary from MLIR content. + + Args: + mlir_content: The MLIR IR content as a string. + + Returns: + A dictionary containing: + - 'compute_ops': List of compute operations found in the IR. + - 'tensor_shapes': List of tensor shapes found in the IR. + - 'memref_shapes': List of memref shapes found in the IR. + - 'quantization_ops': List of quantization-related operations. + """ + compute_ops: List[str] = [] + tensor_shapes: List[str] = [] + memref_shapes: List[str] = [] + quantization_ops: List[str] = [] + + # Extract compute operations + for pattern in cls._COMPUTE_OP_PATTERNS: + matches = re.findall(pattern, mlir_content) + for match in matches: + if isinstance(match, tuple): + # Extract the first non-empty element from the match tuple + op_name = "" + for element in match: + if element: + op_name = element + break + else: + op_name = match + if op_name: + full_op = re.search(rf"\b\w+\.{re.escape(op_name)}\b", mlir_content) + if full_op: + compute_ops.append(full_op.group()) + + # Extract tensor shapes + tensor_matches = re.findall(cls._TENSOR_SHAPE_PATTERN, mlir_content) + tensor_shapes = list(set(tensor_matches)) + + # Extract memref shapes + memref_matches = re.findall(cls._MEMREF_SHAPE_PATTERN, mlir_content) + memref_shapes = list(set(memref_matches)) + + # Identify quantization-related operations + quant_patterns = [ + r"arith\.sitofp", + r"arith\.fptosi", + r"arith\.extui", + r"arith\.extsi", + r"arith\.trunci", + r"arith\.muli", + r"arith\.shrsi", + r"arith\.shrui", + r"arith\.uitofp", + r"arith\.subf", + r"arith\.mulf", + ] + for pattern in quant_patterns: + if re.search(pattern, mlir_content): + quantization_ops.append(pattern.replace(r"\.", ".")) + + # Remove duplicates while preserving order + compute_ops = list(dict.fromkeys(compute_ops)) + + return { + "compute_ops": compute_ops, + "tensor_shapes": tensor_shapes, + "memref_shapes": memref_shapes, + "quantization_ops": quantization_ops, + } + + @classmethod + def parse_compilation_error( + cls, stderr: str, context_lines: int = 5 + ) -> str: + """Parses compilation error output for LLM debugging. + + Args: + stderr: The standard error output from a failed iree-compile. + context_lines: Number of lines of context to include around errors. + + Returns: + A clean string summarizing the error for the LLM. + """ + if not stderr: + return "No error output provided." + + lines = stderr.split("\n") + error_summary_parts: List[str] = [] + error_locations: List[Dict[str, Any]] = [] + + # Find all error lines and their locations + for i, line in enumerate(lines): + for pattern in cls._ERROR_PATTERNS: + match = re.search(pattern, line) + if match: + # Try to extract line number + loc_match = re.search(cls._LINE_NUMBER_PATTERN, line) + error_locations.append({ + "index": i, + "line": line, + "match": match.group(), + "line_number": int(loc_match.group(1)) if loc_match else None, + "column": int(loc_match.group(2)) if loc_match else None, + }) + break + + if not error_locations: + # No specific errors found, return the full stderr truncated + truncated = "\n".join(lines[:50]) + if len(lines) > 50: + truncated += f"\n... ({len(lines) - 50} more lines)" + return f"Compilation failed with output:\n{truncated}" + + # Build error summary with context + error_summary_parts.append("=== COMPILATION ERROR SUMMARY ===\n") + + for error in error_locations[:5]: # Limit to first 5 errors + idx = error["index"] + start = max(0, idx - context_lines) + end = min(len(lines), idx + context_lines + 1) + + error_summary_parts.append(f"--- Error at line {idx + 1} ---") + if error["line_number"]: + error_summary_parts.append( + f"Source location: line {error['line_number']}, " + f"column {error['column']}" + ) + + # Add context + context = "\n".join(lines[start:end]) + error_summary_parts.append(f"Context:\n{context}\n") + + # Add primary error message + primary_errors = [e["match"] for e in error_locations if "error:" in e["line"]] + if primary_errors: + error_summary_parts.append("=== PRIMARY ERRORS ===") + for err in primary_errors[:3]: + error_summary_parts.append(f" - {err}") + + return "\n".join(error_summary_parts) + + @classmethod + def extract_op_at_location( + cls, mlir_content: str, line_number: int, context_lines: int = 5 + ) -> Optional[str]: + """Extracts the operation and context at a specific line number. + + Args: + mlir_content: The MLIR IR content. + line_number: The line number to extract (1-indexed). + context_lines: Number of lines of context to include. + + Returns: + A string with the operation and surrounding context, or None if + the line number is out of range. + """ + lines = mlir_content.split("\n") + if line_number < 1 or line_number > len(lines): + return None + + idx = line_number - 1 # Convert to 0-indexed + start = max(0, idx - context_lines) + end = min(len(lines), idx + context_lines + 1) + + result_lines = [] + for i in range(start, end): + prefix = ">>> " if i == idx else " " + result_lines.append(f"{i + 1:4d} {prefix}{lines[i]}") + + return "\n".join(result_lines) + + @classmethod + def identify_quantization_pattern(cls, mlir_content: str) -> Optional[str]: + """Identifies the quantization pattern present in the MLIR. + + Args: + mlir_content: The MLIR IR content. + + Returns: + A string identifying the quantization pattern, or None if no + recognized pattern is found. + """ + # Check for dequantization pattern (int -> float -> multiply) + dequant_pattern = ( + r"arith\.extui.*i\d+.*i\d+" + r".*arith\.uitofp" + r".*arith\.subf" + r".*arith\.mulf" + ) + if re.search(dequant_pattern, mlir_content, re.DOTALL): + return "GroupedDequantization" + + # Check for requantization pattern (float -> int) + requant_pattern = r"arith\.fptosi.*arith\.trunci" + if re.search(requant_pattern, mlir_content, re.DOTALL): + return "Requantization" + + # Check for integer-only operations + int_only_pattern = r"arith\.muli.*arith\.shrsi" + if re.search(int_only_pattern, mlir_content, re.DOTALL): + return "IntegerOnly" + + # Check for float quantization conversion + if re.search(r"arith\.sitofp", mlir_content) and re.search( + r"arith\.mulf", mlir_content + ): + return "FloatDequantization" + + return None diff --git a/experimental/iree_evo/verification.py b/experimental/iree_evo/verification.py new file mode 100644 index 000000000000..a96b4878b4c5 --- /dev/null +++ b/experimental/iree_evo/verification.py @@ -0,0 +1,302 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Structural verification using LIT tests. + +This module provides functionality to generate and run LLVM LIT tests +for verifying the structural properties of compiled MLIR. +""" + +import os +import subprocess +import tempfile +from typing import Any, Dict, List, Optional + + +class LitGen: + """Generator for LLVM LIT structural tests. + + This class creates and runs LIT tests to verify that compiler + optimizations produce the expected structural changes in the IR. + """ + + # Verification patterns for different optimization strategies + _VERIFICATION_PATTERNS: Dict[str, Dict[str, List[str]]] = { + "IntegerRequantization": { + # Operations that should NOT appear after optimization + "check_not": [ + "arith.sitofp", + "arith.mulf", + "arith.uitofp", + "arith.subf", + ], + # Operations that SHOULD appear after optimization + "check": [ + "arith.muli", + "arith.shrsi", + ], + }, + "QuantizationFusion": { + "check_not": [ + "linalg.generic", # Should be fused + ], + "check": [ + "linalg.matmul", # Or fused matmul variant + ], + }, + "Tiling": { + "check_not": [], + "check": [ + "scf.forall", # Tiled loops should appear + ], + }, + "Vectorization": { + "check_not": [], + "check": [ + "vector.transfer_read", + "vector.transfer_write", + ], + }, + } + + @classmethod + def generate_structural_test( + cls, + strategy: str, + original_mlir: str, + compile_flags: Optional[List[str]] = None, + ) -> str: + """Creates a .mlir file content ready for llvm-lit. + + Args: + strategy: The optimization strategy name (e.g., 'IntegerRequantization'). + original_mlir: The original MLIR content to test. + compile_flags: Optional list of compiler flags to include in RUN line. + + Returns: + A string containing the complete LIT test file content. + """ + # Get verification patterns for the strategy + if strategy in cls._VERIFICATION_PATTERNS: + patterns = cls._VERIFICATION_PATTERNS[strategy] + else: + # Default to empty patterns if strategy not recognized + patterns = {"check_not": [], "check": []} + + # Build the RUN line + flags_str = " ".join(compile_flags) if compile_flags else "" + run_line = f"// RUN: iree-opt {flags_str} %s | FileCheck %s" + + # Build CHECK lines + check_lines: List[str] = [] + + # Add CHECK-NOT lines for operations that should be eliminated + for op in patterns.get("check_not", []): + check_lines.append(f"// CHECK-NOT: {op}") + + # Add CHECK lines for operations that should be present + for op in patterns.get("check", []): + check_lines.append(f"// CHECK: {op}") + + # Combine all parts + test_content_parts = [ + run_line, + "", + ] + test_content_parts.extend(check_lines) + test_content_parts.extend(["", original_mlir]) + + return "\n".join(test_content_parts) + + @classmethod + def generate_custom_test( + cls, + original_mlir: str, + check_patterns: List[str], + check_not_patterns: List[str], + compile_flags: Optional[List[str]] = None, + ) -> str: + """Creates a custom LIT test with specified patterns. + + Args: + original_mlir: The original MLIR content to test. + check_patterns: List of patterns that should be present. + check_not_patterns: List of patterns that should not be present. + compile_flags: Optional list of compiler flags. + + Returns: + A string containing the complete LIT test file content. + """ + # Build the RUN line + flags_str = " ".join(compile_flags) if compile_flags else "" + run_line = f"// RUN: iree-opt {flags_str} %s | FileCheck %s" + + # Build CHECK lines + check_lines: List[str] = [] + + for pattern in check_not_patterns: + check_lines.append(f"// CHECK-NOT: {pattern}") + + for pattern in check_patterns: + check_lines.append(f"// CHECK: {pattern}") + + # Combine all parts + test_content_parts = [ + run_line, + "", + ] + test_content_parts.extend(check_lines) + test_content_parts.extend(["", original_mlir]) + + return "\n".join(test_content_parts) + + @classmethod + def run_lit( + cls, + test_filepath: str, + lit_executable: str = "llvm-lit", + timeout: int = 60, + ) -> bool: + """Runs llvm-lit on a test file and returns the result. + + Args: + test_filepath: Path to the LIT test file. + lit_executable: Path to the llvm-lit executable. + timeout: Timeout in seconds for the test. + + Returns: + True if the test passed, False otherwise. + """ + if not os.path.exists(test_filepath): + raise FileNotFoundError(f"Test file not found: {test_filepath}") + + try: + result = subprocess.run( + [lit_executable, "-v", test_filepath], + capture_output=True, + text=True, + timeout=timeout, + ) + return result.returncode == 0 + except subprocess.TimeoutExpired: + return False + except FileNotFoundError: + raise FileNotFoundError( + f"llvm-lit executable not found: {lit_executable}. " + "Please ensure LLVM/LIT is installed and in PATH." + ) + + @classmethod + def run_lit_with_output( + cls, + test_filepath: str, + lit_executable: str = "llvm-lit", + timeout: int = 60, + ) -> Dict[str, Any]: + """Runs llvm-lit and returns detailed output. + + Args: + test_filepath: Path to the LIT test file. + lit_executable: Path to the llvm-lit executable. + timeout: Timeout in seconds for the test. + + Returns: + A dictionary containing: + - 'passed': bool indicating if the test passed + - 'stdout': stdout from llvm-lit + - 'stderr': stderr from llvm-lit + - 'returncode': the return code + """ + if not os.path.exists(test_filepath): + raise FileNotFoundError(f"Test file not found: {test_filepath}") + + try: + result = subprocess.run( + [lit_executable, "-v", test_filepath], + capture_output=True, + text=True, + timeout=timeout, + ) + return { + "passed": result.returncode == 0, + "stdout": result.stdout, + "stderr": result.stderr, + "returncode": result.returncode, + } + except subprocess.TimeoutExpired: + return { + "passed": False, + "stdout": "", + "stderr": f"Test timed out after {timeout} seconds", + "returncode": -1, + } + except FileNotFoundError: + return { + "passed": False, + "stdout": "", + "stderr": f"llvm-lit executable not found: {lit_executable}", + "returncode": -1, + } + + @classmethod + def create_and_run_test( + cls, + strategy: str, + original_mlir: str, + compile_flags: Optional[List[str]] = None, + work_dir: Optional[str] = None, + lit_executable: str = "llvm-lit", + ) -> Dict[str, Any]: + """Creates a test file, runs it, and returns results. + + This is a convenience method that combines test generation and execution. + + Args: + strategy: The optimization strategy name. + original_mlir: The original MLIR content to test. + compile_flags: Optional list of compiler flags. + work_dir: Optional working directory for the test file. + lit_executable: Path to the llvm-lit executable. + + Returns: + A dictionary containing the test results. + """ + # Generate test content + test_content = cls.generate_structural_test( + strategy, original_mlir, compile_flags + ) + + # Create temporary file for the test + if work_dir: + os.makedirs(work_dir, exist_ok=True) + test_file = os.path.join(work_dir, "test.mlir") + with open(test_file, "w") as f: + f.write(test_content) + else: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".mlir", delete=False + ) as f: + f.write(test_content) + test_file = f.name + + try: + # Run the test + result = cls.run_lit_with_output(test_file, lit_executable) + result["test_file"] = test_file + result["test_content"] = test_content + return result + finally: + # Clean up if using temp file + if not work_dir and os.path.exists(test_file): + os.remove(test_file) + + @classmethod + def get_available_strategies(cls) -> List[str]: + """Returns the list of available verification strategies. + + Returns: + A list of strategy names that have defined verification patterns. + """ + return list(cls._VERIFICATION_PATTERNS.keys()) diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt index 9469b86b277e..ab6c80b44c22 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt @@ -42,6 +42,7 @@ iree_bitcode_library( "${PROJECT_BINARY_DIR}/runtime/src/iree/schemas/cpu_data_headers_filegroup.stamp" "common_riscv_64.h" "mmt4d_riscv_64_internal.h" + "bme.h" "mmt4d_riscv_64_tiles.inl" "pack_riscv_64_internal.h" "unpack_riscv_64_internal.h" @@ -177,6 +178,8 @@ iree_cc_library( riscv_64_v SRCS "mmt4d_riscv_64_v.c" + HDRS + "bme.h" COPTS "${IREE_UK_COPTS_RISCV_64_V}" DEPS diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h b/runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h new file mode 100644 index 000000000000..6ee763e309f1 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h @@ -0,0 +1,64 @@ +// HACK reuse the scalar registers to avoid assembler hacking for now +#define m0 "x0" +#define m1 "x1" +#define m2 "x2" +#define m3 "x3" +#define m4 "x4" +#define m5 "x5" +#define m6 "x6" +#define m7 "x7" + +#define v0 "x0" +#define v1 "x1" +#define v2 "x2" +#define v3 "x3" +#define v4 "x4" +#define v5 "x5" +#define v6 "x6" +#define v7 "x7" +#define v8 "x8" +#define v9 "x9" +#define v10 "x10" +#define v11 "x11" +#define v12 "x12" +#define v13 "x13" +#define v14 "x14" +#define v15 "x15" +#define v16 "x16" +#define v17 "x17" +#define v18 "x18" +#define v19 "x19" +#define v20 "x20" +#define v21 "x21" +#define v22 "x22" +#define v23 "x23" +#define v24 "x24" +#define v25 "x25" +#define v26 "x26" +#define v27 "x27" +#define v28 "x28" +#define v29 "x29" +#define v30 "x30" +#define v31 "x31" + +// opmvx. f6=b101010, f7=b1010101 +#define VMV_RV(md, rs1, vs2) \ + asm volatile(".insn r 0x57, 0x6, 0x55, " md ", %0, " vs2 : : "r"(rs1)); + +// opmvx. f6=b101110, f7=b1011101 +#define VMV_VR(vd, rs1, ms2) \ + asm volatile(".insn r 0x57, 0x6, 0x5d, " vd ", %0, " ms2 : : "r"(rs1)); + +// opmvx. f6=b101100, f7=b1011001 +#define OPMVINBCAST(md, vs2) \ + asm volatile(".insn r 0x57, 0x6, 0x59, " md ", x0, " vs2); + +// opmvv. f6=b101000, f7=b1010001 +#define VOPACC(md, vs2, vs1) \ + asm volatile(".insn r 0x57, 0x2, 0x51, " md ", " vs1 ", " vs2); + +#include // For size_t +#include // For int8_t and int32_t + +//void i8_mm_bme_1x2(int32_t* c_bias, int32_t* c_out, int8_t* at, int8_t* b, +// size_t M, size_t N, size_t K); \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl index 4f4d0413a907..48247dc98dc9 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl @@ -32,3 +32,10 @@ IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 1, 1, _zvfh) IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 2, 1, _zvfh) IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 4, 1, _zvfh) IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 7, 1, _zvfh) + +// s8s8s32 tiles using the 'v' extension +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 1, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 2, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 4, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 8, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 16, 1, _v) \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c index 2517080da24a..7280e3aef0af 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c @@ -8,6 +8,7 @@ #include "iree/builtins/ukernel/arch/riscv_64/common_riscv_64.h" #include "iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_internal.h" +#include "bme.h" IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v( @@ -121,6 +122,87 @@ iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v( } } +IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void +iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v( + void* IREE_UK_RESTRICT out_tile, + const void* IREE_UK_RESTRICT lhs_panel, + const void* IREE_UK_RESTRICT rhs_panel, + const iree_uk_mmt4d_params_t* params, int M0) { + IREE_UK_ASSERT(M0 >= 1 && M0 <= 16); + iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile; + const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel; + const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel; + + const int N0 = params->N0; + const int K = params->K; + size_t ml = M0; + size_t vl = N0; + + // Performance case for M0=16 (LMUL=8) + if (M0 == 16) { + // init m0 to zero (LMUL=8) + asm volatile("vsetvli zero, %0, e32, m8, ta, ma" : : "r"(vl)); + asm volatile("vmv.v.i v0, 0"); + OPMVINBCAST(m0, v0); + + // K-loop unrolled by 2 + size_t k = 0; + while (k + 2 <= K) { + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16 + asm volatile("vle8.v v16, (%0)" : : "r"(&lhs_ptr[k * M0])); + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0 + asm volatile("vle8.v v18, (%0)" : : "r"(&rhs_ptr[k * N0])); + VOPACC(m0, v18, v16); + k++; + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16 + asm volatile("vle8.v v20, (%0)" : : "r"(&lhs_ptr[k * M0])); + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0 + asm volatile("vle8.v v22, (%0)" : : "r"(&rhs_ptr[k * N0])); + VOPACC(m0, v22, v20); + k++; + } + if (k < K) { + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16 + asm volatile("vle8.v v16, (%0)" : : "r"(&lhs_ptr[k * M0])); + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0 + asm volatile("vle8.v v18, (%0)" : : "r"(&rhs_ptr[k * N0])); + VOPACC(m0, v18, v16); + } + + // store results + asm volatile("vsetvli zero, %0, e32, m8, ta, ma" : : "r"(vl)); + for (size_t r = 0; r < ml; r++) { // ml=16 + VMV_VR(v0, r, m0); + asm volatile("vse32.v v0, (%0)" : : "r"(&out_ptr[r * N0])); + } + } + // Tail case for M0 < 16 (using LMUL=4) + else { + // 1. Initialize accumulators to ZERO (LMUL=4) + asm volatile("vsetvli zero, %0, e32, m4, ta, ma" : : "r"(vl)); + asm volatile("vmv.v.i v0, 0"); + OPMVINBCAST(m3, v0); // Initialize m3 to zero + + // 2. Main K-loop + for (int k = 0; k < K; ++k) { + asm volatile("vsetvli zero, %0, e8, m1, ta, ma" : : "r"(ml)); + asm volatile("vle8.v v5, (%0)" : : "r"(&lhs_ptr[k * M0])); + + asm volatile("vsetvli zero, %0, e8, m1, ta, ma" : : "r"(vl)); + asm volatile("vle8.v v4, (%0)" : : "r"(&rhs_ptr[k * N0])); + + VOPACC(m3, v4, v5); + } + + // 3. Store results + asm volatile("vsetvli zero, %0, e32, m4, ta, ma" : : "r"(vl)); + for (size_t r = 0; r < ml; r++) { + VMV_VR(v0, r, m3); + asm volatile("vse32.v v0, (%0)" : : "r"(&out_ptr[r * N0])); + } + } +} + IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v, iree_uk_mmt4d_tile_f32f32f32_1xXXx1_riscv_64_v, 1) @@ -133,3 +215,22 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v, iree_uk_mmt4d_tile_f32f32f32_7xXXx1_riscv_64_v, 7) + +// *** UPDATED SECTION *** +// Point all s8s8s32 tiles to the new generic function +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_riscv_64_v, 1) +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_2xXXx1_riscv_64_v, 2) +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_4xXXx1_riscv_64_v, 4) +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_8xXXx1_riscv_64_v, 8) +// Add the new M0=16 tile +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_16xXXx1_riscv_64_v, 16) \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c index 352a0725a7ad..0f9a6d0b926e 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c @@ -19,6 +19,21 @@ iree_uk_query_matmul_tile_sizes_riscv_64_f32f32f32( return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 1, .N = 8}; } +static iree_uk_matmul_tile_sizes_t +iree_uk_query_matmul_tile_sizes_riscv_64_s8s8s32( + const iree_uk_query_tile_sizes_2d_params_t* params) { +#if defined(IREE_UK_BUILD_RISCV_64_V) + if (iree_uk_cpu_riscv_64_v(params->cpu_data)) { + // Corresponds to the new target M0=16. + // N=32 is based on a minimum-VLEN (128-bit) and the new LMUL=8. + // N0 = (128 bits / 32 bits_per_element) * 8_LMUL = 32. + return (iree_uk_matmul_tile_sizes_t){.M = 16, .K = 1, .N = 32}; + } +#endif + // generic fallback + return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 1, .N = 8}; +} + bool iree_uk_query_matmul_tile_sizes_arch( const iree_uk_query_tile_sizes_2d_params_t* params, iree_uk_matmul_tile_sizes_t* out_matmul_tile_sizes) { @@ -27,8 +42,12 @@ bool iree_uk_query_matmul_tile_sizes_arch( *out_matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_riscv_64_f32f32f32(params); return true; + } else if (op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32) { + *out_matmul_tile_sizes = + iree_uk_query_matmul_tile_sizes_riscv_64_s8s8s32(params); + return true; } else { // Shouldn't happen, validated earlier. return false; } -} +} \ No newline at end of file