Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from benchmarks.base import Benchmark, BenchmarkResult, BenchmarkSample
from benchmarks.metrics import Metrics
from benchmarks.results import ExperimentConfig, ExperimentRecord, ResultsStore
from benchmarks.runner import BenchmarkRunner, compare_methods
from benchmarks.runner import BenchmarkRunner, ProgressStats, compare_methods
from benchmarks.tasks import (
BrowseCompPlusBenchmark,
NIAHBenchmark,
Expand All @@ -35,6 +35,7 @@
"BenchmarkSample",
# Runner
"BenchmarkRunner",
"ProgressStats",
"compare_methods",
# Metrics
"Metrics",
Expand Down
10 changes: 10 additions & 0 deletions benchmarks/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,15 @@ def main():
default=1,
help="Parallel workers (1=sequential, >1=parallel threads, default: 1)",
)
run_parser.add_argument(
"--progress",
"-p",
type=str,
default="auto",
choices=["auto", "tqdm", "simple", "none"],
help="Progress display mode: auto (uses tqdm if available), tqdm (progress bar), "
"simple (periodic status), none (quiet). Default: auto",
)

# Benchmark-specific options for run
run_parser.add_argument("--context-length", type=int, default=100_000, help="NIAH context len")
Expand Down Expand Up @@ -390,6 +399,7 @@ def cmd_run(args: argparse.Namespace) -> int:
verbose=args.verbose,
log_dir=args.log_dir,
max_workers=args.max_workers,
progress=args.progress,
)

if args.benchmark == "all":
Expand Down
277 changes: 234 additions & 43 deletions benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- CodeAct (code-generation agents)

