Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarks/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,12 @@ def main():
help="Progress display mode: auto (uses tqdm if available), tqdm (progress bar), "
"simple (periodic status), none (quiet). Default: auto",
)
run_parser.add_argument(
"--sample-timeout",
type=int,
default=300,
help="Timeout per sample in seconds (default: 300 = 5 minutes)",
)

# Benchmark-specific options for run
run_parser.add_argument("--context-length", type=int, default=100_000, help="NIAH context len")
Expand Down Expand Up @@ -470,6 +476,7 @@ def cmd_run(args: argparse.Namespace) -> int:
log_dir=args.log_dir,
max_workers=args.max_workers,
progress=args.progress,
sample_timeout=args.sample_timeout,
)

for benchmark in benchmarks:
Expand Down
20 changes: 18 additions & 2 deletions benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
("no such model", "Model not found. Check available models for your account."),
# Rate limiting (without retry-after)
("rate_limit", "Rate limited. Consider reducing --max-workers or adding delays."),
# Timeout
("timed out", "Sample timed out. Try --sample-timeout to increase, or reduce --context-length."),
]


Expand Down Expand Up @@ -119,6 +121,7 @@ class RunnerConfig:
max_workers: int = 1 # Number of parallel workers (1 = sequential)
progress: str = "auto" # Progress display: "auto", "tqdm", "simple", "none"
progress_callback: ProgressCallback | None = None # Custom progress callback
sample_timeout: int = 300 # Timeout per sample in seconds (default: 5 minutes)
backend_kwargs: dict[str, Any] = field(default_factory=dict)
environment_kwargs: dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -157,6 +160,7 @@ def __init__(
max_workers: int = 1,
progress: str = "auto",
progress_callback: ProgressCallback | None = None,
sample_timeout: int = 300,
**kwargs,
):
"""Initialize runner with configuration.
Expand All @@ -176,6 +180,7 @@ def __init__(
- "none": No progress output
progress_callback: Custom callback for progress updates.
Signature: (completed, total, sample_result, stats) -> None
sample_timeout: Timeout per sample in seconds (default: 300 = 5 minutes).
**kwargs: Additional backend or environment kwargs.
"""
self.config = RunnerConfig(
Expand All @@ -188,6 +193,7 @@ def __init__(
max_workers=max_workers,
progress=progress,
progress_callback=progress_callback,
sample_timeout=sample_timeout,
backend_kwargs={"model_name": model, **kwargs.get("backend_kwargs", {})},
environment_kwargs=kwargs.get("environment_kwargs", {}),
)
Expand Down Expand Up @@ -615,14 +621,24 @@ def _run_sample(
inference_fn: Callable[[BenchmarkSample], tuple[str, dict[str, Any]]],
benchmark: Benchmark,
) -> SampleResult:
"""Run a single sample and evaluate."""
"""Run a single sample and evaluate with timeout."""
from concurrent.futures import TimeoutError as FuturesTimeoutError

start_time = time.time()
error = None
prediction = ""
metadata: dict[str, Any] = {}
timeout = self.config.sample_timeout

try:
prediction, metadata = inference_fn(sample)
# Use a thread pool to enforce timeout
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(inference_fn, sample)
try:
prediction, metadata = future.result(timeout=timeout)
except FuturesTimeoutError:
error = f"Sample timed out after {timeout}s. Try --sample-timeout to increase."
prediction = ""
except Exception as e:
error = str(e)
prediction = ""
Expand Down
9 changes: 9 additions & 0 deletions tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,15 @@ def test_classify_generic_error(self):
is_fatal, suggestion = classify_error(error)
assert not is_fatal

def test_classify_timeout_error(self):
"""Test that timeout errors are classified as fatal with helpful suggestion."""
from benchmarks.runner import classify_error

error = "Sample timed out after 300s"
is_fatal, suggestion = classify_error(error)
assert is_fatal
assert "timeout" in suggestion.lower() or "context" in suggestion.lower()


class TestBenchmarkIntegration:
"""Integration tests for benchmark framework."""
Expand Down
Loading