diff --git a/helion/_testing.py b/helion/_testing.py index 87f459558..fe6b1505a 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -17,7 +17,6 @@ from typing import Generator import unittest -import pytest import torch from torch.utils._pytree import tree_map import triton @@ -267,6 +266,8 @@ def setUp(self) -> None: if not self._in_ref_eager_mode: return + import pytest + # Reset assert_close counter for this test RefEagerTestBase._assert_close_count = 0 # Reset assertRaises counter for this test @@ -361,6 +362,8 @@ def tearDown(self) -> None: super().tearDown() # type: ignore[misc] return + import pytest + try: # Exit the run_ref tracker self._run_ref_tracker.__exit__(None, None, None) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 01b96ae4a..bac7a740c 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -42,9 +42,11 @@ from triton.testing import do_bench from .. import exc +from .._testing import is_cuda from ..runtime.kernel import BoundKernel from ..runtime.precompile_shim import already_compiled from ..runtime.precompile_shim import make_precompiler +from .benchmarking import do_bench_cudagraph_with_cache_clear from .benchmarking import interleaved_bench from .config_generation import ConfigGeneration from .config_generation import FlatConfig @@ -338,12 +340,18 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: # Accuracy check failed; reject this config return inf t1 = time.perf_counter() - res = do_bench( - functools.partial(fn, *self.args), - return_mode="median", - warmup=1, # we are already warmed up above - rep=50, - ) + kwargs = { + "fn": functools.partial(fn, *self.args), + "rep": 50, + } + if is_cuda(): + res = do_bench_cudagraph_with_cache_clear(**kwargs) + else: + res = do_bench( + **kwargs, + return_mode="median", + warmup=1, # we are already warmed up above + ) t2 = time.perf_counter() assert isinstance(res, float) self.log.debug( diff --git a/helion/autotuner/benchmarking.py b/helion/autotuner/benchmarking.py index b0f17a4f2..1d9bf93b7 100644 --- a/helion/autotuner/benchmarking.py +++ b/helion/autotuner/benchmarking.py @@ -4,12 +4,104 @@ import math import statistics from typing import Callable +from typing import Sequence +import torch +import triton from triton import runtime from .progress_bar import iter_with_progress +def do_bench_cudagraph_with_cache_clear( + fn: Callable[[], object], + rep: int = 20, + grad_to_none: Sequence[torch.Tensor] | None = None, +) -> float: + """ + Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing. + Only supports calculating mean execution time. + + Args: + fn: Function to benchmark + rep: Target total measurement time in milliseconds + grad_to_none: Tensors whose gradients should be cleared before each measurement + + Returns: + Mean execution time in milliseconds + """ + # Get a cache tensor and function to zero it for L2 cache clearing + cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined] + clear_cache_fn = cache.zero_ + + with torch.cuda.stream(torch.cuda.Stream()): + # Warmup: clear cache and run function once to ensure it's compiled + clear_cache_fn() + fn() + + # Reset gradients if needed + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + + # Estimate execution time + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + clear_cache_fn() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # Calculate number of repetitions needed to reach target measurement time (rep) + n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms)) + + # Create a CUDA graph for measuring total time (cache clearing + kernel execution) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + clear_cache_fn() + fn() + torch.cuda.synchronize() + + # Create a separate CUDA graph for measuring cache clearing time only + cache_clear_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cache_clear_graph): + for _ in range(n_repeat): + clear_cache_fn() + torch.cuda.synchronize() + + # Measure time for cache clearing only + cache_clear_start_event = torch.cuda.Event(enable_timing=True) + cache_clear_end_event = torch.cuda.Event(enable_timing=True) + cache_clear_start_event.record() + cache_clear_graph.replay() + cache_clear_end_event.record() + torch.cuda.synchronize() + cache_clear_time = ( + cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat + ) + + # Measure total time (cache clearing + kernel execution) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + total_time = start_event.elapsed_time(end_event) / n_repeat + + # Subtract cache clearing overhead to get pure kernel execution time + return total_time - cache_clear_time + + def compute_repeat( fn: Callable[[], object], *, diff --git a/test/test_debug_utils.py b/test/test_debug_utils.py index 087d2ffd2..728f177db 100644 --- a/test/test_debug_utils.py +++ b/test/test_debug_utils.py @@ -13,6 +13,7 @@ from helion._testing import DEVICE from helion._testing import RefEagerTestDisabled from helion._testing import TestCase +from helion._testing import is_cuda from helion._testing import skipIfCpu import helion.language as hl @@ -142,20 +143,28 @@ def test_print_repro_on_autotune_error(self): torch.manual_seed(0) x = torch.randn([128], dtype=torch.float32, device=DEVICE) - # Mock do_bench to fail on the second config with PTXASError (warn level) + # Mock benchmark helper to fail on the second config with PTXASError (warn level) from torch._inductor.runtime.triton_compat import PTXASError - from triton.testing import do_bench as original_do_bench + + from helion.autotuner import base_search call_count = [0] + bench_attr = ( + "do_bench_cudagraph_with_cache_clear" if is_cuda() else "do_bench" + ) + + original_bench = getattr(base_search, bench_attr) + bench_target = f"helion.autotuner.base_search.{bench_attr}" + def mock_do_bench(*args, **kwargs): call_count[0] += 1 if call_count[0] == 2: # Fail on second config raise PTXASError("Mocked PTXAS error") - return original_do_bench(*args, **kwargs) + return original_bench(*args, **kwargs) with self.capture_output() as output_capture: - with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench): + with mock.patch(bench_target, mock_do_bench): # Autotune will try both configs, second one will fail and print repro kernel.autotune([x], force=False)