Supports parallel execution for faster evaluation.
Includes progress tracking with ETA via tqdm or custom callbacks.
"""

import time
Expand All @@ -19,6 +20,49 @@

from benchmarks.base import Benchmark, BenchmarkResult, BenchmarkSample, SampleResult

# Type alias for progress callback
ProgressCallback = Callable[[int, int, "SampleResult | None", "ProgressStats"], None]


@dataclass
class ProgressStats:
"""Running statistics for progress reporting."""

completed: int = 0
total: int = 0
correct: int = 0
errors: int = 0
total_time_ms: float = 0.0

@property
def accuracy(self) -> float:
"""Current running accuracy."""
if self.completed == 0:
return 0.0
return self.correct / self.completed

@property
def error_rate(self) -> float:
"""Current error rate."""
if self.completed == 0:
return 0.0
return self.errors / self.completed

@property
def avg_time_ms(self) -> float:
"""Average time per sample in milliseconds."""
if self.completed == 0:
return 0.0
return self.total_time_ms / self.completed

@property
def eta_seconds(self) -> float:
"""Estimated time remaining in seconds."""
remaining = self.total - self.completed
if remaining <= 0 or self.avg_time_ms == 0:
return 0.0
return (remaining * self.avg_time_ms) / 1000.0


@dataclass
class RunnerConfig:
Expand All @@ -31,6 +75,8 @@ class RunnerConfig:
verbose: bool = False
log_dir: str | None = None
max_workers: int = 1 # Number of parallel workers (1 = sequential)
progress: str = "auto" # Progress display: "auto", "tqdm", "simple", "none"
progress_callback: ProgressCallback | None = None # Custom progress callback
backend_kwargs: dict[str, Any] = field(default_factory=dict)
environment_kwargs: dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -67,6 +113,8 @@ def __init__(
verbose: bool = False,
log_dir: str | None = None,
max_workers: int = 1,
progress: str = "auto",
progress_callback: ProgressCallback | None = None,
**kwargs,
):
"""Initialize runner with configuration.
Expand All @@ -79,6 +127,13 @@ def __init__(
verbose: Enable verbose output.
log_dir: Directory for logging trajectories.
max_workers: Number of parallel workers (1 = sequential, >1 = parallel).
progress: Progress display mode:
- "auto": Use tqdm if available, else simple
- "tqdm": Force tqdm progress bar
- "simple": Print periodic status updates
- "none": No progress output
progress_callback: Custom callback for progress updates.
Signature: (completed, total, sample_result, stats) -> None
**kwargs: Additional backend or environment kwargs.
"""
self.config = RunnerConfig(
Expand All @@ -89,9 +144,46 @@ def __init__(
verbose=verbose,
log_dir=log_dir,
max_workers=max_workers,
progress=progress,
progress_callback=progress_callback,
backend_kwargs={"model_name": model, **kwargs.get("backend_kwargs", {})},
environment_kwargs=kwargs.get("environment_kwargs", {}),
)
self._tqdm_available: bool | None = None

def _check_tqdm(self) -> bool:
"""Check if tqdm is available."""
if self._tqdm_available is None:
try:
import tqdm # noqa: F401

self._tqdm_available = True
except ImportError:
self._tqdm_available = False
return self._tqdm_available

def _get_progress_mode(self) -> str:
"""Determine effective progress mode."""
mode = self.config.progress
if mode == "auto":
return "tqdm" if self._check_tqdm() else "simple"
if mode == "tqdm" and not self._check_tqdm():
return "simple"
return mode

def _format_eta(self, seconds: float) -> str:
"""Format seconds as human-readable duration."""
if seconds <= 0:
return "--:--"
if seconds < 60:
return f"{seconds:.0f}s"
if seconds < 3600:
mins = int(seconds // 60)
secs = int(seconds % 60)
return f"{mins}:{secs:02d}"
hours = int(seconds // 3600)
mins = int((seconds % 3600) // 60)
return f"{hours}h{mins:02d}m"

def run(
self,
Expand Down Expand Up @@ -135,76 +227,175 @@ def run(

# Collect samples first (needed for parallel execution)
samples = list(benchmark.load_samples(num_samples=num_samples, seed=seed))
total = len(samples)

# Initialize progress tracking
stats = ProgressStats(total=total)
progress_mode = self._get_progress_mode()

if workers <= 1:
# Sequential execution
# Sequential execution with progress
result.sample_results = self._run_sequential(
samples, inference_fn, benchmark, stats, progress_mode
)
else:
# Parallel execution with progress
result.sample_results = self._run_parallel(
samples, inference_fn, benchmark, workers, stats, progress_mode
)

return result

def _update_progress(
self,
sample_result: SampleResult,
stats: ProgressStats,
progress_mode: str,
pbar: Any = None,
) -> None:
"""Update progress statistics and display."""
stats.completed += 1
if sample_result.is_correct:
stats.correct += 1
if sample_result.error:
stats.errors += 1
stats.total_time_ms += sample_result.execution_time_ms

# Call custom callback if provided
if self.config.progress_callback:
self.config.progress_callback(
stats.completed, stats.total, sample_result, stats
)

# Update display based on mode
if progress_mode == "tqdm" and pbar is not None:
pbar.set_postfix(
acc=f"{stats.accuracy:.1%}",
err=stats.errors,
eta=self._format_eta(stats.eta_seconds),
refresh=False,
)
pbar.update(1)
elif progress_mode == "simple":
# Print every 10% or every sample for small runs
interval = max(1, stats.total // 10)
if stats.completed % interval == 0 or stats.completed == stats.total:
print(
f" Progress: {stats.completed}/{stats.total} "
f"({stats.completed/stats.total:.0%}) | "
f"Acc: {stats.accuracy:.1%} | "
f"Errors: {stats.errors} | "
f"ETA: {self._format_eta(stats.eta_seconds)}"
)

# Verbose per-sample output
if self.config.verbose:
status = "✓" if sample_result.is_correct else "✗"
print(f" [{status}] Sample {sample_result.sample_id}: {sample_result.metrics}")

def _run_sequential(
self,
samples: list[BenchmarkSample],
inference_fn: Callable[[BenchmarkSample], tuple[str, dict[str, Any]]],
benchmark: Benchmark,
stats: ProgressStats,
progress_mode: str,
) -> list[SampleResult]:
"""Run samples sequentially with progress tracking."""
results: list[SampleResult] = []
pbar = None

try:
if progress_mode == "tqdm":
from tqdm import tqdm

pbar = tqdm(
total=stats.total,
desc=f"{benchmark.name}",
unit="sample",
ncols=100,
)

for sample in samples:
sample_result = self._run_sample(sample, inference_fn, benchmark)
result.sample_results.append(sample_result)
results.append(sample_result)
self._update_progress(sample_result, stats, progress_mode, pbar)

if self.config.verbose:
status = "✓" if sample_result.is_correct else "✗"
print(f" [{status}] Sample {sample.id}: {sample_result.metrics}")
else:
# Parallel execution
result.sample_results = self._run_parallel(samples, inference_fn, benchmark, workers)
finally:
if pbar is not None:
pbar.close()

return result
return results

def _run_parallel(
self,
samples: list[BenchmarkSample],
inference_fn: Callable[[BenchmarkSample], tuple[str, dict[str, Any]]],
benchmark: Benchmark,
max_workers: int,
stats: ProgressStats,
progress_mode: str,
) -> list[SampleResult]:
"""Run samples in parallel using thread pool.
"""Run samples in parallel using thread pool with progress tracking.

Args:
samples: List of samples to process.
inference_fn: Inference function to apply.
benchmark: Benchmark for evaluation.
max_workers: Number of parallel threads.
stats: Progress statistics to update.
progress_mode: Progress display mode.

Returns:
List of SampleResult in original sample order.
"""
import threading

results: dict[str, SampleResult] = {}
completed = 0
total = len(samples)
lock = threading.Lock()
pbar = None

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_sample = {
executor.submit(self._run_sample, sample, inference_fn, benchmark): sample
for sample in samples
}

# Collect results as they complete
for future in as_completed(future_to_sample):
sample = future_to_sample[future]
try:
sample_result = future.result()
results[sample.id] = sample_result
completed += 1

if self.config.verbose:
status = "✓" if sample_result.is_correct else "✗"
print(
f" [{status}] ({completed}/{total}) "
f"Sample {sample.id}: {sample_result.metrics}"
try:
if progress_mode == "tqdm":
from tqdm import tqdm

pbar = tqdm(
total=stats.total,
desc=f"{benchmark.name}",
unit="sample",
ncols=100,
)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_sample = {
executor.submit(self._run_sample, sample, inference_fn, benchmark): sample
for sample in samples
}

# Collect results as they complete
for future in as_completed(future_to_sample):
sample = future_to_sample[future]
try:
sample_result = future.result()
except Exception as e:
# Handle unexpected errors
sample_result = SampleResult(
sample_id=sample.id,
prediction="",
expected=sample.expected_answer,
is_correct=False,
metrics={m: 0.0 for m in benchmark.default_metrics()},
error=f"Parallel execution error: {e}",
)
except Exception as e:
# Handle unexpected errors
results[sample.id] = SampleResult(
sample_id=sample.id,
prediction="",
expected=sample.expected_answer,
is_correct=False,
metrics={m: 0.0 for m in benchmark.default_metrics()},
error=f"Parallel execution error: {e}",
)
completed += 1

with lock:
results[sample.id] = sample_result
self._update_progress(sample_result, stats, progress_mode, pbar)

finally:
if pbar is not None:
pbar.close()

# Return results in original sample order
return [results[sample.id] for sample in samples]
Expand Down
Loading
Loading