Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Generator
import unittest

import pytest
import torch
from torch.utils._pytree import tree_map
import triton
Expand Down Expand Up @@ -267,6 +266,8 @@ def setUp(self) -> None:
if not self._in_ref_eager_mode:
return

import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?


# Reset assert_close counter for this test
RefEagerTestBase._assert_close_count = 0
# Reset assertRaises counter for this test
Expand Down Expand Up @@ -361,6 +362,8 @@ def tearDown(self) -> None:
super().tearDown() # type: ignore[misc]
return

import pytest

try:
# Exit the run_ref tracker
self._run_ref_tracker.__exit__(None, None, None)
Expand Down
20 changes: 14 additions & 6 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
from triton.testing import do_bench

from .. import exc
from .._testing import is_cuda
from ..runtime.kernel import BoundKernel
from ..runtime.precompile_shim import already_compiled
from ..runtime.precompile_shim import make_precompiler
from .benchmarking import do_bench_cudagraph_with_cache_clear
from .benchmarking import interleaved_bench
from .config_generation import ConfigGeneration
from .config_generation import FlatConfig
Expand Down Expand Up @@ -338,12 +340,18 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
# Accuracy check failed; reject this config
return inf
t1 = time.perf_counter()
res = do_bench(
functools.partial(fn, *self.args),
return_mode="median",
warmup=1, # we are already warmed up above
rep=50,
)
kwargs = {
"fn": functools.partial(fn, *self.args),
"rep": 50,
}
if is_cuda():
res = do_bench_cudagraph_with_cache_clear(**kwargs)
else:
res = do_bench(
**kwargs,
return_mode="median",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the above also use median? Median tends to remove outliers.

warmup=1, # we are already warmed up above
)
t2 = time.perf_counter()
assert isinstance(res, float)
self.log.debug(
Expand Down
92 changes: 92 additions & 0 deletions helion/autotuner/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,104 @@
import math
import statistics
from typing import Callable
from typing import Sequence

import torch
import triton
from triton import runtime

from .progress_bar import iter_with_progress


def do_bench_cudagraph_with_cache_clear(
fn: Callable[[], object],
rep: int = 20,
grad_to_none: Sequence[torch.Tensor] | None = None,
) -> float:
"""
Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing.
Only supports calculating mean execution time.

Args:
fn: Function to benchmark
rep: Target total measurement time in milliseconds
grad_to_none: Tensors whose gradients should be cleared before each measurement

Returns:
Mean execution time in milliseconds
"""
# Get a cache tensor and function to zero it for L2 cache clearing
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() # type: ignore[attr-defined]
clear_cache_fn = cache.zero_

with torch.cuda.stream(torch.cuda.Stream()):
# Warmup: clear cache and run function once to ensure it's compiled
clear_cache_fn()
fn()

# Reset gradients if needed
if grad_to_none is not None:
for x in grad_to_none:
x.detach_()
x.requires_grad_(True)
x.grad = None

# Estimate execution time
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
clear_cache_fn()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# Calculate number of repetitions needed to reach target measurement time (rep)
n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))

# Create a CUDA graph for measuring total time (cache clearing + kernel execution)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for _ in range(n_repeat):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
clear_cache_fn()
fn()
torch.cuda.synchronize()

# Create a separate CUDA graph for measuring cache clearing time only
cache_clear_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cache_clear_graph):
for _ in range(n_repeat):
clear_cache_fn()
torch.cuda.synchronize()

# Measure time for cache clearing only
cache_clear_start_event = torch.cuda.Event(enable_timing=True)
cache_clear_end_event = torch.cuda.Event(enable_timing=True)
cache_clear_start_event.record()
cache_clear_graph.replay()
cache_clear_end_event.record()
torch.cuda.synchronize()
cache_clear_time = (
cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat
)

# Measure total time (cache clearing + kernel execution)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
g.replay()
end_event.record()
torch.cuda.synchronize()
total_time = start_event.elapsed_time(end_event) / n_repeat

# Subtract cache clearing overhead to get pure kernel execution time
return total_time - cache_clear_time


def compute_repeat(
fn: Callable[[], object],
*,
Expand Down
17 changes: 13 additions & 4 deletions test/test_debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from helion._testing import DEVICE
from helion._testing import RefEagerTestDisabled
from helion._testing import TestCase
from helion._testing import is_cuda
from helion._testing import skipIfCpu
import helion.language as hl

Expand Down Expand Up @@ -142,20 +143,28 @@ def test_print_repro_on_autotune_error(self):
torch.manual_seed(0)
x = torch.randn([128], dtype=torch.float32, device=DEVICE)

# Mock do_bench to fail on the second config with PTXASError (warn level)
# Mock benchmark helper to fail on the second config with PTXASError (warn level)
from torch._inductor.runtime.triton_compat import PTXASError
from triton.testing import do_bench as original_do_bench

from helion.autotuner import base_search

call_count = [0]

bench_attr = (
"do_bench_cudagraph_with_cache_clear" if is_cuda() else "do_bench"
)

original_bench = getattr(base_search, bench_attr)
bench_target = f"helion.autotuner.base_search.{bench_attr}"

def mock_do_bench(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 2: # Fail on second config
raise PTXASError("Mocked PTXAS error")
return original_do_bench(*args, **kwargs)
return original_bench(*args, **kwargs)

with self.capture_output() as output_capture:
with mock.patch("helion.autotuner.base_search.do_bench", mock_do_bench):
with mock.patch(bench_target, mock_do_bench):
# Autotune will try both configs, second one will fail and print repro
kernel.autotune([x], force=False)

Expand Down
Loading