diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000..5e12c53d --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1,50 @@ +""" +RLM Benchmarks Framework + +A minimal, extensible framework for evaluating Recursive Language Models +against the benchmarks from the original RLM paper (arXiv:2512.24601). + +Benchmark tasks: +- S-NIAH: Single-Needle-in-a-Haystack (context-insensitive retrieval) +- BrowseComp-Plus: Compositional multi-hop QA over document corpora +- OOLONG: Semantic aggregation over long contexts +- OOLONG-Pairs: Pairwise combinatorial aggregation + +Usage: + from benchmarks import BenchmarkRunner, OolongBenchmark + + runner = BenchmarkRunner(backend="openai", model="gpt-5-mini") + results = runner.run(OolongBenchmark(subset="toy_dnd", num_samples=10)) +""" + +from benchmarks.base import Benchmark, BenchmarkResult, BenchmarkSample +from benchmarks.metrics import Metrics +from benchmarks.results import ExperimentConfig, ExperimentRecord, ResultsStore +from benchmarks.runner import BenchmarkRunner, compare_methods +from benchmarks.tasks import ( + BrowseCompPlusBenchmark, + NIAHBenchmark, + OolongBenchmark, + OolongPairsBenchmark, +) + +__all__ = [ + # Base classes + "Benchmark", + "BenchmarkResult", + "BenchmarkSample", + # Runner + "BenchmarkRunner", + "compare_methods", + # Metrics + "Metrics", + # Results storage + "ResultsStore", + "ExperimentConfig", + "ExperimentRecord", + # Benchmark tasks + "NIAHBenchmark", + "OolongBenchmark", + "OolongPairsBenchmark", + "BrowseCompPlusBenchmark", +] diff --git a/benchmarks/base.py b/benchmarks/base.py new file mode 100644 index 00000000..fbd5e942 --- /dev/null +++ b/benchmarks/base.py @@ -0,0 +1,174 @@ +""" +Base classes for RLM benchmarks. + +Provides abstract interfaces that all benchmark implementations must follow, +enabling consistent evaluation and extensibility. +""" + +from abc import ABC, abstractmethod +from collections.abc import Iterator +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class BenchmarkSample: + """A single benchmark sample with context, question, and expected answer.""" + + id: str + context: str + question: str + expected_answer: str | list[str] + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def context_tokens_approx(self) -> int: + """Approximate token count (rough estimate: ~4 chars per token).""" + return len(self.context) // 4 + + +@dataclass +class SampleResult: + """Result for a single benchmark sample.""" + + sample_id: str + prediction: str + expected: str | list[str] + is_correct: bool + metrics: dict[str, float] = field(default_factory=dict) + iterations: int = 0 + total_tokens: int = 0 + execution_time_ms: float = 0.0 + error: str | None = None + + +@dataclass +class BenchmarkResult: + """Aggregated results for a complete benchmark run.""" + + benchmark_name: str + method: str # "rlm", "direct", "summarize", "rag", etc. + model: str + sample_results: list[SampleResult] = field(default_factory=list) + config: dict[str, Any] = field(default_factory=dict) + + @property + def accuracy(self) -> float: + """Fraction of correct predictions.""" + if not self.sample_results: + return 0.0 + correct = sum(1 for r in self.sample_results if r.is_correct) + return correct / len(self.sample_results) + + @property + def mean_f1(self) -> float: + """Mean F1 score across samples.""" + f1_scores = [r.metrics.get("f1", 0.0) for r in self.sample_results] + return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0 + + @property + def total_tokens(self) -> int: + """Total tokens used across all samples.""" + return sum(r.total_tokens for r in self.sample_results) + + @property + def mean_iterations(self) -> float: + """Mean number of RLM iterations.""" + iters = [r.iterations for r in self.sample_results] + return sum(iters) / len(iters) if iters else 0.0 + + @property + def mean_execution_time_ms(self) -> float: + """Mean execution time per sample in milliseconds.""" + times = [r.execution_time_ms for r in self.sample_results] + return sum(times) / len(times) if times else 0.0 + + @property + def error_rate(self) -> float: + """Fraction of samples that resulted in errors.""" + errors = sum(1 for r in self.sample_results if r.error is not None) + return errors / len(self.sample_results) if self.sample_results else 0.0 + + def summary(self) -> dict[str, Any]: + """Return summary statistics.""" + return { + "benchmark": self.benchmark_name, + "method": self.method, + "model": self.model, + "num_samples": len(self.sample_results), + "accuracy": self.accuracy, + "mean_f1": self.mean_f1, + "total_tokens": self.total_tokens, + "mean_iterations": self.mean_iterations, + "mean_execution_time_ms": self.mean_execution_time_ms, + "error_rate": self.error_rate, + } + + +class Benchmark(ABC): + """Abstract base class for all benchmarks. + + To create a new benchmark: + 1. Subclass Benchmark + 2. Implement name, description, load_samples(), and evaluate() + 3. Optionally override default_metrics() for custom evaluation + + Example: + class MyBenchmark(Benchmark): + @property + def name(self) -> str: + return "my-benchmark" + + def load_samples(self, num_samples: int | None = None) -> Iterator[BenchmarkSample]: + # Load from dataset, file, or generate samples + yield BenchmarkSample(...) + + def evaluate(self, prediction: str, expected: str | list[str]) -> dict[str, float]: + # Return metrics dict with at least "correct" and "f1" + return {"correct": 1.0 if prediction == expected else 0.0, "f1": ...} + """ + + @property + @abstractmethod + def name(self) -> str: + """Unique identifier for this benchmark.""" + pass + + @property + def description(self) -> str: + """Human-readable description of the benchmark.""" + return "" + + @abstractmethod + def load_samples( + self, num_samples: int | None = None, seed: int | None = None + ) -> Iterator[BenchmarkSample]: + """Load benchmark samples. + + Args: + num_samples: Maximum number of samples to load. None for all. + seed: Random seed for reproducible sampling. + + Yields: + BenchmarkSample instances. + """ + pass + + @abstractmethod + def evaluate(self, prediction: str, expected: str | list[str]) -> dict[str, float]: + """Evaluate a prediction against expected answer(s). + + Args: + prediction: Model's prediction string. + expected: Expected answer(s). Can be a single string or list of valid answers. + + Returns: + Dictionary with at least: + - "correct": 1.0 if correct, 0.0 otherwise + - "f1": F1 score between prediction and expected + """ + pass + + def default_metrics(self) -> list[str]: + """List of metric names this benchmark produces.""" + return ["correct", "f1"] diff --git a/benchmarks/cli.py b/benchmarks/cli.py new file mode 100644 index 00000000..a646d75e --- /dev/null +++ b/benchmarks/cli.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 +""" +CLI for running RLM benchmarks. + +Usage: + # Run benchmarks + python -m benchmarks.cli run --benchmark oolong --num-samples 10 + python -m benchmarks.cli run --benchmark niah --methods rlm direct + + # Query stored results + python -m benchmarks.cli history --benchmark oolong-toy_dnd --limit 10 + python -m benchmarks.cli compare --benchmark niah-100k --group-by method + + # Legacy (no subcommand defaults to 'run') + python -m benchmarks.cli --benchmark oolong --num-samples 10 +""" + +import argparse +import json +import sys +from datetime import datetime + +from benchmarks.base import BenchmarkResult +from benchmarks.results import ExperimentConfig, ResultsStore +from benchmarks.runner import BenchmarkRunner, compare_methods +from benchmarks.tasks.browsecomp import BrowseCompPlusBenchmark +from benchmarks.tasks.niah import NIAHBenchmark +from benchmarks.tasks.oolong import OolongBenchmark, OolongPairsBenchmark + + +def get_benchmark(args: argparse.Namespace): + """Instantiate benchmark from CLI arguments.""" + name = args.benchmark.lower() + + if name == "niah": + return NIAHBenchmark( + context_length=args.context_length, + needle_depth=args.needle_depth, + ) + elif name == "oolong": + return OolongBenchmark(subset=args.subset) + elif name == "oolong-pairs": + return OolongPairsBenchmark( + num_items=args.num_items, + num_pairs=args.num_pairs, + ) + elif name == "browsecomp": + return BrowseCompPlusBenchmark( + num_documents=args.num_documents, + num_hops=args.num_hops, + ) + else: + raise ValueError(f"Unknown benchmark: {name}") + + +def get_all_benchmarks(args: argparse.Namespace): + """Get all benchmarks for comprehensive evaluation.""" + return [ + NIAHBenchmark(context_length=args.context_length), + OolongBenchmark(subset="toy_dnd"), + OolongPairsBenchmark(num_items=30, num_pairs=15), + BrowseCompPlusBenchmark(num_documents=50, num_hops=2), + ] + + +def save_results(results: dict[str, BenchmarkResult], output_path: str): + """Save results to JSON file.""" + output = { + "timestamp": datetime.now().isoformat(), + "results": { + name: { + "summary": result.summary(), + "sample_results": [ + { + "id": sr.sample_id, + "prediction": sr.prediction, + "expected": sr.expected, + "is_correct": sr.is_correct, + "metrics": sr.metrics, + "iterations": sr.iterations, + "total_tokens": sr.total_tokens, + "execution_time_ms": sr.execution_time_ms, + "error": sr.error, + } + for sr in result.sample_results + ], + } + for name, result in results.items() + }, + } + + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + + print(f"\nResults saved to: {output_path}") + + +def print_summary(results: dict[str, BenchmarkResult]): + """Print a summary table of results.""" + print("\n" + "=" * 70) + print("BENCHMARK RESULTS SUMMARY") + print("=" * 70) + + headers = ["Method", "Accuracy", "Mean F1", "Tokens", "Iterations", "Time (ms)"] + col_widths = [15, 10, 10, 12, 12, 12] + + # Print header + header_line = " | ".join(h.ljust(w) for h, w in zip(headers, col_widths, strict=True)) + print(header_line) + print("-" * len(header_line)) + + # Print results + for name, result in results.items(): + row = [ + name[:15], + f"{result.accuracy:.1%}", + f"{result.mean_f1:.3f}", + str(result.total_tokens), + f"{result.mean_iterations:.1f}", + f"{result.mean_execution_time_ms:.0f}", + ] + print(" | ".join(v.ljust(w) for v, w in zip(row, col_widths, strict=True))) + + print("=" * 70) + + +def cmd_history(args: argparse.Namespace) -> int: + """Show historical results for a benchmark.""" + store = ResultsStore(args.results_dir) + records = store.get_history( + benchmark=args.benchmark, + model=args.model, + method=args.method, + limit=args.limit, + ) + + if not records: + print(f"No results found for benchmark: {args.benchmark}") + return 1 + + print(f"\nHistory for {args.benchmark} ({len(records)} results)") + print("=" * 80) + + headers = ["Timestamp", "Model", "Method", "Accuracy", "F1", "Tokens"] + col_widths = [20, 20, 10, 10, 8, 10] + header_line = " | ".join(h.ljust(w) for h, w in zip(headers, col_widths, strict=True)) + print(header_line) + print("-" * len(header_line)) + + for r in records: + row = [ + r.timestamp[:19], + r.config.model[:20], + r.config.method[:10], + f"{r.accuracy:.1%}", + f"{r.mean_f1:.3f}", + str(r.total_tokens), + ] + print(" | ".join(v.ljust(w) for v, w in zip(row, col_widths, strict=True))) + + return 0 + + +def cmd_compare(args: argparse.Namespace) -> int: + """Compare results grouped by a dimension.""" + store = ResultsStore(args.results_dir) + comparison = store.compare( + benchmark=args.benchmark, + group_by=args.group_by, + filter_model=args.model, + ) + + if not comparison: + print(f"No results found for benchmark: {args.benchmark}") + return 1 + + print(f"\nComparison for {args.benchmark} (grouped by {args.group_by})") + print("=" * 70) + + headers = ["Group", "Count", "Accuracy", "Max Acc", "F1", "Tokens"] + col_widths = [15, 8, 10, 10, 8, 12] + header_line = " | ".join(h.ljust(w) for h, w in zip(headers, col_widths, strict=True)) + print(header_line) + print("-" * len(header_line)) + + for group, metrics in sorted(comparison.items()): + row = [ + group[:15], + str(metrics["count"]), + f"{metrics['mean_accuracy']:.1%}", + f"{metrics['max_accuracy']:.1%}", + f"{metrics['mean_f1']:.3f}", + f"{metrics['mean_tokens']:.0f}", + ] + print(" | ".join(v.ljust(w) for v, w in zip(row, col_widths, strict=True))) + + return 0 + + +def cmd_list(args: argparse.Namespace) -> int: + """List available benchmarks and summary.""" + store = ResultsStore(args.results_dir) + summary = store.summary() + + print("\nBenchmark Results Summary") + print("=" * 50) + print(f"Results directory: {summary['results_dir']}") + print(f"Total experiments: {summary['total_experiments']}") + print("\nBenchmarks:") + + for name, count in summary["benchmarks"].items(): + print(f" - {name}: {count} runs") + + return 0 + + +def cmd_export(args: argparse.Namespace) -> int: + """Export results to CSV.""" + store = ResultsStore(args.results_dir) + store.export_csv(args.benchmark, args.output) + print(f"Exported {args.benchmark} results to {args.output}") + return 0 + + +def main(): + parser = argparse.ArgumentParser( + description="RLM Benchmarks CLI", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Global options + parser.add_argument( + "--results-dir", + type=str, + default="./benchmark_results", + help="Directory for storing results (default: ./benchmark_results)", + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # ========== RUN subcommand ========== + run_parser = subparsers.add_parser("run", help="Run benchmarks") + + run_parser.add_argument( + "--benchmark", + "-b", + type=str, + required=True, + choices=["niah", "oolong", "oolong-pairs", "browsecomp", "all"], + help="Benchmark to run", + ) + run_parser.add_argument( + "--num-samples", + "-n", + type=int, + default=10, + help="Number of samples to evaluate (default: 10)", + ) + run_parser.add_argument( + "--seed", + "-s", + type=int, + default=None, + help="Random seed for reproducibility", + ) + run_parser.add_argument( + "--methods", + "-m", + nargs="+", + default=["rlm"], + choices=["rlm", "direct", "summarize"], + help="Inference methods to compare (default: rlm)", + ) + run_parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Output file path for results (JSON)", + ) + run_parser.add_argument( + "--backend", + type=str, + default="openai", + help="LLM backend (default: openai)", + ) + run_parser.add_argument( + "--model", + type=str, + default="gpt-5-mini", + help="Model name (default: gpt-5-mini)", + ) + run_parser.add_argument( + "--environment", + type=str, + default="subprocess", + help="REPL environment (default: subprocess)", + ) + run_parser.add_argument( + "--max-iterations", + type=int, + default=30, + help="Max RLM iterations (default: 30)", + ) + run_parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose output", + ) + run_parser.add_argument("--log-dir", type=str, default=None, help="Trajectory logs directory") + run_parser.add_argument( + "--max-workers", + "-w", + type=int, + default=1, + help="Parallel workers (1=sequential, >1=parallel threads, default: 1)", + ) + + # Benchmark-specific options for run + run_parser.add_argument("--context-length", type=int, default=100_000, help="NIAH context len") + run_parser.add_argument("--needle-depth", type=float, default=None, help="NIAH needle position") + run_parser.add_argument("--subset", type=str, default="toy_dnd", help="OOLONG subset") + run_parser.add_argument("--num-items", type=int, default=50, help="OOLONG-Pairs items") + run_parser.add_argument("--num-pairs", type=int, default=25, help="OOLONG-Pairs pairs") + run_parser.add_argument("--num-documents", type=int, default=100, help="BrowseComp docs") + run_parser.add_argument("--num-hops", type=int, default=2, help="BrowseComp hops") + + # ========== HISTORY subcommand ========== + history_parser = subparsers.add_parser("history", help="Show historical results") + history_parser.add_argument("--benchmark", "-b", type=str, required=True, help="Benchmark name") + history_parser.add_argument("--model", type=str, default=None, help="Filter by model") + history_parser.add_argument("--method", type=str, default=None, help="Filter by method") + history_parser.add_argument("--limit", type=int, default=20, help="Max results to show") + + # ========== COMPARE subcommand ========== + compare_parser = subparsers.add_parser("compare", help="Compare results") + compare_parser.add_argument("--benchmark", "-b", type=str, required=True, help="Benchmark name") + compare_parser.add_argument( + "--group-by", + type=str, + default="method", + choices=["method", "model", "environment"], + help="Group by dimension", + ) + compare_parser.add_argument("--model", type=str, default=None, help="Filter by model") + + # ========== LIST subcommand ========== + subparsers.add_parser("list", help="List benchmarks and summary") + + # ========== EXPORT subcommand ========== + export_parser = subparsers.add_parser("export", help="Export results to CSV") + export_parser.add_argument("--benchmark", "-b", type=str, required=True, help="Benchmark name") + export_parser.add_argument("--output", "-o", type=str, required=True, help="Output CSV path") + + args = parser.parse_args() + + # Handle legacy usage (no subcommand) + if args.command is None: + # Check if --benchmark was passed (legacy mode) + if hasattr(args, "benchmark") and args.benchmark: + args.command = "run" + else: + parser.print_help() + return 1 + + # Dispatch to appropriate command + if args.command == "run": + return cmd_run(args) + elif args.command == "history": + return cmd_history(args) + elif args.command == "compare": + return cmd_compare(args) + elif args.command == "list": + return cmd_list(args) + elif args.command == "export": + return cmd_export(args) + else: + parser.print_help() + return 1 + + +def cmd_run(args: argparse.Namespace) -> int: + """Run benchmarks (main command).""" + runner = BenchmarkRunner( + backend=args.backend, + model=args.model, + environment=args.environment, + max_iterations=args.max_iterations, + verbose=args.verbose, + log_dir=args.log_dir, + max_workers=args.max_workers, + ) + + if args.benchmark == "all": + benchmarks = get_all_benchmarks(args) + else: + benchmarks = [get_benchmark(args)] + + all_results = {} + + for benchmark in benchmarks: + print(f"\n{'=' * 60}") + print(f"Running: {benchmark.name}") + print(f"Description: {benchmark.description}") + print(f"{'=' * 60}") + + results = compare_methods( + benchmark=benchmark, + runner=runner, + methods=args.methods, + num_samples=args.num_samples, + seed=args.seed, + ) + + for method, result in results.items(): + key = f"{benchmark.name}/{method}" + all_results[key] = result + + # Print summary + print_summary(all_results) + + # Save to results store + store = ResultsStore(args.results_dir) + config = ExperimentConfig( + backend=args.backend, + model=args.model, + environment=args.environment, + max_iterations=args.max_iterations, + num_samples=args.num_samples, + seed=args.seed, + ) + + for key, result in all_results.items(): + exp_id = store.save(result, config) + print(f"Saved {key} as experiment {exp_id}") + + if args.output: + save_results(all_results, args.output) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/metrics.py b/benchmarks/metrics.py new file mode 100644 index 00000000..635ad88b --- /dev/null +++ b/benchmarks/metrics.py @@ -0,0 +1,188 @@ +""" +Evaluation metrics for RLM benchmarks. + +Provides common metrics used across benchmark evaluations: +- Exact match +- Containment (answer in prediction or vice versa) +- Token-level F1 +- Normalized string comparison +""" + +import re +import string +from collections import Counter + + +class Metrics: + """Collection of evaluation metrics for benchmark scoring.""" + + @staticmethod + def normalize(text: str) -> str: + """Normalize text for comparison. + + - Lowercase + - Remove punctuation + - Collapse whitespace + - Strip leading/trailing whitespace + """ + text = text.lower() + text = text.translate(str.maketrans("", "", string.punctuation)) + text = re.sub(r"\s+", " ", text) + return text.strip() + + @staticmethod + def exact_match(prediction: str, expected: str | list[str]) -> bool: + """Check if prediction exactly matches expected (after normalization). + + Args: + prediction: Model prediction. + expected: Single expected answer or list of valid answers. + + Returns: + True if normalized prediction matches any expected answer. + """ + pred_norm = Metrics.normalize(prediction) + + if isinstance(expected, str): + expected = [expected] + + return any(pred_norm == Metrics.normalize(exp) for exp in expected) + + @staticmethod + def containment(prediction: str, expected: str | list[str]) -> bool: + """Check if expected is contained in prediction or vice versa. + + More lenient than exact_match - useful when answers may be + embedded in longer responses. + + Args: + prediction: Model prediction. + expected: Single expected answer or list of valid answers. + + Returns: + True if any containment relationship exists. + """ + pred_norm = Metrics.normalize(prediction) + + if isinstance(expected, str): + expected = [expected] + + for exp in expected: + exp_norm = Metrics.normalize(exp) + if exp_norm in pred_norm or pred_norm in exp_norm: + return True + + return False + + @staticmethod + def token_f1(prediction: str, expected: str | list[str]) -> float: + """Compute token-level F1 score. + + Treats prediction and expected as bags of tokens and computes + precision, recall, and F1. + + Args: + prediction: Model prediction. + expected: Single expected answer or list of valid answers. + If list, returns max F1 across all expected answers. + + Returns: + F1 score between 0.0 and 1.0. + """ + if isinstance(expected, str): + expected = [expected] + + pred_tokens = Metrics.normalize(prediction).split() + + if not pred_tokens: + return 0.0 + + max_f1 = 0.0 + for exp in expected: + exp_tokens = Metrics.normalize(exp).split() + + if not exp_tokens: + continue + + pred_counter = Counter(pred_tokens) + exp_counter = Counter(exp_tokens) + + common = sum((pred_counter & exp_counter).values()) + + if common == 0: + continue + + precision = common / len(pred_tokens) + recall = common / len(exp_tokens) + f1 = 2 * precision * recall / (precision + recall) + + max_f1 = max(max_f1, f1) + + return max_f1 + + @staticmethod + def evaluate_standard(prediction: str, expected: str | list[str]) -> dict[str, float]: + """Standard evaluation combining multiple metrics. + + Returns: + Dictionary with: + - "correct": 1.0 if exact_match or containment, else 0.0 + - "exact_match": 1.0 if exact match, else 0.0 + - "containment": 1.0 if containment, else 0.0 + - "f1": Token-level F1 score + """ + exact = Metrics.exact_match(prediction, expected) + contained = Metrics.containment(prediction, expected) + f1 = Metrics.token_f1(prediction, expected) + + return { + "correct": 1.0 if (exact or contained) else 0.0, + "exact_match": 1.0 if exact else 0.0, + "containment": 1.0 if contained else 0.0, + "f1": f1, + } + + @staticmethod + def pairwise_f1( + predicted_pairs: set[tuple[str, str]], + expected_pairs: set[tuple[str, str]], + ) -> dict[str, float]: + """Compute F1 for pairwise predictions (e.g., OOLONG-Pairs). + + Both predicted and expected should be sets of (item1, item2) tuples. + Pairs are treated as unordered (a, b) == (b, a). + + Args: + predicted_pairs: Set of predicted pairs. + expected_pairs: Set of expected pairs. + + Returns: + Dictionary with precision, recall, and f1. + """ + + # Normalize pairs to be order-independent + def normalize_pair(p: tuple[str, str]) -> tuple[str, str]: + return tuple(sorted([Metrics.normalize(p[0]), Metrics.normalize(p[1])])) + + pred_normalized = {normalize_pair(p) for p in predicted_pairs} + exp_normalized = {normalize_pair(p) for p in expected_pairs} + + if not pred_normalized and not exp_normalized: + return {"precision": 1.0, "recall": 1.0, "f1": 1.0} + + if not pred_normalized: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0} + + if not exp_normalized: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0} + + common = len(pred_normalized & exp_normalized) + precision = common / len(pred_normalized) + recall = common / len(exp_normalized) + + if precision + recall == 0: + f1 = 0.0 + else: + f1 = 2 * precision * recall / (precision + recall) + + return {"precision": precision, "recall": recall, "f1": f1} diff --git a/benchmarks/results.py b/benchmarks/results.py new file mode 100644 index 00000000..932dc9b8 --- /dev/null +++ b/benchmarks/results.py @@ -0,0 +1,438 @@ +""" +Results storage and comparison for RLM benchmarks. + +Provides persistent storage of benchmark results with: +- Structured experiment metadata (model, environment, config) +- Historical comparison across runs +- Statistical analysis utilities +""" + +import hashlib +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any + +from benchmarks.base import BenchmarkResult + + +@dataclass +class ExperimentConfig: + """Complete configuration for an experiment run.""" + + # Model configuration + backend: str + model: str + + # Environment configuration + environment: str + max_iterations: int = 30 + + # Benchmark configuration + benchmark_name: str = "" + method: str = "rlm" + num_samples: int | None = None + seed: int | None = None + + # Additional kwargs + backend_kwargs: dict[str, Any] = field(default_factory=dict) + environment_kwargs: dict[str, Any] = field(default_factory=dict) + + def config_hash(self) -> str: + """Generate a hash of the configuration for deduplication.""" + config_str = json.dumps(asdict(self), sort_keys=True, default=str) + return hashlib.md5(config_str.encode()).hexdigest()[:12] + + +@dataclass +class ExperimentRecord: + """A single experiment run with full metadata.""" + + # Identifiers + experiment_id: str + timestamp: str + + # Configuration + config: ExperimentConfig + + # Results summary + accuracy: float + mean_f1: float + total_tokens: int + mean_iterations: float + mean_execution_time_ms: float + error_rate: float + num_samples: int + + # Optional: full sample results + sample_results: list[dict[str, Any]] | None = None + + # Git/version info for reproducibility + git_commit: str | None = None + rlm_version: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "experiment_id": self.experiment_id, + "timestamp": self.timestamp, + "config": asdict(self.config), + "results": { + "accuracy": self.accuracy, + "mean_f1": self.mean_f1, + "total_tokens": self.total_tokens, + "mean_iterations": self.mean_iterations, + "mean_execution_time_ms": self.mean_execution_time_ms, + "error_rate": self.error_rate, + "num_samples": self.num_samples, + }, + "sample_results": self.sample_results, + "metadata": { + "git_commit": self.git_commit, + "rlm_version": self.rlm_version, + }, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ExperimentRecord": + """Create from dictionary.""" + config = ExperimentConfig(**data["config"]) + results = data["results"] + return cls( + experiment_id=data["experiment_id"], + timestamp=data["timestamp"], + config=config, + accuracy=results["accuracy"], + mean_f1=results["mean_f1"], + total_tokens=results["total_tokens"], + mean_iterations=results["mean_iterations"], + mean_execution_time_ms=results["mean_execution_time_ms"], + error_rate=results["error_rate"], + num_samples=results["num_samples"], + sample_results=data.get("sample_results"), + git_commit=data.get("metadata", {}).get("git_commit"), + rlm_version=data.get("metadata", {}).get("rlm_version"), + ) + + +class ResultsStore: + """Persistent storage for benchmark results. + + Stores results in JSON-lines format for efficient append and query. + Each benchmark gets its own file for easy management. + + Directory structure: + results_dir/ + niah-100k.jsonl + oolong-toy_dnd.jsonl + oolong-pairs-50x25.jsonl + index.json # Quick lookup index + + Usage: + store = ResultsStore("./benchmark_results") + + # Save a result + store.save(benchmark_result, config) + + # Query historical results + history = store.get_history("oolong-toy_dnd", model="gpt-5") + + # Compare methods + comparison = store.compare( + benchmark="niah-100k", + group_by="method", + filter_model="gpt-5" + ) + """ + + def __init__(self, results_dir: str = "./benchmark_results"): + self.results_dir = Path(results_dir) + self.results_dir.mkdir(parents=True, exist_ok=True) + self.index_path = self.results_dir / "index.json" + self._load_index() + + def _load_index(self): + """Load or create the index file.""" + if self.index_path.exists(): + with open(self.index_path) as f: + self._index = json.load(f) + else: + self._index = {"benchmarks": {}, "experiments": []} + + def _save_index(self): + """Save the index file.""" + with open(self.index_path, "w") as f: + json.dump(self._index, f, indent=2) + + def _get_git_commit(self) -> str | None: + """Get current git commit hash.""" + try: + import subprocess + + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + cwd=self.results_dir.parent, + ) + if result.returncode == 0: + return result.stdout.strip()[:12] + except Exception: + pass + return None + + def _get_rlm_version(self) -> str | None: + """Get RLM package version.""" + try: + import importlib.metadata + + return importlib.metadata.version("rlm") + except Exception: + return None + + def save( + self, + result: BenchmarkResult, + config: ExperimentConfig | None = None, + include_samples: bool = True, + ) -> str: + """Save a benchmark result. + + Args: + result: BenchmarkResult to save. + config: ExperimentConfig with full configuration. + If None, creates from result.config. + include_samples: Whether to store individual sample results. + + Returns: + Experiment ID for reference. + """ + # Create config if not provided + if config is None: + config = ExperimentConfig( + backend=result.config.get("backend", "unknown"), + model=result.model, + environment=result.config.get("environment", "unknown"), + max_iterations=result.config.get("max_iterations", 30), + benchmark_name=result.benchmark_name, + method=result.method, + num_samples=result.config.get("num_samples"), + seed=result.config.get("seed"), + ) + else: + config.benchmark_name = result.benchmark_name + config.method = result.method + + # Generate experiment ID + timestamp = datetime.now().isoformat() + exp_id = f"{config.config_hash()}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + # Create record + sample_results = None + if include_samples: + sample_results = [ + { + "sample_id": sr.sample_id, + "prediction": sr.prediction, + "expected": sr.expected, + "is_correct": sr.is_correct, + "metrics": sr.metrics, + "iterations": sr.iterations, + "total_tokens": sr.total_tokens, + "execution_time_ms": sr.execution_time_ms, + "error": sr.error, + } + for sr in result.sample_results + ] + + record = ExperimentRecord( + experiment_id=exp_id, + timestamp=timestamp, + config=config, + accuracy=result.accuracy, + mean_f1=result.mean_f1, + total_tokens=result.total_tokens, + mean_iterations=result.mean_iterations, + mean_execution_time_ms=result.mean_execution_time_ms, + error_rate=result.error_rate, + num_samples=len(result.sample_results), + sample_results=sample_results, + git_commit=self._get_git_commit(), + rlm_version=self._get_rlm_version(), + ) + + # Write to benchmark-specific file + benchmark_file = self.results_dir / f"{result.benchmark_name}.jsonl" + with open(benchmark_file, "a") as f: + json.dump(record.to_dict(), f) + f.write("\n") + + # Update index + if result.benchmark_name not in self._index["benchmarks"]: + self._index["benchmarks"][result.benchmark_name] = { + "file": f"{result.benchmark_name}.jsonl", + "count": 0, + } + self._index["benchmarks"][result.benchmark_name]["count"] += 1 + self._index["experiments"].append( + { + "id": exp_id, + "benchmark": result.benchmark_name, + "model": config.model, + "method": config.method, + "accuracy": result.accuracy, + "timestamp": timestamp, + } + ) + self._save_index() + + return exp_id + + def get_history( + self, + benchmark: str, + model: str | None = None, + method: str | None = None, + limit: int | None = None, + ) -> list[ExperimentRecord]: + """Get historical results for a benchmark. + + Args: + benchmark: Benchmark name to query. + model: Filter by model name (substring match). + method: Filter by method (rlm, direct, etc.). + limit: Maximum number of results to return. + + Returns: + List of ExperimentRecord objects, newest first. + """ + benchmark_file = self.results_dir / f"{benchmark}.jsonl" + if not benchmark_file.exists(): + return [] + + records = [] + with open(benchmark_file) as f: + for line in f: + if line.strip(): + data = json.loads(line) + record = ExperimentRecord.from_dict(data) + + # Apply filters + if model and model.lower() not in record.config.model.lower(): + continue + if method and record.config.method != method: + continue + + records.append(record) + + # Sort by timestamp descending + records.sort(key=lambda r: r.timestamp, reverse=True) + + if limit: + records = records[:limit] + + return records + + def compare( + self, + benchmark: str, + group_by: str = "method", + filter_model: str | None = None, + ) -> dict[str, dict[str, float]]: + """Compare results grouped by a dimension. + + Args: + benchmark: Benchmark to compare. + group_by: Dimension to group by ("method", "model", "environment"). + filter_model: Optional model filter. + + Returns: + Dictionary mapping group key to aggregated metrics. + """ + records = self.get_history(benchmark, model=filter_model) + + groups: dict[str, list[ExperimentRecord]] = {} + for record in records: + if group_by == "method": + key = record.config.method + elif group_by == "model": + key = record.config.model + elif group_by == "environment": + key = record.config.environment + else: + key = getattr(record.config, group_by, "unknown") + + if key not in groups: + groups[key] = [] + groups[key].append(record) + + # Aggregate metrics for each group + comparison = {} + for key, group_records in groups.items(): + comparison[key] = { + "count": len(group_records), + "mean_accuracy": sum(r.accuracy for r in group_records) / len(group_records), + "max_accuracy": max(r.accuracy for r in group_records), + "mean_f1": sum(r.mean_f1 for r in group_records) / len(group_records), + "mean_tokens": sum(r.total_tokens for r in group_records) / len(group_records), + "mean_iterations": sum(r.mean_iterations for r in group_records) + / len(group_records), + } + + return comparison + + def list_benchmarks(self) -> list[str]: + """List all benchmarks with stored results.""" + return list(self._index["benchmarks"].keys()) + + def summary(self) -> dict[str, Any]: + """Get summary of all stored results.""" + return { + "total_experiments": len(self._index["experiments"]), + "benchmarks": {name: info["count"] for name, info in self._index["benchmarks"].items()}, + "results_dir": str(self.results_dir), + } + + def export_csv(self, benchmark: str, output_path: str): + """Export benchmark results to CSV for external analysis.""" + import csv + + records = self.get_history(benchmark) + + if not records: + return + + fieldnames = [ + "experiment_id", + "timestamp", + "model", + "method", + "environment", + "accuracy", + "mean_f1", + "total_tokens", + "mean_iterations", + "num_samples", + "git_commit", + ] + + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for r in records: + writer.writerow( + { + "experiment_id": r.experiment_id, + "timestamp": r.timestamp, + "model": r.config.model, + "method": r.config.method, + "environment": r.config.environment, + "accuracy": r.accuracy, + "mean_f1": r.mean_f1, + "total_tokens": r.total_tokens, + "mean_iterations": r.mean_iterations, + "num_samples": r.num_samples, + "git_commit": r.git_commit, + } + ) diff --git a/benchmarks/runner.py b/benchmarks/runner.py new file mode 100644 index 00000000..d6681c16 --- /dev/null +++ b/benchmarks/runner.py @@ -0,0 +1,417 @@ +""" +Benchmark runner for evaluating RLM and baseline methods. + +Orchestrates running benchmarks with different inference methods: +- RLM (recursive language model) +- Direct LLM call +- Summarization-based +- RAG (retrieval-augmented) +- CodeAct (code-generation agents) + +Supports parallel execution for faster evaluation. +""" + +import time +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from typing import Any + +from benchmarks.base import Benchmark, BenchmarkResult, BenchmarkSample, SampleResult + + +@dataclass +class RunnerConfig: + """Configuration for benchmark runner.""" + + backend: str = "openai" + model: str = "gpt-5-mini" + environment: str = "subprocess" + max_iterations: int = 30 + verbose: bool = False + log_dir: str | None = None + max_workers: int = 1 # Number of parallel workers (1 = sequential) + backend_kwargs: dict[str, Any] = field(default_factory=dict) + environment_kwargs: dict[str, Any] = field(default_factory=dict) + + +class BenchmarkRunner: + """Runs benchmarks with configurable inference methods. + + Supports multiple methods for comparison: + - "rlm": Full RLM with REPL environment + - "direct": Direct LLM call (context + question) + - "summarize": Iterative summarization baseline + - "custom": User-provided inference function + + Example: + runner = BenchmarkRunner(backend="openai", model="gpt-5-mini") + + # Run with RLM + rlm_results = runner.run(benchmark, method="rlm", num_samples=100) + + # Run with direct LLM for comparison + direct_results = runner.run(benchmark, method="direct", num_samples=100) + + # Compare results + print(f"RLM accuracy: {rlm_results.accuracy:.2%}") + print(f"Direct accuracy: {direct_results.accuracy:.2%}") + """ + + def __init__( + self, + backend: str = "openai", + model: str = "gpt-5-mini", + environment: str = "subprocess", + max_iterations: int = 30, + verbose: bool = False, + log_dir: str | None = None, + max_workers: int = 1, + **kwargs, + ): + """Initialize runner with configuration. + + Args: + backend: LLM backend (openai, anthropic, etc.) + model: Model name to use. + environment: REPL environment for RLM (local, subprocess, docker, modal). + max_iterations: Max iterations for RLM. + verbose: Enable verbose output. + log_dir: Directory for logging trajectories. + max_workers: Number of parallel workers (1 = sequential, >1 = parallel). + **kwargs: Additional backend or environment kwargs. + """ + self.config = RunnerConfig( + backend=backend, + model=model, + environment=environment, + max_iterations=max_iterations, + verbose=verbose, + log_dir=log_dir, + max_workers=max_workers, + backend_kwargs={"model_name": model, **kwargs.get("backend_kwargs", {})}, + environment_kwargs=kwargs.get("environment_kwargs", {}), + ) + + def run( + self, + benchmark: Benchmark, + method: str = "rlm", + num_samples: int | None = None, + seed: int | None = None, + custom_fn: Callable[[BenchmarkSample], str] | None = None, + max_workers: int | None = None, + ) -> BenchmarkResult: + """Run a benchmark with the specified method. + + Args: + benchmark: Benchmark instance to run. + method: Inference method ("rlm", "direct", "summarize", "custom"). + num_samples: Number of samples to evaluate. None for all. + seed: Random seed for reproducible sampling. + custom_fn: Custom inference function for method="custom". + Takes BenchmarkSample, returns prediction string. + max_workers: Override default max_workers for this run. + 1 = sequential, >1 = parallel threads. + + Returns: + BenchmarkResult with all sample results and aggregate metrics. + """ + result = BenchmarkResult( + benchmark_name=benchmark.name, + method=method, + model=self.config.model, + config={ + "backend": self.config.backend, + "environment": self.config.environment, + "max_iterations": self.config.max_iterations, + "num_samples": num_samples, + "seed": seed, + }, + ) + + inference_fn = self._get_inference_fn(method, custom_fn) + workers = max_workers if max_workers is not None else self.config.max_workers + + # Collect samples first (needed for parallel execution) + samples = list(benchmark.load_samples(num_samples=num_samples, seed=seed)) + + if workers <= 1: + # Sequential execution + for sample in samples: + sample_result = self._run_sample(sample, inference_fn, benchmark) + result.sample_results.append(sample_result) + + if self.config.verbose: + status = "✓" if sample_result.is_correct else "✗" + print(f" [{status}] Sample {sample.id}: {sample_result.metrics}") + else: + # Parallel execution + result.sample_results = self._run_parallel(samples, inference_fn, benchmark, workers) + + return result + + def _run_parallel( + self, + samples: list[BenchmarkSample], + inference_fn: Callable[[BenchmarkSample], tuple[str, dict[str, Any]]], + benchmark: Benchmark, + max_workers: int, + ) -> list[SampleResult]: + """Run samples in parallel using thread pool. + + Args: + samples: List of samples to process. + inference_fn: Inference function to apply. + benchmark: Benchmark for evaluation. + max_workers: Number of parallel threads. + + Returns: + List of SampleResult in original sample order. + """ + results: dict[str, SampleResult] = {} + completed = 0 + total = len(samples) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_sample = { + executor.submit(self._run_sample, sample, inference_fn, benchmark): sample + for sample in samples + } + + # Collect results as they complete + for future in as_completed(future_to_sample): + sample = future_to_sample[future] + try: + sample_result = future.result() + results[sample.id] = sample_result + completed += 1 + + if self.config.verbose: + status = "✓" if sample_result.is_correct else "✗" + print( + f" [{status}] ({completed}/{total}) " + f"Sample {sample.id}: {sample_result.metrics}" + ) + except Exception as e: + # Handle unexpected errors + results[sample.id] = SampleResult( + sample_id=sample.id, + prediction="", + expected=sample.expected_answer, + is_correct=False, + metrics={m: 0.0 for m in benchmark.default_metrics()}, + error=f"Parallel execution error: {e}", + ) + completed += 1 + + # Return results in original sample order + return [results[sample.id] for sample in samples] + + def _get_inference_fn( + self, + method: str, + custom_fn: Callable[[BenchmarkSample], str] | None = None, + ) -> Callable[[BenchmarkSample], tuple[str, dict[str, Any]]]: + """Get inference function for the specified method. + + Returns: + Function that takes BenchmarkSample and returns (prediction, metadata). + """ + if method == "rlm": + return self._inference_rlm + elif method == "direct": + return self._inference_direct + elif method == "summarize": + return self._inference_summarize + elif method == "custom": + if custom_fn is None: + raise ValueError("custom_fn required for method='custom'") + return lambda s: (custom_fn(s), {}) + else: + raise ValueError(f"Unknown method: {method}") + + def _inference_rlm(self, sample: BenchmarkSample) -> tuple[str, dict[str, Any]]: + """Run inference using RLM.""" + from rlm import RLM + from rlm.logger import RLMLogger + + logger = None + if self.config.log_dir: + logger = RLMLogger(log_dir=self.config.log_dir) + + rlm = RLM( + backend=self.config.backend, + backend_kwargs=self.config.backend_kwargs, + environment=self.config.environment, + environment_kwargs=self.config.environment_kwargs, + max_iterations=self.config.max_iterations, + logger=logger, + verbose=self.config.verbose, + ) + + result = rlm.completion(prompt=sample.context, root_prompt=sample.question) + + metadata = { + "iterations": result.iterations if hasattr(result, "iterations") else 0, + "total_tokens": ( + result.usage.total_tokens if hasattr(result, "usage") and result.usage else 0 + ), + } + + return result.response, metadata + + def _inference_direct(self, sample: BenchmarkSample) -> tuple[str, dict[str, Any]]: + """Run inference using direct LLM call.""" + from rlm.clients import get_client + + client = get_client(self.config.backend, **self.config.backend_kwargs) + + prompt = f"""Context: +{sample.context} + +Question: {sample.question} + +Answer the question based on the context above. Provide only the answer, nothing else.""" + + response = client.completion(prompt) + + usage = client.get_last_usage() + total_tokens = 0 + if usage and usage.model_usage_summaries: + for model_usage in usage.model_usage_summaries.values(): + total_tokens += model_usage.total_input_tokens + model_usage.total_output_tokens + + return response, {"iterations": 1, "total_tokens": total_tokens} + + def _inference_summarize(self, sample: BenchmarkSample) -> tuple[str, dict[str, Any]]: + """Run inference using iterative summarization. + + Splits context into chunks, summarizes each, then answers from summaries. + """ + from rlm.clients import get_client + + client = get_client(self.config.backend, **self.config.backend_kwargs) + + # Chunk the context (simple split by paragraphs, ~4k chars each) + chunk_size = 4000 + context = sample.context + chunks = [] + + while context: + if len(context) <= chunk_size: + chunks.append(context) + break + # Find a good break point + break_point = context.rfind("\n\n", 0, chunk_size) + if break_point == -1: + break_point = context.rfind(". ", 0, chunk_size) + if break_point == -1: + break_point = chunk_size + chunks.append(context[:break_point]) + context = context[break_point:].lstrip() + + # Summarize each chunk + summaries = [] + iterations = 0 + for chunk in chunks: + iterations += 1 + summary_prompt = f"""Summarize the following text, keeping all important facts and details relevant to answering questions: + +{chunk} + +Summary:""" + summary = client.completion(summary_prompt) + summaries.append(summary) + + # Combine summaries and answer + iterations += 1 + combined = "\n\n".join(summaries) + answer_prompt = f"""Based on these summaries: + +{combined} + +Question: {sample.question} + +Answer:""" + response = client.completion(answer_prompt) + + usage = client.get_usage_summary() + total_tokens = 0 + if usage and usage.model_usage_summaries: + for model_usage in usage.model_usage_summaries.values(): + total_tokens += model_usage.total_input_tokens + model_usage.total_output_tokens + + return response, {"iterations": iterations, "total_tokens": total_tokens} + + def _run_sample( + self, + sample: BenchmarkSample, + inference_fn: Callable[[BenchmarkSample], tuple[str, dict[str, Any]]], + benchmark: Benchmark, + ) -> SampleResult: + """Run a single sample and evaluate.""" + start_time = time.time() + error = None + prediction = "" + metadata: dict[str, Any] = {} + + try: + prediction, metadata = inference_fn(sample) + except Exception as e: + error = str(e) + prediction = "" + + execution_time_ms = (time.time() - start_time) * 1000 + + if error: + metrics = {m: 0.0 for m in benchmark.default_metrics()} + is_correct = False + else: + metrics = benchmark.evaluate(prediction, sample.expected_answer) + is_correct = metrics.get("correct", 0.0) > 0.5 + + return SampleResult( + sample_id=sample.id, + prediction=prediction, + expected=sample.expected_answer, + is_correct=is_correct, + metrics=metrics, + iterations=metadata.get("iterations", 0), + total_tokens=metadata.get("total_tokens", 0), + execution_time_ms=execution_time_ms, + error=error, + ) + + +def compare_methods( + benchmark: Benchmark, + runner: BenchmarkRunner, + methods: list[str] | None = None, + num_samples: int | None = None, + seed: int | None = None, +) -> dict[str, BenchmarkResult]: + """Run benchmark with multiple methods for comparison. + + Args: + benchmark: Benchmark to run. + runner: Configured BenchmarkRunner. + methods: List of methods to compare. Default: ["rlm", "direct"]. + num_samples: Number of samples per method. + seed: Random seed for reproducibility. + + Returns: + Dictionary mapping method name to BenchmarkResult. + """ + if methods is None: + methods = ["rlm", "direct"] + + results = {} + for method in methods: + print(f"\nRunning {benchmark.name} with method={method}...") + results[method] = runner.run(benchmark, method=method, num_samples=num_samples, seed=seed) + print(f" Accuracy: {results[method].accuracy:.2%}") + print(f" Mean F1: {results[method].mean_f1:.3f}") + + return results diff --git a/benchmarks/tasks/__init__.py b/benchmarks/tasks/__init__.py new file mode 100644 index 00000000..222177cc --- /dev/null +++ b/benchmarks/tasks/__init__.py @@ -0,0 +1,20 @@ +""" +Benchmark task implementations. + +Contains implementations for the four benchmark tasks from the RLM paper: +- S-NIAH: Single-Needle-in-a-Haystack +- BrowseComp-Plus: Multi-hop QA over document corpora +- OOLONG: Semantic aggregation +- OOLONG-Pairs: Pairwise combinatorial aggregation +""" + +from benchmarks.tasks.browsecomp import BrowseCompPlusBenchmark +from benchmarks.tasks.niah import NIAHBenchmark +from benchmarks.tasks.oolong import OolongBenchmark, OolongPairsBenchmark + +__all__ = [ + "NIAHBenchmark", + "OolongBenchmark", + "OolongPairsBenchmark", + "BrowseCompPlusBenchmark", +] diff --git a/benchmarks/tasks/browsecomp.py b/benchmarks/tasks/browsecomp.py new file mode 100644 index 00000000..e2f7efcd --- /dev/null +++ b/benchmarks/tasks/browsecomp.py @@ -0,0 +1,258 @@ +""" +BrowseComp-Plus Benchmark - Multi-hop QA over Large Document Corpora. + +Tests compositional question answering that requires reasoning across +multiple documents. This is the most realistic benchmark setting, +with document corpora ranging from 6M to 11M tokens. + +Note: This benchmark requires access to a document corpus. Users can +either provide their own corpus or use a synthetic generator. +""" + +import random +from collections.abc import Iterator +from typing import Any + +from benchmarks.base import Benchmark, BenchmarkSample +from benchmarks.metrics import Metrics + + +class BrowseCompPlusBenchmark(Benchmark): + """BrowseComp-Plus benchmark for multi-hop QA. + + Tests ability to answer questions that require: + 1. Finding relevant documents in a large corpus + 2. Reasoning across multiple documents + 3. Synthesizing a final answer + + This implementation provides: + - A synthetic corpus generator for testing + - Interface for loading custom document corpora + - Multi-hop question templates + + Args: + num_documents: Number of documents in the corpus. + doc_length: Average length of each document in characters. + num_hops: Number of reasoning hops required (1-3). + corpus_path: Path to custom corpus (JSON/JSONL file). + """ + + def __init__( + self, + num_documents: int = 100, + doc_length: int = 5000, + num_hops: int = 2, + corpus_path: str | None = None, + ): + self.num_documents = num_documents + self.doc_length = doc_length + self.num_hops = num_hops + self.corpus_path = corpus_path + + @property + def name(self) -> str: + return f"browsecomp-{self.num_documents}docs-{self.num_hops}hop" + + @property + def description(self) -> str: + return f"BrowseComp-Plus: {self.num_hops}-hop QA over {self.num_documents} documents" + + # Templates for generating synthetic documents and questions + ENTITIES = { + "companies": [ + "Acme Corp", + "TechNova", + "GlobalSync", + "DataFlow", + "CloudPeak", + "InnovateTech", + "FutureWorks", + "QuantumLeap", + "NexGen", + "PrimeAI", + ], + "people": [ + "Dr. Sarah Chen", + "James Rodriguez", + "Maria Santos", + "David Kim", + "Emily Watson", + "Michael Brown", + "Lisa Park", + "Robert Taylor", + "Jennifer Lee", + "William Davis", + ], + "products": [ + "SynthOS", + "NeuralNet Pro", + "DataVault", + "CloudBridge", + "AI Assistant", + "SmartAnalytics", + "SecureFlow", + "AutoML Platform", + "EdgeCompute", + "RoboSuite", + ], + "locations": [ + "San Francisco", + "New York", + "London", + "Tokyo", + "Singapore", + "Berlin", + "Sydney", + "Toronto", + "Mumbai", + "São Paulo", + ], + } + + FACT_TEMPLATES = [ + "{company} was founded by {person} in {location}.", + "{person} serves as the CEO of {company}.", + "{company} developed {product} in collaboration with {company2}.", + "{product} was first released in {location}.", + "{person} led the research team that created {product}.", + "{company} acquired {company2} to expand into {location}.", + "{person} previously worked at {company} before joining {company2}.", + ] + + MULTIHOP_TEMPLATES = [ + { + "hops": 2, + "question": "Who founded the company that created {product}?", + "chain": ["product->company", "company->founder"], + }, + { + "hops": 2, + "question": "Where is the headquarters of the company where {person} works?", + "chain": ["person->company", "company->location"], + }, + { + "hops": 3, + "question": "Who is the CEO of the company that acquired the creator of {product}?", + "chain": ["product->company", "company->acquirer", "acquirer->ceo"], + }, + ] + + def _generate_corpus(self, rng: random.Random) -> tuple[list[dict], dict[str, Any]]: + """Generate a synthetic document corpus with linked facts.""" + documents = [] + fact_graph = {} + + # Generate relationships + companies = self.ENTITIES["companies"].copy() + people = self.ENTITIES["people"].copy() + products = self.ENTITIES["products"].copy() + locations = self.ENTITIES["locations"].copy() + + rng.shuffle(companies) + rng.shuffle(people) + rng.shuffle(products) + rng.shuffle(locations) + + # Assign founders to companies + for i, company in enumerate(companies[: len(people)]): + founder = people[i % len(people)] + location = locations[i % len(locations)] + fact_graph[company] = { + "founder": founder, + "location": location, + "ceo": people[(i + 1) % len(people)], + } + + # Assign products to companies + for i, product in enumerate(products): + company = companies[i % len(companies)] + fact_graph[product] = {"company": company} + + # Generate documents with facts + for i in range(self.num_documents): + doc_content = [] + + # Add some facts from the graph + for _ in range(rng.randint(3, 8)): + template = rng.choice(self.FACT_TEMPLATES) + try: + fact = template.format( + company=rng.choice(companies), + company2=rng.choice(companies), + person=rng.choice(people), + product=rng.choice(products), + location=rng.choice(locations), + ) + doc_content.append(fact) + except (KeyError, IndexError): + pass + + # Add filler content + while len(" ".join(doc_content)) < self.doc_length: + doc_content.append( + f"Additional information about business operations and market trends " + f"in the {rng.choice(locations)} region continues to evolve." + ) + + documents.append( + { + "id": f"doc-{i:04d}", + "content": " ".join(doc_content), + } + ) + + return documents, fact_graph + + def _generate_question( + self, fact_graph: dict[str, Any], rng: random.Random + ) -> tuple[str, str, list[str]]: + """Generate a multi-hop question with answer and reasoning chain.""" + # Simple 2-hop question for now + products = [k for k in fact_graph.keys() if "company" in fact_graph.get(k, {})] + if not products: + return "What is 2+2?", "4", ["direct"] + + product = rng.choice(products) + company = fact_graph[product]["company"] + + if company in fact_graph and "founder" in fact_graph[company]: + founder = fact_graph[company]["founder"] + question = f"Who founded the company that created {product}?" + answer = founder + chain = [f"{product} -> {company}", f"{company} -> {founder}"] + return question, answer, chain + + return "What is 2+2?", "4", ["direct"] + + def load_samples( + self, num_samples: int | None = None, seed: int | None = None + ) -> Iterator[BenchmarkSample]: + """Generate BrowseComp-Plus samples.""" + rng = random.Random(seed) + num_samples = num_samples or 50 + + for i in range(num_samples): + # Generate fresh corpus for each sample (or reuse with different questions) + documents, fact_graph = self._generate_corpus(rng) + question, answer, chain = self._generate_question(fact_graph, rng) + + # Combine all documents into context + context = "\n\n---\n\n".join( + f"Document {doc['id']}:\n{doc['content']}" for doc in documents + ) + + yield BenchmarkSample( + id=f"browsecomp-{i:04d}", + context=context, + question=question, + expected_answer=answer, + metadata={ + "num_documents": len(documents), + "num_hops": len(chain), + "reasoning_chain": chain, + }, + ) + + def evaluate(self, prediction: str, expected: str | list[str]) -> dict[str, float]: + """Evaluate using standard metrics.""" + return Metrics.evaluate_standard(prediction, expected) diff --git a/benchmarks/tasks/niah.py b/benchmarks/tasks/niah.py new file mode 100644 index 00000000..c0444008 --- /dev/null +++ b/benchmarks/tasks/niah.py @@ -0,0 +1,203 @@ +""" +Single-Needle-in-a-Haystack (S-NIAH) Benchmark. + +Tests the ability to retrieve a single piece of information from a large context. +This is a context-insensitive retrieval task where the answer location doesn't +depend on semantic understanding of the surrounding text. + +The benchmark generates synthetic haystacks with a hidden "needle" fact. +""" + +import random +import string +from collections.abc import Iterator + +from benchmarks.base import Benchmark, BenchmarkSample +from benchmarks.metrics import Metrics + + +class NIAHBenchmark(Benchmark): + """Single-Needle-in-a-Haystack benchmark. + + Generates synthetic documents with a single retrievable fact (the "needle") + hidden at various positions within distractor text (the "haystack"). + + Args: + context_length: Target length of the haystack in characters. + needle_depth: Position of needle (0.0=start, 0.5=middle, 1.0=end). + If None, randomized per sample. + distractor_type: Type of haystack content ("lorem", "essays", "random"). + """ + + LOREM_IPSUM = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.""" + + NEEDLE_TEMPLATES = [ + ("The secret code is {value}.", "What is the secret code?"), + ("The password for the vault is {value}.", "What is the password for the vault?"), + ("The answer to the riddle is {value}.", "What is the answer to the riddle?"), + ("The hidden number is {value}.", "What is the hidden number?"), + ("The magic word is {value}.", "What is the magic word?"), + ] + + def __init__( + self, + context_length: int = 100_000, + needle_depth: float | None = None, + distractor_type: str = "lorem", + ): + self.context_length = context_length + self.needle_depth = needle_depth + self.distractor_type = distractor_type + + @property + def name(self) -> str: + return f"niah-{self.context_length // 1000}k" + + @property + def description(self) -> str: + return f"Single-Needle-in-a-Haystack with {self.context_length:,} char context" + + def _generate_haystack(self, length: int, rng: random.Random) -> str: + """Generate distractor text of approximately the target length.""" + if self.distractor_type == "lorem": + # Repeat lorem ipsum with slight variations + paragraphs = [] + while len("\n\n".join(paragraphs)) < length: + # Shuffle words slightly for variation + words = self.LOREM_IPSUM.split() + rng.shuffle(words) + paragraphs.append(" ".join(words)) + return "\n\n".join(paragraphs)[:length] + + elif self.distractor_type == "random": + # Random sentences + words = [ + "the", + "a", + "is", + "are", + "was", + "were", + "has", + "have", + "had", + "do", + "does", + "did", + "will", + "would", + "could", + "should", + "may", + "might", + "must", + "can", + "be", + "been", + "being", + "have", + "has", + ] + nouns = [ + "cat", + "dog", + "house", + "tree", + "car", + "book", + "table", + "chair", + "computer", + "phone", + "person", + "city", + "country", + "world", + ] + adjectives = [ + "big", + "small", + "red", + "blue", + "green", + "old", + "new", + "fast", + "slow", + "hot", + "cold", + "bright", + "dark", + ] + + sentences = [] + current_length = 0 + while current_length < length: + sentence = f"The {rng.choice(adjectives)} {rng.choice(nouns)} {rng.choice(words)} {rng.choice(adjectives)}." + sentences.append(sentence) + current_length += len(sentence) + 1 + + return " ".join(sentences)[:length] + + else: + raise ValueError(f"Unknown distractor_type: {self.distractor_type}") + + def _generate_needle(self, rng: random.Random) -> tuple[str, str, str]: + """Generate a needle (fact, question, answer).""" + template, question = rng.choice(self.NEEDLE_TEMPLATES) + + # Generate a random value + value_type = rng.choice(["word", "number", "code"]) + if value_type == "word": + value = "".join(rng.choices(string.ascii_lowercase, k=rng.randint(6, 10))) + elif value_type == "number": + value = str(rng.randint(1000, 9999)) + else: + value = "".join(rng.choices(string.ascii_uppercase + string.digits, k=8)) + + needle = template.format(value=value) + return needle, question, value + + def load_samples( + self, num_samples: int | None = None, seed: int | None = None + ) -> Iterator[BenchmarkSample]: + """Generate synthetic NIAH samples.""" + rng = random.Random(seed) + num_samples = num_samples or 100 + + for i in range(num_samples): + needle, question, answer = self._generate_needle(rng) + + # Determine needle position + depth = self.needle_depth if self.needle_depth is not None else rng.random() + + # Generate haystack + haystack_length = self.context_length - len(needle) - 10 + haystack = self._generate_haystack(haystack_length, rng) + + # Insert needle at depth position + insert_pos = int(len(haystack) * depth) + # Find a good break point (paragraph or sentence) + break_pos = haystack.rfind("\n\n", max(0, insert_pos - 100), insert_pos + 100) + if break_pos == -1: + break_pos = haystack.rfind(". ", max(0, insert_pos - 50), insert_pos + 50) + if break_pos == -1: + break_pos = insert_pos + + context = haystack[:break_pos] + "\n\n" + needle + "\n\n" + haystack[break_pos:] + + yield BenchmarkSample( + id=f"niah-{i:04d}", + context=context, + question=question, + expected_answer=answer, + metadata={ + "needle_depth": depth, + "context_length": len(context), + "needle": needle, + }, + ) + + def evaluate(self, prediction: str, expected: str | list[str]) -> dict[str, float]: + """Evaluate prediction - exact match or containment for NIAH.""" + return Metrics.evaluate_standard(prediction, expected) diff --git a/benchmarks/tasks/oolong.py b/benchmarks/tasks/oolong.py new file mode 100644 index 00000000..67bf0304 --- /dev/null +++ b/benchmarks/tasks/oolong.py @@ -0,0 +1,336 @@ +""" +OOLONG Benchmarks - Semantic Aggregation over Long Contexts. + +OOLONG tests the ability to aggregate information across a long context. +Two variants: +- OolongBenchmark: Standard QA requiring semantic understanding +- OolongPairsBenchmark: Pairwise combinatorial aggregation (hardest setting) + +Uses the oolongbench/oolong-real dataset from HuggingFace. +""" + +import random +import re +from collections.abc import Iterator +from typing import Any + +from benchmarks.base import Benchmark, BenchmarkSample +from benchmarks.metrics import Metrics + + +class OolongBenchmark(Benchmark): + """OOLONG benchmark for semantic aggregation. + + Loads samples from the oolongbench/oolong-real HuggingFace dataset. + Tests ability to answer questions requiring understanding and aggregation + of information spread across a long context. + + Args: + subset: Dataset subset to use (e.g., "toy_dnd", "counting", etc.) + """ + + AVAILABLE_SUBSETS = [ + "toy_dnd", + "counting", + "retrieval", + "reasoning", + "aggregation", + ] + + def __init__(self, subset: str = "toy_dnd"): + self.subset = subset + self._validate_subset() + + def _validate_subset(self): + """Check that subset is available.""" + if self.subset not in self.AVAILABLE_SUBSETS: + raise ValueError(f"Unknown subset: {self.subset}. Available: {self.AVAILABLE_SUBSETS}") + + @property + def name(self) -> str: + return f"oolong-{self.subset}" + + @property + def description(self) -> str: + return f"OOLONG semantic aggregation benchmark ({self.subset})" + + def _load_dataset(self, seed: int | None = None): + """Load the oolong dataset with streaming.""" + try: + from datasets import load_dataset + except ImportError as e: + raise ImportError("Please install datasets: uv pip install datasets") from e + + ds = load_dataset( + "oolongbench/oolong-real", + self.subset, + split="test", + streaming=True, + ) + + if seed is not None: + ds = ds.shuffle(seed=seed, buffer_size=1000) + + return ds + + def load_samples( + self, num_samples: int | None = None, seed: int | None = None + ) -> Iterator[BenchmarkSample]: + """Load samples from the OOLONG dataset.""" + ds = self._load_dataset(seed=seed) + + count = 0 + for row in ds: + if num_samples is not None and count >= num_samples: + break + + yield BenchmarkSample( + id=f"oolong-{self.subset}-{count:04d}", + context=row["context_window_text"], + question=row["question"], + expected_answer=row["answer"], + metadata={ + "subset": self.subset, + "original_id": row.get("id", count), + }, + ) + count += 1 + + def evaluate(self, prediction: str, expected: str | list[str]) -> dict[str, float]: + """Evaluate using standard metrics (containment + F1).""" + return Metrics.evaluate_standard(prediction, expected) + + +class OolongPairsBenchmark(Benchmark): + """OOLONG-Pairs benchmark for pairwise combinatorial aggregation. + + The hardest setting from the RLM paper. Requires identifying all pairs + of items that satisfy a given relationship from a long context. + + This is a synthetic benchmark that generates pairwise relationships. + """ + + def __init__( + self, + num_items: int = 50, + num_pairs: int = 25, + context_length: int = 50_000, + ): + """Initialize OOLONG-Pairs benchmark. + + Args: + num_items: Number of unique items in the context. + num_pairs: Number of pairs that satisfy the relationship. + context_length: Approximate target context length. + """ + self.num_items = num_items + self.num_pairs = num_pairs + self.context_length = context_length + + @property + def name(self) -> str: + return f"oolong-pairs-{self.num_items}x{self.num_pairs}" + + @property + def description(self) -> str: + return f"OOLONG-Pairs: Find all {self.num_pairs} pairs among {self.num_items} items" + + # Item categories and relationships for generating diverse content + DOMAINS = [ + { + "items": [ + "Alice", + "Bob", + "Charlie", + "Diana", + "Eve", + "Frank", + "Grace", + "Henry", + "Ivy", + "Jack", + "Kate", + "Leo", + "Mia", + "Noah", + "Olivia", + "Peter", + "Quinn", + "Rose", + "Sam", + "Tina", + "Uma", + "Victor", + "Wendy", + "Xavier", + "Yara", + "Zack", + ], + "relationship": "collaborated with", + "question": "List all pairs of people who collaborated together.", + "context_template": "{a} and {b} worked together on a project last year.", + "distractor_templates": [ + "{a} attended the conference in {city}.", + "{a} published a paper on {topic}.", + "{a} received an award for {achievement}.", + ], + }, + { + "items": [ + "Apple", + "Banana", + "Cherry", + "Date", + "Elderberry", + "Fig", + "Grape", + "Honeydew", + "Jackfruit", + "Kiwi", + "Lemon", + "Mango", + "Nectarine", + "Orange", + "Papaya", + "Quince", + "Raspberry", + "Strawberry", + "Tangerine", + "Ugli", + "Vanilla", + "Watermelon", + ], + "relationship": "pairs well with", + "question": "List all pairs of fruits that pair well together in recipes.", + "context_template": "{a} and {b} create an excellent flavor combination.", + "distractor_templates": [ + "{a} is commonly grown in {region}.", + "{a} contains high levels of {nutrient}.", + "{a} is harvested during {season}.", + ], + }, + ] + + def _generate_sample(self, sample_id: int, rng: random.Random) -> BenchmarkSample: + """Generate a single OOLONG-Pairs sample.""" + domain = rng.choice(self.DOMAINS) + items = domain["items"].copy() + + # Ensure we have enough items + while len(items) < self.num_items: + items.append(f"Item{len(items)}") + + rng.shuffle(items) + items = items[: self.num_items] + + # Generate pairs + all_possible = [] + for i, a in enumerate(items): + for b in items[i + 1 :]: + all_possible.append((a, b)) + + rng.shuffle(all_possible) + true_pairs = all_possible[: self.num_pairs] + + # Generate context with pair statements and distractors + statements = [] + + # Add true pair statements + for a, b in true_pairs: + stmt = domain["context_template"].format(a=a, b=b) + statements.append(stmt) + + # Add distractor statements + cities = ["London", "Paris", "Tokyo", "New York", "Sydney", "Berlin"] + topics = ["machine learning", "sustainability", "economics", "design"] + achievements = ["innovation", "leadership", "research", "creativity"] + regions = ["California", "Mediterranean", "South America", "Southeast Asia"] + nutrients = ["vitamin C", "antioxidants", "fiber", "potassium"] + seasons = ["summer", "fall", "spring", "winter"] + + while len("\n".join(statements)) < self.context_length: + item = rng.choice(items) + template = rng.choice(domain["distractor_templates"]) + + stmt = template.format( + a=item, + city=rng.choice(cities), + topic=rng.choice(topics), + achievement=rng.choice(achievements), + region=rng.choice(regions), + nutrient=rng.choice(nutrients), + season=rng.choice(seasons), + ) + statements.append(stmt) + + # Shuffle all statements + rng.shuffle(statements) + context = "\n".join(statements) + + return BenchmarkSample( + id=f"oolong-pairs-{sample_id:04d}", + context=context, + question=domain["question"], + expected_answer=true_pairs, # List of tuples + metadata={ + "num_items": len(items), + "num_pairs": len(true_pairs), + "domain": domain["relationship"], + }, + ) + + def load_samples( + self, num_samples: int | None = None, seed: int | None = None + ) -> Iterator[BenchmarkSample]: + """Generate synthetic OOLONG-Pairs samples.""" + rng = random.Random(seed) + num_samples = num_samples or 50 + + for i in range(num_samples): + yield self._generate_sample(i, rng) + + def _parse_pairs(self, prediction: str) -> set[tuple[str, str]]: + """Parse pairs from prediction string. + + Handles various formats: + - (A, B) + - A and B + - A - B + - A with B + """ + pairs = set() + + # Try to find pairs in various formats + patterns = [ + r"\(([^,]+),\s*([^)]+)\)", # (A, B) + r"([A-Z][a-z]+)\s+and\s+([A-Z][a-z]+)", # A and B + r"([A-Z][a-z]+)\s*[-–]\s*([A-Z][a-z]+)", # A - B + r"([A-Z][a-z]+)\s+with\s+([A-Z][a-z]+)", # A with B + ] + + for pattern in patterns: + matches = re.findall(pattern, prediction) + for match in matches: + pairs.add((match[0].strip(), match[1].strip())) + + return pairs + + def evaluate(self, prediction: str, expected: str | list[Any]) -> dict[str, float]: + """Evaluate pairwise F1 score.""" + predicted_pairs = self._parse_pairs(prediction) + + # Convert expected to set of tuples + if isinstance(expected, str): + expected_pairs = self._parse_pairs(expected) + else: + expected_pairs = {(a, b) for a, b in expected} + + metrics = Metrics.pairwise_f1(predicted_pairs, expected_pairs) + + # Add "correct" for compatibility (1.0 if F1 > 0.5) + metrics["correct"] = 1.0 if metrics["f1"] > 0.5 else 0.0 + + return metrics + + def default_metrics(self) -> list[str]: + return ["correct", "precision", "recall", "f1"] diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 00000000..de9636f2 --- /dev/null +++ b/tests/benchmarks/__init__.py @@ -0,0 +1 @@ +"""Tests for the RLM benchmarks framework.""" diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_benchmarks.py new file mode 100644 index 00000000..a84d9d48 --- /dev/null +++ b/tests/benchmarks/test_benchmarks.py @@ -0,0 +1,269 @@ +"""Tests for benchmark framework components.""" + +from benchmarks.base import BenchmarkResult, BenchmarkSample, SampleResult +from benchmarks.metrics import Metrics +from benchmarks.tasks.niah import NIAHBenchmark +from benchmarks.tasks.oolong import OolongPairsBenchmark + + +class TestMetrics: + """Tests for evaluation metrics.""" + + def test_normalize(self): + """Test text normalization.""" + assert Metrics.normalize("Hello, World!") == "hello world" + assert Metrics.normalize(" Multiple Spaces ") == "multiple spaces" + assert Metrics.normalize("UPPERCASE") == "uppercase" + + def test_exact_match_single(self): + """Test exact match with single expected.""" + assert Metrics.exact_match("hello", "hello") + assert Metrics.exact_match("Hello!", "hello") + assert not Metrics.exact_match("hello world", "hello") + + def test_exact_match_multiple(self): + """Test exact match with multiple expected answers.""" + assert Metrics.exact_match("hello", ["hello", "hi", "hey"]) + assert Metrics.exact_match("hi", ["hello", "hi", "hey"]) + assert not Metrics.exact_match("goodbye", ["hello", "hi", "hey"]) + + def test_containment(self): + """Test containment matching.""" + assert Metrics.containment("The answer is 42", "42") + assert Metrics.containment("42", "The answer is 42") + assert Metrics.containment("hello world", ["hello", "goodbye"]) + assert not Metrics.containment("hello", "world") + + def test_token_f1(self): + """Test token-level F1 score.""" + # Perfect match + assert Metrics.token_f1("hello world", "hello world") == 1.0 + + # Partial overlap + f1 = Metrics.token_f1("hello world today", "hello world") + assert 0.5 < f1 < 1.0 + + # No overlap + assert Metrics.token_f1("hello", "goodbye") == 0.0 + + # Empty prediction + assert Metrics.token_f1("", "hello") == 0.0 + + def test_pairwise_f1(self): + """Test pairwise F1 for OOLONG-Pairs.""" + pred = {("A", "B"), ("C", "D")} + exp = {("A", "B"), ("C", "D")} + result = Metrics.pairwise_f1(pred, exp) + assert result["f1"] == 1.0 + + # Partial overlap + pred = {("A", "B"), ("E", "F")} + exp = {("A", "B"), ("C", "D")} + result = Metrics.pairwise_f1(pred, exp) + assert result["f1"] == 0.5 + + # Order independence + pred = {("B", "A")} + exp = {("A", "B")} + result = Metrics.pairwise_f1(pred, exp) + assert result["f1"] == 1.0 + + def test_evaluate_standard(self): + """Test standard evaluation combining metrics.""" + result = Metrics.evaluate_standard("The answer is 42", "42") + assert result["correct"] == 1.0 + assert result["containment"] == 1.0 + assert result["f1"] > 0 + + +class TestBenchmarkSample: + """Tests for BenchmarkSample.""" + + def test_create_sample(self): + """Test creating a benchmark sample.""" + sample = BenchmarkSample( + id="test-001", + context="This is the context.", + question="What is this?", + expected_answer="context", + ) + assert sample.id == "test-001" + assert sample.context_tokens_approx > 0 + + def test_sample_with_multiple_answers(self): + """Test sample with multiple valid answers.""" + sample = BenchmarkSample( + id="test-002", + context="Context", + question="Question?", + expected_answer=["answer1", "answer2"], + ) + assert len(sample.expected_answer) == 2 + + +class TestBenchmarkResult: + """Tests for BenchmarkResult aggregation.""" + + def test_empty_result(self): + """Test empty result defaults.""" + result = BenchmarkResult( + benchmark_name="test", + method="rlm", + model="test-model", + ) + assert result.accuracy == 0.0 + assert result.mean_f1 == 0.0 + assert result.error_rate == 0.0 + + def test_result_aggregation(self): + """Test metric aggregation.""" + result = BenchmarkResult( + benchmark_name="test", + method="rlm", + model="test-model", + sample_results=[ + SampleResult( + sample_id="1", + prediction="correct", + expected="correct", + is_correct=True, + metrics={"f1": 1.0}, + iterations=5, + total_tokens=100, + ), + SampleResult( + sample_id="2", + prediction="wrong", + expected="right", + is_correct=False, + metrics={"f1": 0.0}, + iterations=3, + total_tokens=50, + ), + ], + ) + assert result.accuracy == 0.5 + assert result.mean_f1 == 0.5 + assert result.total_tokens == 150 + assert result.mean_iterations == 4.0 + + def test_result_summary(self): + """Test summary dict generation.""" + result = BenchmarkResult( + benchmark_name="test", + method="rlm", + model="test-model", + ) + summary = result.summary() + assert "benchmark" in summary + assert "accuracy" in summary + assert "mean_f1" in summary + + +class TestNIAHBenchmark: + """Tests for NIAH benchmark.""" + + def test_load_samples(self): + """Test loading NIAH samples.""" + benchmark = NIAHBenchmark(context_length=10_000) + samples = list(benchmark.load_samples(num_samples=5, seed=42)) + + assert len(samples) == 5 + for sample in samples: + assert len(sample.context) > 0 + assert len(sample.question) > 0 + assert len(sample.expected_answer) > 0 + + def test_needle_in_context(self): + """Test that needle is present in context.""" + benchmark = NIAHBenchmark(context_length=5_000) + sample = next(benchmark.load_samples(num_samples=1, seed=42)) + + # The expected answer should be findable in context + needle = sample.metadata["needle"] + assert needle in sample.context + + def test_evaluate(self): + """Test NIAH evaluation.""" + benchmark = NIAHBenchmark() + + # Exact match + result = benchmark.evaluate("abc123", "abc123") + assert result["correct"] == 1.0 + + # Containment + result = benchmark.evaluate("The code is abc123.", "abc123") + assert result["correct"] == 1.0 + + # Miss + result = benchmark.evaluate("xyz789", "abc123") + assert result["correct"] == 0.0 + + +class TestOolongPairsBenchmark: + """Tests for OOLONG-Pairs benchmark.""" + + def test_load_samples(self): + """Test loading OOLONG-Pairs samples.""" + benchmark = OolongPairsBenchmark(num_items=20, num_pairs=10) + samples = list(benchmark.load_samples(num_samples=3, seed=42)) + + assert len(samples) == 3 + for sample in samples: + assert len(sample.expected_answer) == 10 + assert all(isinstance(p, tuple) for p in sample.expected_answer) + + def test_parse_pairs(self): + """Test parsing pairs from prediction.""" + benchmark = OolongPairsBenchmark() + + # Test various formats + pred = "(Alice, Bob) and (Charlie, Diana)" + pairs = benchmark._parse_pairs(pred) + assert len(pairs) >= 2 + + pred = "Alice and Bob collaborated. Charlie with Diana worked." + pairs = benchmark._parse_pairs(pred) + assert len(pairs) >= 2 + + def test_evaluate(self): + """Test OOLONG-Pairs evaluation.""" + benchmark = OolongPairsBenchmark() + + # Perfect match + expected = [("Alice", "Bob"), ("Charlie", "Diana")] + prediction = "(Alice, Bob), (Charlie, Diana)" + result = benchmark.evaluate(prediction, expected) + assert result["f1"] == 1.0 + + # Partial match + prediction = "(Alice, Bob), (Eve, Frank)" + result = benchmark.evaluate(prediction, expected) + assert 0 < result["f1"] < 1.0 + + +class TestBenchmarkIntegration: + """Integration tests for benchmark framework.""" + + def test_niah_reproducibility(self): + """Test that same seed produces same samples.""" + benchmark = NIAHBenchmark(context_length=5_000) + + samples1 = list(benchmark.load_samples(num_samples=3, seed=123)) + samples2 = list(benchmark.load_samples(num_samples=3, seed=123)) + + for s1, s2 in zip(samples1, samples2, strict=True): + assert s1.expected_answer == s2.expected_answer + assert s1.question == s2.question + + def test_different_seeds_different_samples(self): + """Test that different seeds produce different samples.""" + benchmark = NIAHBenchmark(context_length=5_000) + + samples1 = list(benchmark.load_samples(num_samples=3, seed=1)) + samples2 = list(benchmark.load_samples(num_samples=3, seed=2)) + + # At least one should be different + answers1 = [s.expected_answer for s in samples1] + answers2 = [s.expected_answer for s in samples2] + assert answers1 != answers2