diff --git a/retrieval-bench/.gitignore b/retrieval-bench/.gitignore new file mode 100644 index 000000000..14c0c7ca8 --- /dev/null +++ b/retrieval-bench/.gitignore @@ -0,0 +1,15 @@ +# Transient directories generated by retrieval_bench +cache/ +results/ +traces/ +nemo_agentic_logs/ +shards/ + +# Environment +.env +.venv/ + +# Python +__pycache__/ +*.pyc +*.egg-info/ diff --git a/retrieval-bench/README.md b/retrieval-bench/README.md new file mode 100644 index 000000000..dc1ca6814 --- /dev/null +++ b/retrieval-bench/README.md @@ -0,0 +1,115 @@ + + +# retrieval-bench + +Pipeline evaluation framework for document retrieval. + +This package depends on `vidore-benchmark` for shared framework modules that are +not vendored in this slim repository snapshot. + +## Security note + +All model backends load HuggingFace models with `trust_remote_code=True`, which +executes Python code shipped with the model repository. Only use model IDs from +sources you trust. + +## Installation + +```bash +pip install -e . +``` + +### Model dependencies + +Most models work with `transformers==4.51.0`: + +```bash +pip install transformers==4.51.0 +pip install flash-attn==2.6.3 --no-build-isolation +``` + +The `nemotron-colembed-vl-8b-v2` model requires a newer transformers release: + +```bash +pip install transformers==5.0.0rc0 +pip install flash-attn==2.6.3 --no-build-isolation +``` + +## Dense retrieval + +Retrieval-only evaluation with a pluggable backend. The `--backend` flag +selects which retriever to use. + +```bash +retrieval-bench evaluate dense-retrieval \ + --dataset-name bright/biology \ + --backend llama-nv-embed-reasoning-3b + +retrieval-bench evaluate dense-retrieval \ + --dataset-name vidore/vidore_v3_hr \ + --backend llama-nemotron-embed-vl-1b-v2 \ + --language english + +retrieval-bench evaluate dense-retrieval \ + --dataset-name vidore/vidore_v3_hr \ + --backend nemotron-colembed-vl-8b-v2 \ + --language english +``` + +Backend-specific overrides can be passed via `--pipeline-args` JSON: + +```bash +retrieval-bench evaluate dense-retrieval \ + --dataset-name bright/biology \ + --backend llama-nv-embed-reasoning-3b \ + --pipeline-args '{"model_id":"~/checkpoints/my_model","scoring_batch_size":2048}' +``` + +## Agentic retrieval + +Dense retrieval augmented with an LLM agent that iteratively refines results. + +```bash +retrieval-bench evaluate agentic-retrieval \ + --dataset-name bright/biology \ + --backend llama-nv-embed-reasoning-3b \ + --llm-model your-llm-model \ + --num-concurrent 10 + +retrieval-bench evaluate agentic-retrieval \ + --dataset-name vidore/vidore_v3_hr \ + --backend llama-nemotron-embed-vl-1b-v2 \ + --llm-model your-llm-model +``` + +By default the pipeline reads `OPENAI_API_KEY` and `OPENAI_BASE_URL` from +environment variables. Override them via `--pipeline-args`: + +```bash +retrieval-bench evaluate agentic-retrieval \ + --dataset-name bright/biology \ + --backend llama-nv-embed-reasoning-3b \ + --llm-model your-llm-model \ + --pipeline-args '{"api_key":"os.environ/MY_KEY","base_url":"os.environ/MY_URL"}' +``` + +### Available backends + +| Backend | Retriever | +|---------|-----------| +| `llama-nv-embed-reasoning-3b` | NeMo Reasoning dense retriever | +| `llama-nemoretriever-colembed-3b-v1` | ColEmbed late-interaction retriever | +| `llama-nemotron-embed-vl-1b-v2` | Nemotron Embed VL 1B v2 (multimodal dense) | +| `nemotron-colembed-vl-8b-v2` | Nemotron ColEmbed VL 8B v2 (multimodal late-interaction) | + +## Utilities + +```bash +retrieval-bench evaluate utils list-datasets +retrieval-bench evaluate utils list-backends +retrieval-bench evaluate utils report-results --results-dir results/my_run +retrieval-bench evaluate utils compare-results --results-dirs results/run_a --results-dirs results/run_b +``` diff --git a/retrieval-bench/pyproject.toml b/retrieval-bench/pyproject.toml new file mode 100644 index 000000000..b8244b708 --- /dev/null +++ b/retrieval-bench/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "retrieval-bench" +version = "0.1.0" +description = "Retrieval pipeline benchmarking toolkit for ViDoRe V3 and BRIGHT leaderboards." +authors = [] +readme = "README.md" +requires-python = ">=3.9,<3.14" +license = { text = "Apache-2.0" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "vidore-benchmark @ git+https://github.com/illuin-tech/vidore-benchmark.git@a93c1d916e305a1116ea557e1170de7adeb531e4", + "aiofiles", + "bm25s", + "datasets>=4.5.0", + "jinja2", + "litellm", + "Pillow", + "pydantic>=2.0.0,<3.0.0", + "pytrec_eval", + "PyStemmer", + "python-dotenv>=1.0.1,<2.0.0", + "rich", + "torch", + "transformers>=4.56.0", + "typer>=0.12.3,<1.0.0", + "torchvision>=0.23.0", +] + +[project.optional-dependencies] +dev = ["pytest>=8.2.1", "ruff>=0.4.5"] + +[project.scripts] +retrieval-bench = "retrieval_bench.cli.main:app" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/retrieval_bench"] + +[tool.pytest.ini_options] +filterwarnings = ["ignore::Warning"] +markers = ["slow: marks test as slow"] +testpaths = ["tests"] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "W", "I", "N"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/retrieval-bench/src/retrieval_bench/__init__.py b/retrieval-bench/src/retrieval_bench/__init__.py new file mode 100644 index 000000000..702be4814 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Top-level package exports. + +Note: This package previously re-exported many optional retriever implementations. +Those imports can be heavyweight and may pull optional dependencies at import time. +For pipeline evaluation workflows, we keep this module lightweight and only export +the pipeline-evaluation API surface. +""" + +from .pipeline_evaluation import ( # noqa: F401 + BasePipeline, + aggregate_results, + evaluate_retrieval, + get_available_datasets, + load_vidore_dataset, + print_dataset_info, +) + +__all__ = [ + "BasePipeline", + "evaluate_retrieval", + "aggregate_results", + "load_vidore_dataset", + "get_available_datasets", + "print_dataset_info", +] diff --git a/retrieval-bench/src/retrieval_bench/cli/evaluate.py b/retrieval-bench/src/retrieval_bench/cli/evaluate.py new file mode 100644 index 000000000..d3cab7d76 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/cli/evaluate.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +CLI commands for ``retrieval-bench evaluate dense-retrieval`` and +``retrieval-bench evaluate agentic-retrieval``. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Annotated, Any, Dict, Optional, Set + +import typer + +from retrieval_bench.pipeline_evaluation import ( + BasePipeline, + aggregate_results, + evaluate_retrieval, + load_vidore_dataset, + print_dataset_info, +) +from retrieval_bench.pipeline_evaluation.tracing import dataset_trace_dir, default_trace_run_name +from retrieval_bench.pipelines.backends import VALID_BACKENDS +from vidore_benchmark.utils.logging_utils import setup_logging + +# Import pipeline utility commands (report-results, compare-results, etc.) +from retrieval_bench.cli.pipeline_evaluation import app as _pipeline_utils_app + +logger = logging.getLogger(__name__) + +app = typer.Typer( + help="Evaluate retrieval pipelines on ViDoRe v3 / BRIGHT datasets.", + no_args_is_help=True, +) + +app.add_typer( + _pipeline_utils_app, + name="utils", + help="Reporting & trace utilities (list-datasets, report-results, compare-results, ...)", +) + + +@app.callback() +def main(log_level: Annotated[str, typer.Option("--log", help="Logging level")] = "warning"): + setup_logging(log_level) + logger.info("Logging level set to `%s`", log_level) + + +# --------------------------------------------------------------------------- +# Query-id selector helpers (duplicated minimally from pipeline_evaluation) +# --------------------------------------------------------------------------- + + +def _parse_query_ids_selector(selector: str) -> Set[str]: + selector = (selector or "").strip() + if not selector: + raise ValueError("Empty --query-ids selector.") + ids: Set[str] = set() + for tok in [t.strip() for t in selector.split(",") if t.strip()]: + if "-" in tok: + parts = [p.strip() for p in tok.split("-", 1)] + start, end = int(parts[0]), int(parts[1]) + for i in range(start, end + 1): + ids.add(str(i)) + else: + ids.add(str(int(tok))) + return ids + + +def _parse_index_selector(selector: str) -> Set[int]: + selector = (selector or "").strip() + if not selector: + raise ValueError("Empty --query-ids selector.") + ids: Set[int] = set() + for tok in [t.strip() for t in selector.split(",") if t.strip()]: + if "-" in tok: + parts = [p.strip() for p in tok.split("-", 1)] + start, end = int(parts[0]), int(parts[1]) + for i in range(start, end + 1): + ids.add(i) + else: + ids.add(int(tok)) + return ids + + +def _query_ids_are_numeric(query_ids: list[str]) -> bool: + return all(str(qid).isdigit() for qid in query_ids) + + +def _filter_queries_by_ids(query_ids, queries, qrels, query_languages, requested_ids): + fq_ids, fq = [], [] + for qid, q in zip(query_ids, queries): + if qid in requested_ids: + fq_ids.append(qid) + fq.append(q) + if not fq_ids: + raise ValueError("After applying --query-ids, zero queries remain.") + fqrels = {qid: qrels[qid] for qid in fq_ids if qid in qrels} + fql = {qid: query_languages.get(qid, "unknown") for qid in fq_ids} + return fq_ids, fq, fqrels, fql + + +def _filter_queries_by_positions(query_ids, queries, qrels, query_languages, requested_positions): + fq_ids, fq = [], [] + for idx, (qid, q) in enumerate(zip(query_ids, queries)): + if idx in requested_positions: + fq_ids.append(qid) + fq.append(q) + if not fq_ids: + raise ValueError("After applying --query-ids, zero queries remain.") + fqrels = {qid: qrels[qid] for qid in fq_ids if qid in qrels} + fql = {qid: query_languages.get(qid, "unknown") for qid in fq_ids} + return fq_ids, fq, fqrels, fql + + +# --------------------------------------------------------------------------- +# Shared evaluation runner +# --------------------------------------------------------------------------- + +_METRICS = [ + "ndcg_cut_1", + "ndcg_cut_5", + "ndcg_cut_10", + "ndcg_cut_20", + "ndcg_cut_100", + "recall_1", + "recall_5", + "recall_10", + "recall_20", + "recall_50", + "recall_100", + "P_1", + "P_5", + "P_10", + "P_20", + "map", + "map_cut_1", + "map_cut_10", + "map_cut_100", + "recip_rank", +] + + +def _run_evaluation( + *, + pipeline: BasePipeline, + dataset_name: str, + split: str, + language: Optional[str], + query_ids_selector: Optional[str], + trace_run_name: Optional[str], + traces_dir: str, + output_file: Optional[str], + show_dataset_info: bool, + pipeline_label: str, + pipeline_args_for_output: Dict[str, Any], + cache_only: bool = False, +) -> None: + """Shared dataset-loading / evaluation / result-display logic.""" + + # Load dataset + try: + query_ids, queries, corpus_ids, corpus_images, corpus_texts, qrels, query_languages, excluded_ids_by_query = ( + load_vidore_dataset(dataset_name=dataset_name, split=split, language=language) + ) + except Exception as e: + print(f"\nError loading dataset: {e}\n") + raise typer.Exit(code=1) + + # Apply query-id filter + if query_ids_selector: + try: + if _query_ids_are_numeric(query_ids): + requested = _parse_query_ids_selector(query_ids_selector) + query_ids, queries, qrels, query_languages = _filter_queries_by_ids( + query_ids, + queries, + qrels, + query_languages, + requested, + ) + else: + requested_pos = _parse_index_selector(query_ids_selector) + query_ids, queries, qrels, query_languages = _filter_queries_by_positions( + query_ids, + queries, + qrels, + query_languages, + requested_pos, + ) + except ValueError as e: + print(f"\nError parsing/applying --query-ids: {e}\n") + raise typer.Exit(code=1) + + if show_dataset_info: + print_dataset_info(dataset_name, query_ids, queries, corpus_ids, corpus_images, corpus_texts, qrels) + + # Cache-only mode: build corpus embeddings and exit without running queries. + if cache_only: + pipeline.dataset_name = dataset_name + pipeline.index(corpus_ids=corpus_ids, corpus_images=corpus_images, corpus_texts=corpus_texts) + print("Corpus embeddings cached. Exiting (--cache-only).") + return + + # Evaluate + print("\nRunning evaluation...") + try: + pipeline.dataset_name = dataset_name + trace_run_name_eff = trace_run_name or default_trace_run_name(pipeline) + results = evaluate_retrieval( + pipeline=pipeline, + query_ids=query_ids, + queries=queries, + corpus_ids=corpus_ids, + corpus_images=corpus_images, + corpus_texts=corpus_texts, + qrels=qrels, + traces_dir=traces_dir, + trace_run_name=trace_run_name_eff, + dataset_name=dataset_name, + split=split, + language=language, + query_ids_selector=query_ids_selector, + excluded_ids_by_query=excluded_ids_by_query, + metrics=_METRICS, + ) + except Exception as e: + print(f"\nError during evaluation: {e}\n") + raise typer.Exit(code=1) + + timing_info = results.get("_timing", {}) if isinstance(results, dict) else {} + aggregated = aggregate_results(results, query_languages) + + # Display results + _display_results(aggregated, timing_info) + + # Save results + if output_file is None: + pipeline_name = trace_run_name_eff + dataset_short = dataset_trace_dir(dataset_name) + os.makedirs(f"results/{pipeline_name}", exist_ok=True) + output_file = f"results/{pipeline_name}/{dataset_short}.json" + + output_path = Path(output_file) + wall_time_per_query_ms = None + if ( + isinstance(timing_info, dict) + and timing_info.get("total_wall_time_milliseconds") is not None + and len(query_ids) > 0 + ): + wall_time_per_query_ms = timing_info["total_wall_time_milliseconds"] / len(query_ids) + + output_data = { + "dataset": dataset_name, + "split": split, + "language": language, + "query_ids_selector": query_ids_selector, + "traces_dir": traces_dir, + "trace_run_name": trace_run_name_eff, + "pipeline_label": pipeline_label, + "pipeline_args": pipeline_args_for_output, + "aggregated_metrics": aggregated, + "wall_time_per_query_milliseconds": wall_time_per_query_ms, + } + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"Results saved to: {output_path}\n") + + if isinstance(timing_info, dict): + missing_num = timing_info.get("missing_num_queries", 0) + expected_num = timing_info.get("expected_num_queries", len(query_ids)) + if isinstance(missing_num, int) and missing_num > 0: + typer.secho( + f"WARNING: {missing_num}/{expected_num} expected queries were not evaluated.", + fg=typer.colors.RED, + bold=True, + ) + + +def _display_results(aggregated: Dict[str, Any], timing_info: Any) -> None: + print("\n" + "=" * 70) + print("Evaluation Results") + print("=" * 70) + + key_metrics = ["ndcg_cut_10", "ndcg_cut_5", "recall_10", "recall_5", "map", "recip_rank"] + + if "overall" in aggregated and "by_language" in aggregated: + overall_metrics = aggregated["overall"] + timing_info = aggregated.get("timing", {}) + + print("\n--- Overall Results ---") + for metric in key_metrics: + if metric in overall_metrics: + print(f" {metric:25s}: {overall_metrics[metric]:.4f}") + + print("\n--- Results by Language ---") + for lang, lang_metrics in aggregated["by_language"].items(): + num_queries = lang_metrics.get("num_queries", 0) + print(f"\n{lang.capitalize()} ({num_queries} queries):") + for metric in key_metrics: + if metric in lang_metrics: + print(f" {metric:25s}: {lang_metrics[metric]:.4f}") + + other_count = len(overall_metrics) - len([m for m in key_metrics if m in overall_metrics]) + if other_count > 0: + print(f"\n({other_count} additional metrics saved to file)") + + if timing_info: + print("\n--- Timing Metrics ---") + for metric, value in timing_info.items(): + if "milliseconds" in metric: + print(f" {metric:40s}: {value:.2f}ms") + else: + print(f" {metric:40s}: {value:.2f}") + else: + timing_metrics = { + k: v for k, v in aggregated.items() if k.startswith(("total_", "average_", "queries_", "num_")) + } + retrieval_metrics = {k: v for k, v in aggregated.items() if k not in timing_metrics} + + if retrieval_metrics: + print("\nKey Retrieval Metrics:") + for metric in key_metrics: + if metric in retrieval_metrics: + print(f" {metric:25s}: {retrieval_metrics[metric]:.4f}") + other_count = len(retrieval_metrics) - len([m for m in key_metrics if m in retrieval_metrics]) + if other_count > 0: + print(f"\n ({other_count} additional metrics saved to file)") + + if timing_metrics: + print("\nTiming Metrics:") + for metric, value in timing_metrics.items(): + if "milliseconds" in metric: + print(f" {metric:40s}: {value:.2f}ms") + else: + print(f" {metric:40s}: {value:.2f}") + + print("=" * 70 + "\n") + + +# --------------------------------------------------------------------------- +# Commands +# --------------------------------------------------------------------------- + +_backend_help = "Dense retriever backend. One of: " + ", ".join(sorted(VALID_BACKENDS)) + + +@app.command("dense-retrieval") +def dense_retrieval( + dataset_name: Annotated[str, typer.Option(help="Dataset name (e.g. 'bright/biology', 'vidore/vidore_v3_hr')")], + backend: Annotated[str, typer.Option(help=_backend_help)], + top_k: Annotated[int, typer.Option(help="Number of results per query")] = 100, + language: Annotated[Optional[str], typer.Option(help="Language filter (e.g. 'english')")] = None, + split: Annotated[str, typer.Option(help="Dataset split")] = "test", + query_ids_selector: Annotated[ + Optional[str], + typer.Option("--query-ids", help="Query-id selector, e.g. '0-99,120'"), + ] = None, + output_file: Annotated[Optional[str], typer.Option(help="Path to save results JSON")] = None, + trace_run_name: Annotated[Optional[str], typer.Option(help="Trace subdirectory name")] = None, + traces_dir: Annotated[str, typer.Option(help="Traces root directory")] = "traces", + show_dataset_info: Annotated[bool, typer.Option(help="Show dataset info")] = True, + pipeline_args: Annotated[ + Optional[str], + typer.Option(help="JSON string of additional backend overrides"), + ] = None, + cache_only: Annotated[ + bool, typer.Option("--cache-only", help="Only build corpus embedding cache, skip evaluation") + ] = False, +): + """ + Evaluate a dense retrieval backend on a dataset. + + Examples: + + retrieval-bench evaluate dense-retrieval \\ + --dataset-name bright/biology \\ + --backend llama-nv-embed-reasoning-3b + + retrieval-bench evaluate dense-retrieval \\ + --dataset-name vidore/vidore_v3_hr \\ + --backend llama-nemotron-embed-vl-1b-v2 \\ + --language english + """ + if backend not in VALID_BACKENDS: + print(f"\nUnknown backend: {backend!r}") + print(f"Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + raise typer.Exit(code=1) + + overrides: Dict[str, Any] = {} + if pipeline_args: + try: + overrides = json.loads(pipeline_args) + except json.JSONDecodeError as e: + print(f"\nError parsing --pipeline-args: {e}\n") + raise typer.Exit(code=1) + + from retrieval_bench.pipelines.dense import DenseRetrievalPipeline + + pipeline = DenseRetrievalPipeline(backend=backend, top_k=top_k, **overrides) + + _run_evaluation( + pipeline=pipeline, + dataset_name=dataset_name, + split=split, + language=language, + query_ids_selector=query_ids_selector, + trace_run_name=trace_run_name, + traces_dir=traces_dir, + output_file=output_file, + show_dataset_info=show_dataset_info, + pipeline_label=f"dense-retrieval/{backend}", + pipeline_args_for_output={"backend": backend, "top_k": top_k, **overrides}, + cache_only=cache_only, + ) + + +@app.command("agentic-retrieval") +def agentic_retrieval( + dataset_name: Annotated[str, typer.Option(help="Dataset name (e.g. 'bright/biology', 'vidore/vidore_v3_hr')")], + backend: Annotated[str, typer.Option(help=_backend_help)], + llm_model: Annotated[str, typer.Option(help="LLM model identifier (e.g. 'gpt-4o', 'openai/my-model')")], + num_concurrent: Annotated[int, typer.Option(help="Number of concurrent agent queries")] = 1, + reasoning_effort: Annotated[str, typer.Option(help="Reasoning effort level")] = "high", + target_top_k: Annotated[int, typer.Option(help="Target number of final results per query")] = 10, + retriever_top_k: Annotated[int, typer.Option(help="Retriever top-k (overrides default 500)")] = 500, + language: Annotated[Optional[str], typer.Option(help="Language filter (e.g. 'english')")] = None, + split: Annotated[str, typer.Option(help="Dataset split")] = "test", + query_ids_selector: Annotated[ + Optional[str], + typer.Option("--query-ids", help="Query-id selector, e.g. '0-99,120'"), + ] = None, + output_file: Annotated[Optional[str], typer.Option(help="Path to save results JSON")] = None, + trace_run_name: Annotated[Optional[str], typer.Option(help="Trace subdirectory name")] = None, + traces_dir: Annotated[str, typer.Option(help="Traces root directory")] = "traces", + show_dataset_info: Annotated[bool, typer.Option(help="Show dataset info")] = True, + pipeline_args: Annotated[ + Optional[str], + typer.Option(help="JSON string of additional overrides (backend and agent)"), + ] = None, +): + """ + Evaluate an agentic retrieval pipeline (dense retrieval + LLM agent). + + Examples: + + retrieval-bench evaluate agentic-retrieval \\ + --dataset-name bright/biology \\ + --backend llama-nv-embed-reasoning-3b \\ + --llm-model your-llm-model \\ + --num-concurrent 10 + + retrieval-bench evaluate agentic-retrieval \\ + --dataset-name vidore/vidore_v3_hr \\ + --backend llama-nemotron-embed-vl-1b-v2 \\ + --llm-model your-llm-model + """ + if backend not in VALID_BACKENDS: + print(f"\nUnknown backend: {backend!r}") + print(f"Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + raise typer.Exit(code=1) + + overrides: Dict[str, Any] = {} + if pipeline_args: + try: + overrides = json.loads(pipeline_args) + except json.JSONDecodeError as e: + print(f"\nError parsing --pipeline-args: {e}\n") + raise typer.Exit(code=1) + + from retrieval_bench.pipelines.agentic import AgenticRetrievalPipeline + + pipeline = AgenticRetrievalPipeline( + backend=backend, + retriever_top_k=retriever_top_k, + num_concurrent=num_concurrent, + llm_model=llm_model, + reasoning_effort=reasoning_effort, + target_top_k=target_top_k, + **overrides, + ) + + _run_evaluation( + pipeline=pipeline, + dataset_name=dataset_name, + split=split, + language=language, + query_ids_selector=query_ids_selector, + trace_run_name=trace_run_name, + traces_dir=traces_dir, + output_file=output_file, + show_dataset_info=show_dataset_info, + pipeline_label=f"agentic-retrieval/{backend}", + pipeline_args_for_output={ + "backend": backend, + "llm_model": llm_model, + "num_concurrent": num_concurrent, + "reasoning_effort": reasoning_effort, + "target_top_k": target_top_k, + "retriever_top_k": retriever_top_k, + **overrides, + }, + ) diff --git a/retrieval-bench/src/retrieval_bench/cli/main.py b/retrieval-bench/src/retrieval_bench/cli/main.py new file mode 100644 index 000000000..ef4113d3b --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/cli/main.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Top-level CLI entrypoint. + +This package intentionally exposes only the ``evaluate`` command group. +""" + +import logging +from typing import Annotated + +import typer + +from retrieval_bench.cli.evaluate import app as evaluate_app +from vidore_benchmark.utils.logging_utils import setup_logging + +logger = logging.getLogger(__name__) + +app = typer.Typer( + help="CLI for retrieval pipeline benchmarking.", + no_args_is_help=True, +) + +app.add_typer( + evaluate_app, + name="evaluate", + help="Evaluate retrieval pipelines on ViDoRe v3 / BRIGHT datasets", +) + + +@app.callback() +def main(log_level: Annotated[str, typer.Option("--log", help="Logging level")] = "warning"): + setup_logging(log_level) + logger.info("Logging level set to `%s`", log_level) + + +if __name__ == "__main__": + app() diff --git a/retrieval-bench/src/retrieval_bench/cli/pipeline_evaluation.py b/retrieval-bench/src/retrieval_bench/cli/pipeline_evaluation.py new file mode 100644 index 000000000..6a1b7bdf6 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/cli/pipeline_evaluation.py @@ -0,0 +1,1196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Pipeline utility CLI commands (report-results, compare-results, etc.). + +The evaluation commands (dense-retrieval, agentic-retrieval) live in +``retrieval_bench.cli.evaluate``. This module retains the utility / +reporting commands that operate on saved results and traces. +""" + +import importlib.util +import json +import logging +import sys +from pathlib import Path +from typing import Annotated, Any, Optional, Set, Tuple + +import typer + +from retrieval_bench.pipeline_evaluation import ( + BasePipeline, + get_available_datasets, +) +from retrieval_bench.pipeline_evaluation.tracing import dataset_trace_dir +from vidore_benchmark.utils.logging_utils import setup_logging + +logger = logging.getLogger(__name__) + +app = typer.Typer( + help=""" + CLI for evaluating abstract pipelines on ViDoRe v3 datasets. + + Evaluate custom retrieval pipelines that inherit from BasePipeline. + Supports built-in pipelines (random, file-based) and custom Python implementations. + """, + no_args_is_help=True, +) + + +@app.callback() +def main(log_level: Annotated[str, typer.Option("--log", help="Logging level")] = "warning"): + """Initialize logging configuration.""" + setup_logging(log_level) + logger.info("Logging level set to `%s`", log_level) + + +def _parse_query_ids_selector(selector: str) -> Set[str]: + """ + Parse a comma-separated list of integer IDs and inclusive ranges. + + Examples: + - "0-99,120,200-210" -> {"0","1",...,"99","120","200",...,"210"} + - "5" -> {"5"} + """ + selector = (selector or "").strip() + if not selector: + raise ValueError("Empty --query-ids selector.") + + ids: Set[str] = set() + tokens = [t.strip() for t in selector.split(",") if t.strip()] + if not tokens: + raise ValueError("Empty --query-ids selector.") + + for tok in tokens: + if "-" in tok: + parts = [p.strip() for p in tok.split("-", 1)] + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError(f"Invalid range token '{tok}'. Expected format like '0-10'.") + try: + start = int(parts[0]) + end = int(parts[1]) + except ValueError: + raise ValueError(f"Invalid range token '{tok}'. Range bounds must be integers.") + if start < 0 or end < 0: + raise ValueError(f"Invalid range token '{tok}'. Query IDs must be >= 0.") + if start > end: + raise ValueError(f"Invalid range token '{tok}'. Range start must be <= end.") + for i in range(start, end + 1): + ids.add(str(i)) + else: + try: + val = int(tok) + except ValueError: + raise ValueError(f"Invalid id token '{tok}'. Expected an integer or a range like '0-10'.") + if val < 0: + raise ValueError(f"Invalid id token '{tok}'. Query IDs must be >= 0.") + ids.add(str(val)) + + return ids + + +def _parse_index_selector(selector: str) -> Set[int]: + """ + Parse a comma-separated list of integer indices and inclusive ranges. + + Examples: + - "0-99,120,200-210" -> {0,1,...,99,120,200,...,210} + - "5" -> {5} + """ + selector = (selector or "").strip() + if not selector: + raise ValueError("Empty --query-ids selector.") + + ids: Set[int] = set() + tokens = [t.strip() for t in selector.split(",") if t.strip()] + if not tokens: + raise ValueError("Empty --query-ids selector.") + + for tok in tokens: + if "-" in tok: + parts = [p.strip() for p in tok.split("-", 1)] + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError(f"Invalid range token '{tok}'. Expected format like '0-10'.") + try: + start = int(parts[0]) + end = int(parts[1]) + except ValueError: + raise ValueError(f"Invalid range token '{tok}'. Range bounds must be integers.") + if start < 0 or end < 0: + raise ValueError(f"Invalid range token '{tok}'. Indices must be >= 0.") + if start > end: + raise ValueError(f"Invalid range token '{tok}'. Range start must be <= end.") + for i in range(start, end + 1): + ids.add(int(i)) + else: + try: + val = int(tok) + except ValueError: + raise ValueError(f"Invalid id token '{tok}'. Expected an integer or a range like '0-10'.") + if val < 0: + raise ValueError(f"Invalid id token '{tok}'. Indices must be >= 0.") + ids.add(int(val)) + + return ids + + +def _query_ids_are_numeric(query_ids: list[str]) -> bool: + return all(str(qid).isdigit() for qid in query_ids) + + +def _filter_queries_by_ids( + query_ids: list[str], + queries: list[str], + qrels: dict, + query_languages: dict, + requested_ids: Set[str], +): + """Filter query_ids/queries/qrels/query_languages in sync, preserving original order.""" + available = set(query_ids) + missing = sorted(requested_ids - available, key=lambda x: int(x) if x.isdigit() else x) + if missing: + raise ValueError( + "Some requested query IDs are not available after filtering (e.g. language).\n" + f"Missing IDs: {', '.join(missing[:50])}" + (" ..." if len(missing) > 50 else "") + ) + + filtered_query_ids: list[str] = [] + filtered_queries: list[str] = [] + for qid, q in zip(query_ids, queries): + if qid in requested_ids: + filtered_query_ids.append(qid) + filtered_queries.append(q) + + if not filtered_query_ids: + raise ValueError("After applying --query-ids, zero queries remain to evaluate.") + + filtered_qrels = {qid: qrels[qid] for qid in filtered_query_ids if qid in qrels} + filtered_query_languages = {qid: query_languages.get(qid, "unknown") for qid in filtered_query_ids} + + return filtered_query_ids, filtered_queries, filtered_qrels, filtered_query_languages + + +def _filter_queries_by_positions( + query_ids: list[str], + queries: list[str], + qrels: dict, + query_languages: dict, + requested_positions: Set[int], +): + """ + Filter query_ids/queries/qrels/query_languages in sync by *position* (0-based index). + + This is used for datasets whose query IDs are non-numeric (e.g. BRIGHT aops), + so we can still use compact numeric selectors like '0-99,120,200-210'. + """ + if not requested_positions: + raise ValueError("Empty --query-ids selector.") + + n = len(query_ids) + bad = sorted([i for i in requested_positions if i < 0 or i >= n]) + if bad: + raise ValueError( + "Some requested query indices are out of range.\n" + f"Valid range: 0..{max(0, n-1)}\n" + f"Bad indices (first 50): {bad[:50]}" + (" ..." if len(bad) > 50 else "") + ) + + filtered_query_ids: list[str] = [] + filtered_queries: list[str] = [] + for idx, (qid, q) in enumerate(zip(query_ids, queries)): + if idx in requested_positions: + filtered_query_ids.append(qid) + filtered_queries.append(q) + + if not filtered_query_ids: + raise ValueError("After applying --query-ids, zero queries remain to evaluate.") + + filtered_qrels = {qid: qrels[qid] for qid in filtered_query_ids if qid in qrels} + filtered_query_languages = {qid: query_languages.get(qid, "unknown") for qid in filtered_query_ids} + return filtered_query_ids, filtered_queries, filtered_qrels, filtered_query_languages + + +def _extract_aggregated_metric(obj: Any, metric: str) -> Optional[float]: + """ + Extract a metric from a pipeline results JSON object. + + Supports both formats: + - language-split: aggregated_metrics.overall. + - flat: aggregated_metrics. + """ + if not isinstance(obj, dict): + return None + aggregated = obj.get("aggregated_metrics", None) + if not isinstance(aggregated, dict): + return None + + if "overall" in aggregated and isinstance(aggregated.get("overall"), dict): + val = aggregated["overall"].get(metric, None) + else: + val = aggregated.get(metric, None) + + if isinstance(val, (int, float)): + return float(val) + return None + + +def _extract_pipeline_infos(obj: Any) -> Optional[dict]: + """ + Extract the pipeline_infos object from a results JSON payload (if present). + + Expected location: + aggregated_metrics._infos.pipeline_infos + """ + if not isinstance(obj, dict): + return None + aggregated = obj.get("aggregated_metrics", None) + if not isinstance(aggregated, dict): + return None + infos = aggregated.get("infos", None) or aggregated.get("_infos", None) + if not isinstance(infos, dict): + return None + pipeline_infos = infos.get("pipeline_infos", None) + if not isinstance(pipeline_infos, dict): + return None + return pipeline_infos + + +def _try_extract_llm_summary_from_results(obj: Any) -> Optional[Tuple[int, int, int, int, Optional[float]]]: + """ + Return (llm_error_count, prompt_tokens, completion_tokens, total_tokens, avg_trajectory_steps) + if present and valid. Otherwise return None. + """ + pi = _extract_pipeline_infos(obj) + if not isinstance(pi, dict): + return None + + err_ids = pi.get("llm_error_query_ids", None) + pt = pi.get("llm_total_prompt_tokens", None) + ct = pi.get("llm_total_completion_tokens", None) + tt = pi.get("llm_total_tokens", None) + + if not isinstance(err_ids, list) or not all(isinstance(x, str) for x in err_ids): + return None + if not isinstance(pt, int) or not isinstance(ct, int) or not isinstance(tt, int): + return None + + avg_traj = pi.get("avg_trajectory_steps", None) + avg_traj_f = float(avg_traj) if isinstance(avg_traj, (int, float)) else None + + return (len(err_ids), int(pt), int(ct), int(tt), avg_traj_f) + + +def _extract_wall_time_and_nq(obj: Any) -> Tuple[Optional[float], Optional[int]]: + """ + Return (total_wall_time_milliseconds, num_queries) from a results JSON object. + + Handles both flat (BRIGHT) and language-split (vidore v3) aggregated_metrics layouts. + Returns (None, None) when the fields are unavailable. + """ + if not isinstance(obj, dict): + return None, None + aggregated = obj.get("aggregated_metrics", None) + if not isinstance(aggregated, dict): + return None, None + + if "overall" in aggregated and isinstance(aggregated.get("overall"), dict): + timing = aggregated.get("timing", {}) + else: + timing = aggregated + + if not isinstance(timing, dict): + return None, None + + wall_ms = timing.get("total_wall_time_milliseconds", None) + nq = timing.get("num_queries", None) + if not isinstance(wall_ms, (int, float)) or not isinstance(nq, (int, float)): + return None, None + return float(wall_ms), int(nq) + + +def _compute_llm_summary_from_traces( + traces_dir: str, trace_run_name: str, dataset_short: str +) -> Optional[Tuple[int, int, int, int, Optional[float], int]]: + """ + Return (llm_error_count, prompt_tokens, completion_tokens, total_tokens, avg_trajectory_steps, num_traces) + by scanning per-query trace files. + Returns None if the trace directory does not exist. + """ + trace_root = Path(traces_dir) / trace_run_name / dataset_short + if not trace_root.exists() or not trace_root.is_dir(): + return None + + llm_error_count = 0 + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + trajectory_steps_vals: list[int] = [] + num_traces = 0 + + for p in sorted(trace_root.glob("*.json")): + try: + with open(p, "r") as f: + obj = json.load(f) + except Exception: + continue + if not isinstance(obj, dict): + continue + num_traces += 1 + + pipeline_trace = obj.get("pipeline_trace", None) + if not isinstance(pipeline_trace, dict): + continue + + if isinstance(pipeline_trace.get("llm_error", None), str): + llm_error_count += 1 + + usage = pipeline_trace.get("llm_usage", None) + if isinstance(usage, dict): + pt = usage.get("prompt_tokens", None) + ct = usage.get("completion_tokens", None) + tt = usage.get("total_tokens", None) + if isinstance(pt, int): + prompt_tokens += pt + if isinstance(ct, int): + completion_tokens += ct + if isinstance(tt, int): + total_tokens += tt + + ts = pipeline_trace.get("trajectory_steps", None) + if isinstance(ts, int): + trajectory_steps_vals.append(ts) + + avg_traj = (sum(trajectory_steps_vals) / len(trajectory_steps_vals)) if trajectory_steps_vals else None + return (llm_error_count, prompt_tokens, completion_tokens, total_tokens, avg_traj, num_traces) + + +def _compute_llm_errors_from_traces( + traces_dir: str, trace_run_name: str, dataset_short: str, *, max_error_len: int = 500 +) -> Optional[list[tuple[str, str]]]: + """ + Return a list of (query_id, llm_error) by scanning per-query trace files. + Returns None if the trace directory does not exist. + """ + trace_root = Path(traces_dir) / trace_run_name / dataset_short + if not trace_root.exists() or not trace_root.is_dir(): + return None + + out: list[tuple[str, str]] = [] + for p in sorted(trace_root.glob("*.json")): + try: + with open(p, "r") as f: + obj = json.load(f) + except Exception: + continue + if not isinstance(obj, dict): + continue + + pipeline_trace = obj.get("pipeline_trace", None) + if not isinstance(pipeline_trace, dict): + continue + + err = pipeline_trace.get("llm_error", None) + if not isinstance(err, str): + continue + + qid = obj.get("query_id", None) + qid_s = str(qid) if qid is not None else p.stem + if max_error_len <= 0: + err_s = err + else: + err_s = ( + err + if len(err) <= max_error_len + else (err[:max_error_len] + f"... [truncated, original_len={len(err)}]") + ) + out.append((qid_s, err_s)) + + # Sort numerically when possible (query ids are usually numeric strings). + out.sort(key=lambda x: int(x[0]) if x[0].isdigit() else x[0]) + return out + + +def _load_pipeline_from_module(module_path: str, class_name: str, **kwargs) -> BasePipeline: + """ + Dynamically load a pipeline class from a Python file. + + Args: + module_path: Path to the Python file containing the pipeline class + class_name: Name of the pipeline class to instantiate + **kwargs: Arguments to pass to the pipeline constructor + + Returns: + Instantiated pipeline object + """ + module_path = Path(module_path).resolve() + + if not module_path.exists(): + raise FileNotFoundError(f"Module file not found: {module_path}") + + # Load the module + spec = importlib.util.spec_from_file_location("custom_pipeline", module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module from {module_path}") + + module = importlib.util.module_from_spec(spec) + sys.modules["custom_pipeline"] = module + spec.loader.exec_module(module) + + # Get the class + if not hasattr(module, class_name): + raise AttributeError(f"Class '{class_name}' not found in {module_path}") + + pipeline_class = getattr(module, class_name) + + # Verify it's a BasePipeline subclass + if not issubclass(pipeline_class, BasePipeline): + raise TypeError(f"Class '{class_name}' must inherit from BasePipeline") + + # Instantiate the pipeline + return pipeline_class(**kwargs) + + +@app.command() +def list_datasets(): + """ + List all available ViDoRe v3 datasets. + + Example: + retrieval-bench pipeline list-datasets + """ + datasets = get_available_datasets() + + print("\n" + "=" * 70) + print("Available ViDoRe v3 Datasets") + print("=" * 70) + for i, dataset_name in enumerate(datasets, 1): + print(f"{i:2d}. {dataset_name}") + print("=" * 70) + print(f"\nTotal: {len(datasets)} datasets\n") + + +@app.command() +def list_backends(): + """ + List available dense retrieval backends. + + Example: + retrieval-bench evaluate utils list-backends + """ + from retrieval_bench.pipelines.backends import VALID_BACKENDS as _backends + + print("\n" + "=" * 70) + print("Available Dense Retrieval Backends") + print("=" * 70) + for i, b in enumerate(sorted(_backends), 1): + print(f" {i}. {b}") + print("=" * 70) + print("\nUsage:") + print(" retrieval-bench evaluate dense-retrieval --backend --dataset-name ") + print(" retrieval-bench evaluate agentic-retrieval --backend --dataset-name \n") + + +@app.command() +def report_results( + results_dir: Annotated[ + str, + typer.Option( + "--results-dir", + help="Directory containing per-dataset result JSON files (e.g. results/ColEmbedPipeline).", + ), + ], +): + """ + Report NDCG@10 and query-coverage per dataset for a results directory. + + This command expects the JSON structure produced by `retrieval-bench pipeline evaluate` + and `retrieval-bench pipeline evaluate-all`, i.e. each file contains: + - aggregated_metrics (including ndcg_cut_10) + - timing info (including expected_num_queries / num_queries / missing_num_queries) when available + + We intentionally do not require per-id lists (evaluated/missing ids); coverage is computed + from counts stored in the result file. + """ + root = Path(results_dir) + if not root.exists() or not root.is_dir(): + raise typer.BadParameter(f"--results-dir must be an existing directory: {root}") + + datasets = get_available_datasets() + + print("\n" + "=" * 70) + print(f"Results report: {root}") + print("=" * 70) + print(f"{'dataset':28s} {'ndcg@10':>8s} {'coverage':>12s} {'status':>8s}") + print("-" * 70) + + for dataset_name in datasets: + ds_short = dataset_trace_dir(dataset_name) + path = root / f"{ds_short}.json" + if not path.exists(): + print(f"{ds_short:28s} {'-':>8s} {'-':>12s} {'MISSING':>8s}") + continue + + try: + with open(path, "r") as f: + obj = json.load(f) + except Exception: + print(f"{ds_short:28s} {'-':>8s} {'-':>12s} {'ERROR':>8s}") + continue + + aggregated: Any = obj.get("aggregated_metrics", {}) if isinstance(obj, dict) else {} + if not isinstance(aggregated, dict): + aggregated = {} + + # Extract NDCG@10 (handle both nested language-split format and flat format). + ndcg = None + timing: Any = {} + if "overall" in aggregated and isinstance(aggregated.get("overall"), dict): + ndcg = aggregated["overall"].get("ndcg_cut_10", None) + timing = aggregated.get("timing", {}) if isinstance(aggregated.get("timing", {}), dict) else {} + else: + ndcg = aggregated.get("ndcg_cut_10", None) + # In flat mode, timing fields are merged into aggregated_metrics. + timing = aggregated + + ndcg_str = f"{float(ndcg):.4f}" if isinstance(ndcg, (int, float)) else "-" + + expected = timing.get("expected_num_queries", None) if isinstance(timing, dict) else None + num = timing.get("num_queries", None) if isinstance(timing, dict) else None + missing = timing.get("missing_num_queries", None) if isinstance(timing, dict) else None + + coverage = "-" + status = "UNKNOWN" + if isinstance(expected, (int, float)) and isinstance(num, (int, float)): + expected_i = int(expected) + num_i = int(num) + coverage = f"{num_i}/{expected_i}" + status = "FULL" if expected_i == num_i else "PARTIAL" + elif isinstance(expected, (int, float)) and isinstance(missing, (int, float)): + expected_i = int(expected) + missing_i = int(missing) + num_i = max(0, expected_i - missing_i) + coverage = f"{num_i}/{expected_i}" + status = "FULL" if missing_i == 0 else "PARTIAL" + + print(f"{ds_short:28s} {ndcg_str:>8s} {coverage:>12s} {status:>8s}") + + print("-" * 70 + "\n") + + +@app.command() +def compare_results( + results_dirs: Annotated[ + list[str], + typer.Option( + "--results-dirs", + help="One or more pipeline results directories (e.g. results/ColEmbedPipeline__...).", + ), + ], + all_datasets: Annotated[ + bool, + typer.Option( + "--all-datasets", + "--all", + help="Report across all available ViDoRe v3 datasets (overrides --datasets).", + ), + ] = False, + datasets: Annotated[ + Optional[str], + typer.Option( + "--datasets", + help="Optional comma-separated dataset filter (accepts either full names like 'vidore/vidore_v3_hr' " + "or shorts like 'vidore_v3_hr'). Defaults to all datasets.", + ), + ] = None, + metric: Annotated[ + str, + typer.Option("--metric", help="Aggregated metric key to report (default: ndcg_cut_10)."), + ] = "ndcg_cut_10", + csv: Annotated[ + bool, + typer.Option( + "--csv", + help="Print only a comma-separated header + one row per dataset (easy to paste into Google Docs/Sheets).", + ), + ] = False, +): + """ + Compare aggregated metrics (default: NDCG@10) across multiple pipeline results directories. + """ + if not results_dirs: + raise typer.BadParameter("--results-dirs must include at least one directory.") + + # Canonical dataset list. + all_ds = get_available_datasets() + short_by_full = {ds: dataset_trace_dir(ds) for ds in all_ds} + full_by_short = {v: k for k, v in short_by_full.items()} + + # Optional dataset filter. + selected = list(all_ds) + if (not all_datasets) and datasets: + toks = [t.strip() for t in datasets.split(",") if t.strip()] + if not toks: + raise typer.BadParameter("--datasets was provided but empty.") + selected = [] + for tok in toks: + # Allow specifying either the full dataset name (e.g. 'bright/biology') + # or the dataset file key used on disk (e.g. 'bright__biology'). + if tok in short_by_full: + selected.append(tok) + elif tok in full_by_short: + selected.append(full_by_short[tok]) + else: + raise typer.BadParameter( + f"Unknown dataset '{tok}'. Expected one of: {', '.join(short_by_full.values())}" + ) + + # Pipeline labels for columns. + dirs: list[Path] = [Path(d) for d in results_dirs] + labels: list[str] = [p.name or str(p) for p in dirs] + + # Build matrix: dataset_short -> label -> (status, value) + rows: list[tuple[str, dict[str, str]]] = [] + coverage_by_label: dict[str, int] = {lab: 0 for lab in labels} + + for ds in selected: + ds_short = short_by_full[ds] + per_label: dict[str, str] = {} + for lab, root in zip(labels, dirs): + path = root / f"{ds_short}.json" + if not path.exists(): + per_label[lab] = "MISSING" + continue + try: + with open(path, "r") as f: + obj = json.load(f) + except Exception: + per_label[lab] = "ERROR" + continue + + val = _extract_aggregated_metric(obj, metric) + if val is None: + per_label[lab] = "-" + else: + per_label[lab] = f"{val:.4f}" + coverage_by_label[lab] += 1 + + rows.append((ds_short, per_label)) + + def _csv_escape(s: str) -> str: + s = "" if s is None else str(s) + if any(ch in s for ch in [",", '"', "\n", "\r"]): + s = s.replace('"', '""') + return f'"{s}"' + return s + + if csv: + header_cells = ["dataset"] + labels + print(",".join(_csv_escape(x) for x in header_cells)) + for ds_short, per_label in rows: + row_cells = [ds_short] + [per_label.get(lab, "") for lab in labels] + print(",".join(_csv_escape(x) for x in row_cells)) + return + + # Pretty print aligned table. + dataset_col_w = max(len("dataset"), max((len(ds_short) for ds_short, _ in rows), default=0)) + col_widths: dict[str, int] = {} + for lab in labels: + max_cell = max((len(r[lab]) for _, r in rows), default=0) + col_widths[lab] = max(len(lab), max_cell, 8) + + print("\n" + "=" * 70) + print(f"Compare results ({metric})") + print("=" * 70) + header = ["dataset".ljust(dataset_col_w)] + header += [lab.rjust(col_widths[lab]) for lab in labels] + print(" ".join(header)) + print("-" * (dataset_col_w + sum(col_widths[lab] for lab in labels) + 2 * len(labels))) + + for ds_short, per_label in rows: + line = [ds_short.ljust(dataset_col_w)] + line += [per_label[lab].rjust(col_widths[lab]) for lab in labels] + print(" ".join(line)) + + # Optional coverage summary. + print("-" * (dataset_col_w + sum(col_widths[lab] for lab in labels) + 2 * len(labels))) + cov_line = ["coverage".ljust(dataset_col_w)] + for lab in labels: + cov_line.append(f"{coverage_by_label[lab]}/{len(rows)}".rjust(col_widths[lab])) + print(" ".join(cov_line)) + print("=" * 70 + "\n") + + +@app.command() +def report_llm_usage( + results_dir: Annotated[ + str, + typer.Option( + "--results-dir", + help="Pipeline results directory (e.g. results/AgenticPipelineV1__...).", + ), + ], + all_datasets: Annotated[ + bool, + typer.Option( + "--all-datasets", + "--all", + help="Report across all available ViDoRe v3 datasets (overrides --datasets).", + ), + ] = False, + datasets: Annotated[ + Optional[str], + typer.Option( + "--datasets", + help="Optional comma-separated dataset filter (accepts either full names like 'vidore/vidore_v3_hr' " + "or shorts like 'vidore_v3_hr'). Defaults to all datasets.", + ), + ] = None, + traces_dir: Annotated[ + str, + typer.Option( + "--traces-dir", + help="Default traces root directory to use when results JSON does not specify traces_dir.", + ), + ] = "traces", + list_errors: Annotated[ + bool, + typer.Option( + "--list-errors", + help="After the summary table, list all encountered LLM errors per dataset (query id + error string).", + ), + ] = False, + max_error_len: Annotated[ + int, + typer.Option( + "--max-error-len", + help=( + "Maximum number of characters to print per error string" + " when using --list-errors. Use 0 for no truncation." + ), + ), + ] = 0, +): + """ + Report LLM error counts and token usage per dataset for a pipeline results directory. + + Uses results JSON summary fields when present; otherwise falls back to scanning per-query traces. + """ + results_root = Path(results_dir) + if not results_root.exists() or not results_root.is_dir(): + raise typer.BadParameter(f"--results-dir must be an existing directory: {results_root}") + + # Canonical dataset list. + ds_full = get_available_datasets() + short_by_full = {ds: dataset_trace_dir(ds) for ds in ds_full} + full_by_short = {v: k for k, v in short_by_full.items()} + + selected = list(ds_full) + if not all_datasets and datasets: + toks = [t.strip() for t in datasets.split(",") if t.strip()] + if not toks: + raise typer.BadParameter("--datasets was provided but empty.") + selected = [] + for tok in toks: + if tok in short_by_full: + selected.append(tok) + elif tok in full_by_short: + selected.append(full_by_short[tok]) + else: + raise typer.BadParameter( + f"Unknown dataset '{tok}'. Expected one of: {', '.join(short_by_full.values())}" + ) + + # Table rows: (dataset_short, nq, err_count, pt, ct, tt, wall_ms_q, avg_traj, source) + rows: list[tuple[str, str, str, str, str, str, str, str, str]] = [] + tot_err = 0 + tot_pt = 0 + tot_ct = 0 + tot_tt = 0 + tot_wall_ms = 0.0 + tot_nq = 0 + tot_traj_steps_sum = 0.0 + tot_traj_steps_count = 0 + + inferred_trace_run_name = results_root.name + trace_ctx_by_dataset: dict[str, tuple[str, str]] = {} + results_obj_by_dataset: dict[str, Any] = {} + + for ds in selected: + ds_short = short_by_full[ds] + result_path = results_root / f"{ds_short}.json" + + # Prefer results JSON summary fields. + if result_path.exists(): + try: + with open(result_path, "r") as f: + obj = json.load(f) + except Exception: + obj = None + results_obj_by_dataset[ds_short] = obj + + # Determine effective trace location (used for fallback and optional error listing). + trace_run_name_eff = inferred_trace_run_name + traces_dir_eff = traces_dir + if isinstance(obj, dict): + trn = obj.get("trace_run_name", None) + tdir = obj.get("traces_dir", None) + if isinstance(trn, str) and trn: + trace_run_name_eff = trn + if isinstance(tdir, str) and tdir: + traces_dir_eff = tdir + trace_ctx_by_dataset[ds_short] = (traces_dir_eff, trace_run_name_eff) + + summary = _try_extract_llm_summary_from_results(obj) + if summary is not None: + err_c, pt, ct, tt, avg_traj = summary + wall_ms, nq = _extract_wall_time_and_nq(obj) + nq_str = str(nq) if nq is not None else "-" + wms_str = f"{wall_ms / nq:.1f}" if wall_ms is not None and nq else "-" + if wall_ms is not None and nq: + tot_wall_ms += wall_ms + if nq is not None: + tot_nq += nq + traj_str = f"{avg_traj:.1f}" if avg_traj is not None else "-" + if avg_traj is not None and nq: + tot_traj_steps_sum += avg_traj * nq + tot_traj_steps_count += nq + rows.append((ds_short, nq_str, str(err_c), str(pt), str(ct), str(tt), wms_str, traj_str, "RESULTS")) + tot_err += err_c + tot_pt += pt + tot_ct += ct + tot_tt += tt + continue + + # Fallback to traces using trace_run_name/traces_dir from the results file if available. + traces_dir_eff, trace_run_name_eff = trace_ctx_by_dataset[ds_short] + trace_summary = _compute_llm_summary_from_traces(traces_dir_eff, trace_run_name_eff, ds_short) + if trace_summary is not None: + err_c, pt, ct, tt, avg_traj, num_traces = trace_summary + wall_ms, nq = _extract_wall_time_and_nq(obj) + nq_eff = nq if nq is not None else num_traces + nq_str = str(nq_eff) + wms_str = f"{wall_ms / nq_eff:.1f}" if wall_ms is not None and nq_eff else "-" + if wall_ms is not None and nq_eff: + tot_wall_ms += wall_ms + tot_nq += nq_eff + traj_str = f"{avg_traj:.1f}" if avg_traj is not None else "-" + if avg_traj is not None and nq_eff: + tot_traj_steps_sum += avg_traj * nq_eff + tot_traj_steps_count += nq_eff + rows.append((ds_short, nq_str, str(err_c), str(pt), str(ct), str(tt), wms_str, traj_str, "TRACES")) + tot_err += err_c + tot_pt += pt + tot_ct += ct + tot_tt += tt + else: + rows.append((ds_short, "-", "-", "-", "-", "-", "-", "-", "MISSING")) + continue + + # Results JSON missing: try traces with inferred run name. + trace_ctx_by_dataset[ds_short] = (traces_dir, inferred_trace_run_name) + trace_summary = _compute_llm_summary_from_traces(traces_dir, inferred_trace_run_name, ds_short) + if trace_summary is not None: + err_c, pt, ct, tt, avg_traj, num_traces = trace_summary + traj_str = f"{avg_traj:.1f}" if avg_traj is not None else "-" + rows.append((ds_short, str(num_traces), str(err_c), str(pt), str(ct), str(tt), "-", traj_str, "TRACES")) + tot_nq += num_traces + tot_err += err_c + tot_pt += pt + tot_ct += ct + tot_tt += tt + else: + rows.append((ds_short, "-", "-", "-", "-", "-", "-", "-", "MISSING")) + + # Pretty print. + headers = ( + "dataset", + "num_queries", + "llm_errors", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "wall_ms/q", + "avg_traj_steps", + "source", + ) + col_w = [max(len(headers[i]), max((len(r[i]) for r in rows), default=0)) for i in range(len(headers))] + table_width = sum(col_w) + 2 * (len(col_w) - 1) + title = f"LLM usage report: {results_root}" + rule_width = max(70, len(title), table_width) + + print("\n" + "=" * rule_width) + print(title) + print("=" * rule_width) + print( + " ".join( + [ + headers[0].ljust(col_w[0]), + headers[1].rjust(col_w[1]), + headers[2].rjust(col_w[2]), + headers[3].rjust(col_w[3]), + headers[4].rjust(col_w[4]), + headers[5].rjust(col_w[5]), + headers[6].rjust(col_w[6]), + headers[7].rjust(col_w[7]), + headers[8].ljust(col_w[8]), + ] + ) + ) + print("-" * rule_width) + + for ds_short, nq_s, err_s, pt_s, ct_s, tt_s, wms_s, traj_s, src in rows: + print( + " ".join( + [ + ds_short.ljust(col_w[0]), + nq_s.rjust(col_w[1]), + err_s.rjust(col_w[2]), + pt_s.rjust(col_w[3]), + ct_s.rjust(col_w[4]), + tt_s.rjust(col_w[5]), + wms_s.rjust(col_w[6]), + traj_s.rjust(col_w[7]), + src.ljust(col_w[8]), + ] + ) + ) + + avg_wall_str = f"{tot_wall_ms / tot_nq:.1f}" if tot_nq > 0 else "-" + avg_traj_str = f"{tot_traj_steps_sum / tot_traj_steps_count:.1f}" if tot_traj_steps_count > 0 else "-" + print("-" * rule_width) + print( + " ".join( + [ + "TOTAL".ljust(col_w[0]), + str(tot_nq).rjust(col_w[1]), + str(tot_err).rjust(col_w[2]), + str(tot_pt).rjust(col_w[3]), + str(tot_ct).rjust(col_w[4]), + str(tot_tt).rjust(col_w[5]), + avg_wall_str.rjust(col_w[6]), + avg_traj_str.rjust(col_w[7]), + "".ljust(col_w[8]), + ] + ) + ) + print("=" * rule_width + "\n") + + if list_errors: + print("=" * rule_width) + print("LLM errors (query id -> error)") + print("=" * rule_width) + any_errs = False + for ds in selected: + ds_short = short_by_full[ds] + traces_dir_eff, trace_run_name_eff = trace_ctx_by_dataset.get( + ds_short, (traces_dir, inferred_trace_run_name) + ) + errs_from_traces = _compute_llm_errors_from_traces( + traces_dir_eff, trace_run_name_eff, ds_short, max_error_len=max_error_len + ) + err_map = {qid: err for qid, err in (errs_from_traces or [])} + + # Also include any error ids recorded in the results JSON summary (even if the trace was deleted later). + err_ids_from_results: list[str] = [] + pi = _extract_pipeline_infos(results_obj_by_dataset.get(ds_short, None)) + if isinstance(pi, dict): + ids = pi.get("llm_error_query_ids", None) + if isinstance(ids, list) and all(isinstance(x, str) for x in ids): + err_ids_from_results = ids + + all_ids = set(err_map.keys()) | set(err_ids_from_results) + if not all_ids: + continue + + def _qid_key(x: str): + return int(x) if x.isdigit() else x + + any_errs = True + sorted_ids = sorted(all_ids, key=_qid_key) + print(f"{ds_short}: {len(sorted_ids)}") + for qid in sorted_ids: + err = err_map.get(qid, None) + if err is None: + err = "(trace missing or llm_error not present in trace file)" + print(f" {qid}: {err}") + if not any_errs: + print("(no llm_error entries found in traces)") + print("=" * rule_width + "\n") + + +@app.command() +def purge_llm_error_traces( + results_dir: Annotated[ + str, + typer.Option( + "--results-dir", + help="Pipeline results directory (e.g. results/AgenticPipelineV1__...). Used to locate trace_run_name.", + ), + ], + all_datasets: Annotated[ + bool, + typer.Option( + "--all-datasets", + "--all", + help="Purge across all available ViDoRe v3 datasets (overrides --datasets).", + ), + ] = False, + datasets: Annotated[ + Optional[str], + typer.Option( + "--datasets", + help="Optional comma-separated dataset filter (accepts either full names like 'vidore/vidore_v3_hr' " + "or shorts like 'vidore_v3_hr'). Defaults to all datasets.", + ), + ] = None, + traces_dir: Annotated[ + str, + typer.Option( + "--traces-dir", + help="Default traces root directory to use when results JSON does not specify traces_dir.", + ), + ] = "traces", + dry_run: Annotated[ + bool, + typer.Option( + "--dry-run/--no-dry-run", + help="If enabled (default), print what would be deleted without deleting anything.", + ), + ] = True, + yes: Annotated[ + bool, + typer.Option( + "--yes", + help="Required to actually delete files when using --no-dry-run.", + ), + ] = False, + print_query_ids: Annotated[ + bool, + typer.Option( + "--print-query-ids", + help="Print the query ids that would be purged per dataset (useful for logging).", + ), + ] = False, +): + """ + Delete per-query trace files where `pipeline_trace.llm_error` is present. + + This is a practical way to retrigger only queries that hit LLM failures, since evaluation + uses per-query trace caching: missing traces are rerun, existing valid traces are skipped. + """ + results_root = Path(results_dir) + if not results_root.exists() or not results_root.is_dir(): + raise typer.BadParameter(f"--results-dir must be an existing directory: {results_root}") + + if (not dry_run) and (not yes): + raise typer.BadParameter("Refusing to delete traces without --yes. Re-run with --no-dry-run --yes.") + + # Canonical dataset list. + ds_full = get_available_datasets() + short_by_full = {ds: dataset_trace_dir(ds) for ds in ds_full} + full_by_short = {v: k for k, v in short_by_full.items()} + + selected = list(ds_full) + if not all_datasets and datasets: + toks = [t.strip() for t in datasets.split(",") if t.strip()] + if not toks: + raise typer.BadParameter("--datasets was provided but empty.") + selected = [] + for tok in toks: + if tok in short_by_full: + selected.append(tok) + elif tok in full_by_short: + selected.append(full_by_short[tok]) + else: + raise typer.BadParameter( + f"Unknown dataset '{tok}'. Expected one of: {', '.join(short_by_full.values())}" + ) + + inferred_trace_run_name = results_root.name + total_marked = 0 + total_deleted = 0 + + print("\n" + "=" * 70) + mode = "DRY RUN (no deletions)" if dry_run else "DELETE" + print(f"Purge LLM error traces: {results_root} [{mode}]") + print("=" * 70) + + for ds in selected: + ds_short = short_by_full[ds] + result_path = results_root / f"{ds_short}.json" + + trace_run_name_eff = inferred_trace_run_name + traces_dir_eff = traces_dir + + if result_path.exists(): + try: + with open(result_path, "r") as f: + obj = json.load(f) + except Exception: + obj = None + if isinstance(obj, dict): + trn = obj.get("trace_run_name", None) + tdir = obj.get("traces_dir", None) + if isinstance(trn, str) and trn: + trace_run_name_eff = trn + if isinstance(tdir, str) and tdir: + traces_dir_eff = tdir + + trace_root = Path(traces_dir_eff) / trace_run_name_eff / ds_short + if not trace_root.exists() or not trace_root.is_dir(): + print(f"{ds_short}: no trace dir at {trace_root}") + continue + + marked_paths: list[Path] = [] + marked_qids: list[str] = [] + + for p in sorted(trace_root.glob("*.json")): + try: + with open(p, "r") as f: + obj = json.load(f) + except Exception: + continue + if not isinstance(obj, dict): + continue + pipeline_trace = obj.get("pipeline_trace", None) + if not isinstance(pipeline_trace, dict): + continue + if not isinstance(pipeline_trace.get("llm_error", None), str): + continue + + marked_paths.append(p) + qid = obj.get("query_id", None) + marked_qids.append(str(qid) if qid is not None else p.stem) + + if not marked_paths: + print(f"{ds_short}: 0 traces with llm_error") + continue + + total_marked += len(marked_paths) + print(f"{ds_short}: {len(marked_paths)} traces with llm_error ({trace_root})") + if print_query_ids: + print(f" query_ids: {', '.join(marked_qids)}") + + if not dry_run: + deleted = 0 + for p in marked_paths: + try: + p.unlink() + deleted += 1 + except Exception as e: + print(f" WARNING: failed to delete {p}: {type(e).__name__}: {e}") + total_deleted += deleted + print(f" deleted: {deleted}/{len(marked_paths)}") + + print("-" * 70) + if dry_run: + print(f"TOTAL marked for deletion: {total_marked}") + print("Re-run with: --no-dry-run --yes to delete.") + else: + print(f"TOTAL deleted: {total_deleted}/{total_marked}") + print("=" * 70 + "\n") + + +if __name__ == "__main__": + app() diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/__init__.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/__init__.py new file mode 100644 index 000000000..c8612ed8a --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Internalized NeMo agentic retrieval components. + +This package contains the minimal subset of code originally sourced from an +external repository (formerly imported via sys.path injection). It is maintained +in-tree so this repo can run without depending on that external checkout. +""" + +from .agent import Agent # noqa: F401 +from .configs import AgentConfig, LLMConfig # noqa: F401 +from .llm_handler import LLM # noqa: F401 +from . import tool_helpers # noqa: F401 diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/agent.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/agent.py new file mode 100644 index 000000000..e98691c70 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/agent.py @@ -0,0 +1,459 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Implements an LLM agent with ability to call functions.""" + +import asyncio +import copy +import json +import os +import warnings +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set +from uuid import uuid4 + +from jinja2 import Environment, FileSystemLoader, Template + +from . import image_explainer, llm_handler, query_rewriter, tool_helpers, utils +from .configs import AgentConfig +from .logging_utils import get_logger_with_config +from .selection_agent.selection_agent import SelectionAgent + +logger, _ = get_logger_with_config() + + +def is_context_exceeds_error(msg: dict) -> bool: + """Return True if the given message is an agent error because the context window is exceeded.""" + return msg["role"] == "agent_error" and "context" in msg["content"].lower() and "window" in msg["content"].lower() + + +class Agent: + def __init__( + self, + config: AgentConfig, + llm: llm_handler.LLM, + session_id: Optional[str] = None, + tool_wrapper: Optional[Callable] = None, + doc_transform_fn: Optional[Callable] = None, + tool_map: Optional[Dict[str, tool_helpers.BaseTool]] = None, + ) -> None: + """Implement an LLM agent loop with access to tools.""" + + self.message_history = list() + self.config = config + self.doc_transform_fn = doc_transform_fn + self.tool_wrapper = tool_wrapper + self.session_id = session_id + self.llm = llm + self.tool_map = tool_map + + self._curr_time = datetime.now() + self._uuid = uuid4().hex + + if self.config.user_msg_type not in ["simple", "with_results"]: + raise ValueError(f"user_msg_type must be 'simple' or 'with_results', got {self.config.user_msg_type!r}") + + # Load system prompt template + if Path(config.system_prompt).exists(): + prompt_path = Path(config.system_prompt) + system_prompt_template = Template(prompt_path.read_text().strip()) + else: + env = Environment( + loader=FileSystemLoader(Path(__file__).parent.joinpath("prompts").absolute().resolve().as_posix()) + ) + system_prompt_template = env.get_template(config.system_prompt) + + system_prompt = system_prompt_template.render( + with_init_docs=config.user_msg_type == "with_results", + enforce_top_k=config.enforce_top_k, + top_k=config.target_top_k, + extended_relevance=config.extended_relevance, + ) + + self._system_msg = {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]} + self.message_history = [self._system_msg] + self.current_user_msg = None + + self.auto_user_msg = ( + "Please continue on whatever approach you think is suitable.\n" + "If you think you have solved the task, please finish the interaction.\n" + "IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN RESPONSE.\n" + ) + + self.selection_agent = SelectionAgent( + llm=self.llm, + prompt_name=self.config.selection_prompt, + max_steps=self.config.selection_max_steps, + extended_relevance=self.config.extended_relevance, + ) + + self.steps = 0 + self.extra_data: Dict[str, Any] = {} + self.retrieved_docs: Set[str] = set() + self.retrieval_log: list[dict] = [] + self.exclude_docs: Set[str] = set() + self.api_response_extras: list = [] + + def reset(self) -> None: + self.message_history = [self._system_msg] + self.current_user_msg = None + self.steps = 0 + self._curr_time = datetime.now() + self._uuid = uuid4().hex + self.extra_data = {} + self.retrieved_docs = set() + self.retrieval_log = [] + self.exclude_docs = set() + self.api_response_extras = [] + + def get_llm_raw_io_subdir(self): + raw_log_subdir = self._curr_time.strftime("%d-%m-%y_%H-%M-%S-%f") + "_" + self._uuid[:8] + if self.session_id is not None: + raw_log_subdir = self.session_id + "_" + raw_log_subdir + return "main_agent/" + raw_log_subdir + + async def step(self) -> Optional[dict]: + """Take one step and process LLM outputs.""" + log_exp_name = self._curr_time.strftime("%d-%m_%H-%M") + "_" + self._uuid[:4] + if self.session_id is not None: + log_exp_name = self.session_id + "_" + log_exp_name + trace_prefix = os.environ.get("LOG_TRACE_PREFIX") + if trace_prefix is not None: + log_exp_name = str(trace_prefix) + "_" + log_exp_name + + logging_kwargs = {"step": self.steps, "subdir": self.get_llm_raw_io_subdir(), "log_exp_name": log_exp_name} + + params = { + "messages": copy.deepcopy(self.message_history), + "tools": [t.spec for t in self.tool_map.values()], + "logging_kwargs": logging_kwargs, + } + response = await self.llm.acompletion(**params, return_metadata=True) + if llm_handler.is_error(response): + self.message_history.append(utils.AgentErrorMessage(content=response).model_dump()) + return None + + metadata_kv = response["metadata_kv"] + io_log_data = response["io_log_kwargs"] + if "api_response_extras" in response: + self.api_response_extras.append(response["api_response_extras"]) + response = response["response"] + self.steps += 1 + + metadata_kv = {"S": str(self.steps), **metadata_kv} + log_str = " ".join([f"{k}: {v}" for k, v in metadata_kv.items()]) + logger.info(f"[{self.session_id}] " + log_str) + + if len(response.choices) != 1: + raise RuntimeError(f"Expected exactly 1 choice from LLM, got {len(response.choices)}") + if response.choices[0].finish_reason == "tool_calls": + conv_msg = {"content": [], "role": "assistant", "tool_calls": []} + for call_info in response.choices[0].message.tool_calls: + conv_msg["tool_calls"].append(call_info.model_dump()) + if response.choices[0].message.content is not None: + if not isinstance(response.choices[0].message.content, str): + raise TypeError("Expected string content from LLM, got non-text content") + conv_msg["content"].append({"type": "text", "text": response.choices[0].message.content}) + self.message_history.append(conv_msg) + elif response.choices[0].finish_reason == "stop": + self.message_history.append( + {"role": "assistant", "content": [{"type": "text", "text": response.choices[0].message.content}]} + ) + else: + self.message_history.append( + utils.AgentErrorMessage( + content=f"LLM failed with finish_reason '{response.choices[0].finish_reason}'" + ).model_dump() + ) + return io_log_data + + async def call_one_tool( + self, + fn_name: str, + fn_kwargs: Dict, + store_state: bool = True, + include_in_retrieval_logs: bool = True, + query_type: Optional[str] = "agent", + ): + """Call one tool and return the results.""" + old_sub_query = None + + if ( + self.config.use_query_rewriting + and fn_name == "retrieve" + and isinstance(self.tool_map[fn_name], tool_helpers.RetrieveToolBase) + ): + old_sub_query = fn_kwargs["query"] + new_sub_query = await query_rewriter.rewrite_query( + llm=self.llm, + main_query=self.current_user_msg, + sub_query=old_sub_query, + ) + fn_kwargs["query"] = new_sub_query + if store_state: + if "query_rewriting" not in self.extra_data: + self.extra_data["query_rewriting"] = {} + self.extra_data["query_rewriting"][old_sub_query] = new_sub_query + + async def tool_caller(**kw): + new_kwargs = {**fn_kwargs, **kw} + if self.tool_wrapper is not None: + return await self.tool_wrapper(tool=self.tool_map[fn_name], tool_kwargs=new_kwargs) + return await self.tool_map[fn_name].acall(**new_kwargs) + + if fn_name == "retrieve" and isinstance(self.tool_map[fn_name], tool_helpers.RetrieveToolBase): + fn_res = await tool_helpers.retrieve_with_guarantees( + tool_caller=tool_caller, + top_k=fn_kwargs.get("top_k", getattr(self.tool_map[fn_name], "_default_top_k", 5)), + seen_docids=(self.retrieved_docs if self.config.ensure_new_docs else set()), + exclude_docids=self.exclude_docs, + ) + + if self.doc_transform_fn is not None: + if isinstance(fn_res, str): + if "Error" not in fn_res: + raise RuntimeError(f"Unexpected non-list, non-error retrieve result: {fn_res!r}") + else: + fn_res = await asyncio.gather(*(self.doc_transform_fn(d) for d in fn_res)) + if isinstance(fn_res, list): + for i in range(len(fn_res)): + if fn_res[i]["id"] in self.retrieved_docs: + fn_res[i].pop("image", None) + fn_res[i].pop("text", None) + note = ( + "This document is retrieved before. See previous retrieval results for the content of this " + f"document (id: {fn_res[i]['id']})." + ) + fn_res[i]["note"] = note + else: + fn_res = await tool_caller() + + if fn_name == "retrieve" and isinstance(self.tool_map[fn_name], tool_helpers.RetrieveToolBase): + if self.config.use_image_explainer and not isinstance(fn_res, str): + image_descs = await asyncio.gather( + *( + image_explainer.explain_image( + llm=self.llm, + main_query=self.current_user_msg, + sub_query=fn_kwargs["query"], + image_b64=d.get("image", None), + prompt_type=self.config.image_explainer_prompt, + ) + for d in fn_res + ) + ) + for i in range(len(fn_res)): + if image_descs[i] is not None: + fn_res[i]["text"] = image_descs[i] + fn_res[i].pop("image", None) + + tmp_output_list: list = [] + if isinstance(fn_res, list): + for item in fn_res: + self.retrieved_docs.add(item["id"]) + tmp_output_list.append(item) + elif isinstance(fn_res, str) and not fn_res.startswith("Error"): + try: + res_obj = json.loads(fn_res) + for item in res_obj: + self.retrieved_docs.add(item["id"]) + tmp_output_list.append(item) + except Exception: + pass + if include_in_retrieval_logs: + self.retrieval_log.append( + { + "input": fn_kwargs, + "query_before_rewriting": old_sub_query, + "query_type": query_type, + "output": tmp_output_list, + } + ) + content = tool_helpers.retrieve_output_to_msg_content(output=fn_res) + else: + if not isinstance(fn_res, str): + fn_res = json.dumps(fn_res) + content = [{"type": "text", "text": fn_res}] + return content + + async def process_tool_calls(self) -> List[Dict[str, Any]]: + if len(self.message_history[-1].get("tool_calls", [])) < 1: + raise RuntimeError("There are no tool calls to process") + + tool_messages = [] + for call_info in self.message_history[-1]["tool_calls"]: + fn_name = call_info["function"]["name"] + content = None + try: + fn_kwargs = json.loads(call_info["function"]["arguments"]) + except Exception: + content = "Error parsing tool arguments. Tool arguments not correctly formatted." + if content is None: + content = await self.call_one_tool(fn_name=fn_name, fn_kwargs=fn_kwargs, store_state=True) + tool_messages.append({"content": content, "role": "tool", "tool_call_id": call_info["id"], "name": fn_name}) + return tool_messages + + def is_last_msg_error(self) -> bool: + """Check if the last message in the history is an error.""" + if self.message_history[-1]["role"] == "agent_error": + err_msg = self.message_history[-1]["content"] + if self.config.only_warn_on_error: + warnings.warn(err_msg) + return True + raise RuntimeError(err_msg) + return False + + async def run_for_input( + self, + query: str, + task_instruction: Optional[str] = None, + task_info: Optional[Any] = None, + exclude_docids: Optional[Set] = None, + ) -> Dict[str, Any]: + """Run the agent for a given user message.""" + self.reset() + if self.tool_map is None: + raise RuntimeError("Agent requires tool_map to be provided by the caller.") + + await self.llm.log_extra_data_log_dir(subdir=self.get_llm_raw_io_subdir(), info=task_info) + self.current_user_msg = query + + self.exclude_docs = set() if exclude_docids is None else set(exclude_docids) + + if task_instruction is None: + task_instruction = "" + task_instruction = task_instruction.strip() + if task_instruction != "" and not task_instruction.lower().startswith("instruct"): + task_instruction = f"Instruct: {task_instruction}" + if task_instruction != "": + task_instruction = task_instruction.strip() + "\n" + task_inst_query = f"{task_instruction}Query:\n{query}" + + if self.config.user_msg_type == "simple": + self.message_history.append({"role": "user", "content": [{"type": "text", "text": task_inst_query}]}) + elif self.config.user_msg_type == "with_results": + res = await self.call_one_tool( + fn_name="retrieve", + fn_kwargs={"query": query}, + store_state=False, + query_type="main", + ) + user_msg = { + "role": "user", + "content": [{"type": "text", "text": task_inst_query}, {"type": "text", "text": "Retrieved Documents:"}] + + res, + } + self.message_history.append(user_msg) + else: + raise ValueError(f"`{self.config.user_msg_type}` is not a valid user_msg_type.") + + io_log_data = None + while True: + if self.config.max_steps is not None and self.steps >= self.config.max_steps: + self.message_history.append( + utils.AgentErrorMessage(content="Agent reached maximum allowed iterations").model_dump() + ) + + if self.is_last_msg_error(): + break + + new_io_log_data = await self.step() + if new_io_log_data is not None: + io_log_data = new_io_log_data + + if not self.is_last_msg_error(): + _tc = self.message_history[-1].get("tool_calls", []) + if _tc is None or len(_tc) == 0: + self.message_history.append( + {"role": "user", "content": [{"type": "text", "text": self.auto_user_msg}]} + ) + else: + tool_calls = [tc["function"]["name"] for tc in self.message_history[-1]["tool_calls"]] + tool_messages = await self.process_tool_calls() + self.message_history.extend(tool_messages) + ended_successfully = False + if self.config.end_tool in tool_calls: + end_tool = self.tool_map[self.config.end_tool] + _correct_val = end_tool.correct_call_return_value # type: ignore[attr-defined] + for tm in tool_messages: + if tm["name"] == self.config.end_tool and tm["content"][0]["text"] == _correct_val: + ended_successfully = True + break + if ended_successfully: + break + else: + break + + await self.llm.log_extra_data_log_dir( + subdir=self.get_llm_raw_io_subdir(), + info=self.api_response_extras, + filename="api_response_extras.json", + ) + + if not self.llm.config.instant_log and io_log_data is not None: + await llm_handler.awrite_json(**io_log_data["input_json"]) + await llm_handler.awrite_json(**io_log_data["output_json"]) + + return await self.conclude_task(query=query, task_info=task_info) + + async def conclude_task(self, query: str, task_info: Optional[Any] = None) -> Dict[str, Any]: + """Calculate final top_k results if needed and return artifacts.""" + output_artifacts: Dict[str, Any] = { + "agent_trajectories": self.message_history, + "retrieval_log": self.retrieval_log, + } + if len(self.extra_data): + output_artifacts["agent_extra_data"] = self.extra_data + + if self.config.main_agent_only: + return output_artifacts + + selection_input = [] + rrf_input = [] + + seen_docids = set() + for ret_data in self.retrieval_log: + rrf_input.append(ret_data["output"]) + for doc in ret_data["output"]: + if doc["id"] not in seen_docids: + selection_input.append(doc) + seen_docids.add(doc["id"]) + + rrf_res: Optional[Dict[str, float]] = None + if self.config.calculate_rrf or self.is_last_msg_error(): + rrf_res = utils.rrf_from_subquery_results(retrieval_results=rrf_input) + output_artifacts["rrf_scores"] = rrf_res + + selection_topk_list = self.config.selection_topk_list + if len(selection_topk_list) == 0 and self.is_last_msg_error(): + selection_topk_list = [5, 10] + + if len(selection_topk_list) != 0: + for _ in range(2): + selection_output = await asyncio.gather( + *( + self.selection_agent.select_topk( + query=query, + documents=selection_input, + top_k=k, + session_id=self.session_id, + task_info=task_info, + ) + for k in selection_topk_list + ) + ) + if not any([a is None and is_context_exceeds_error(h[-1]) for h, a in selection_output]): + break + drop_items = len(selection_input) // 4 + if rrf_res is None: + rrf_res = utils.rrf_from_subquery_results(retrieval_results=rrf_input) + output_artifacts["rrf_scores"] = rrf_res + least_rrf_scores = dict(sorted(rrf_res.items(), key=lambda x: x[1])[:drop_items]) + least_rrf_scores_set = set(list(least_rrf_scores.keys())) + selection_input = [i for i in selection_input if i["id"] not in least_rrf_scores_set] + for topk, (msg_hist, final_ans) in zip(selection_topk_list, selection_output): + output_artifacts[f"top{topk}_agent_trajectories"] = msg_hist + if final_ans is not None: + output_artifacts[f"top{topk}_selection_result"] = final_ans + return output_artifacts diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/configs.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/configs.py new file mode 100644 index 000000000..d7604a46c --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/configs.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict + + +class AgentConfig(BaseModel): + """Configuration for the agent.""" + + model_config = ConfigDict(extra="forbid") + system_prompt: str = "02_v1.j2" + enforce_top_k: bool = True + target_top_k: Optional[int] = 10 + extended_relevance: bool = True + max_steps: Optional[int] = None + disable_think: bool = False + only_warn_on_error: bool = True + end_tool: str = "final_results" + use_image_explainer: bool = False + use_query_rewriting: bool = False + image_explainer_prompt: str = "simple" + ensure_new_docs: bool = True + user_msg_type: str = "with_results" + selection_topk_list: List[int] = [5, 10] + calculate_rrf: bool = True + selection_prompt: str = "01_v0.j2" + selection_max_steps: int = 10 + main_agent_only: bool = False + + +class LLMConfig(BaseModel): + """Configuration for the LLM (LiteLLM wrapper).""" + + model_config = ConfigDict(extra="forbid") + model: str + api_key: Optional[str] = None + tool_choice: str = "auto" + reasoning_effort: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + num_retries: Optional[int] = 4 + max_completion_tokens: Optional[int] = None + raw_log_pardir: Optional[str] = None + instant_log: bool = False + strict_error_handling: bool = False + drop_params: bool = False + allowed_openai_params: Optional[List[str]] = None diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/image_explainer.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/image_explainer.py new file mode 100644 index 000000000..e97b06880 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/image_explainer.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +from . import llm_handler + +_SYSTEM_PROMPT_WITH_QUERY = """You are the image understanding component of an agentic retrieval system. + + +You are part of a larger agentic retrieval system, which works like following. The retrieval agent is given a user query and a search tool. The goal of the retrieval agent is to find all the related documents and images related to the given query. It primarily does this by making multiple search calls (parallel or sequential) with different related queries to explore the corpus. These search calls might find the related documents or they might provide new information that help the agent generate better sub-queries in its next round of search calls. Unfortunately the main search agent is text only and cannot process and understand images. Your job is to help with this. + + + +After each search call, you are given the main (i.e., user's) query, the sub-query used for the most recent search call, and one of the retrieved images for this sub-query. +Your task is to understand the main query, the sub-query, and the image. +Then, generate a text for the text-only search agent that provides all the helpful information about the image as if the search agent itself is seeing this image. +Your text should describe the image in the context of the main and sub-query. + +Notes and best practices: +- start your answer by saying what the image is (e.g., "this is an image of a slide that shows ...", "this is an image of a text book page that ...", "this is image of a cat that is standing on ...") then continue with the rest of your output. +- include information from image that is relevant to the main query or the sub-query. +- include any information that could help the search agent future searches (i.e., the information that helps it to generate better sub-queries in the next search call). +- also if all or part of the information requested in the query or sub-query is not present in the image, explicitely mention that that specific piece of information is not mentioned. +- do not provide any judgment on whether the image is useful/related or not. Your task is to just help the agent see what is in the image (avoid saying things like x is or is not useful for this query, etc.). +- if the image is text heavy, you should provide most of the information to the agent even if it is remotely important. + +**VERY IMPORTANT**: Be precise and faithful to the image. Do **NOT** include information that is not present in the image. + + + +You should generate only the target text without any additional announcements, prefixes, or suffixes. All your output for the image is directly given to the search agent. +""" + + +_SYSTEM_PROMPT_SIMPLE = """Your task is to convert an image to text. + +I want to input the screenshot from some documents to a Large Language Model (LLM). But, my LLM can only process text as input. Your task is to take an image as the input and generate a text equivalent of the image that I can pass to my LLM. +For images that only contain text data, this is identical to OCR (optical character recognition). +But, for documents that involve figures, charts, tables, etc., you should describe all the details in these documents such that reading the text provides similar information to seeing the original image. + +You should generate only the target text without any additional comments, explanations, etc. All your output is used as the text description for the image.""" + + +async def explain_image( + llm: llm_handler.LLM, + main_query: str, + sub_query: str, + image_b64: str, + prompt_type: str = "simple", +) -> Optional[str]: + if image_b64 is None: + return None + if prompt_type == "with_query": + main_query = main_query.replace("Query:", "").replace("query:", "") + sub_query = sub_query.replace("Query:", "").replace("query:", "") + txt_msg = f"Main Query: {main_query}\n\nCurrent Sub-query: {sub_query}" + msg_list = [ + {"role": "system", "content": _SYSTEM_PROMPT_WITH_QUERY}, + { + "role": "user", + "content": [ + {"type": "text", "text": txt_msg}, + {"type": "image_url", "image_url": {"url": image_b64}}, + ], + }, + ] + elif prompt_type == "simple": + msg_list = [ + {"role": "system", "content": _SYSTEM_PROMPT_SIMPLE}, + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": image_b64}}], + }, + ] + else: + raise ValueError(f"Unknown prompt_type: {prompt_type}") + + response = await llm.acompletion( + messages=msg_list, + return_metadata=True, + ) + response = response["response"] + return response.choices[0].message.content diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/llm_handler.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/llm_handler.py new file mode 100644 index 000000000..faab62c10 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/llm_handler.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Implements a class and related utilities to make LLM API calls.""" + +import asyncio +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +import aiofiles +import litellm +from litellm.exceptions import BadRequestError, ContentPolicyViolationError, ContextWindowExceededError +from litellm.types.utils import ModelResponse + +from .configs import LLMConfig +from .logging_utils import get_logger_with_config + +logger, _ = get_logger_with_config() + +LLM_ERROR_PREFIX = "LLMError:" + +# Reduce noisy provider/help banners on handled API errors. +if hasattr(litellm, "suppress_debug_info"): + litellm.suppress_debug_info = True + + +def is_error(response: Any) -> bool: + """Return True if the response from the LLM handler is an error string.""" + return isinstance(response, str) and response.startswith(LLM_ERROR_PREFIX) + + +def normalize_messages_for_api(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Normalize message content from list-of-content-blocks to plain strings. + + Some OpenAI-compatible endpoints only accept string content for certain roles. + This converts text-only `content` lists (e.g., [{"type":"text","text":"..."}]) + into a plain string. Messages with non-text blocks (e.g., image_url) are left as-is. + """ + + normalized: List[Dict[str, Any]] = [] + for msg in messages: + msg = dict(msg) + content = msg.get("content") + if isinstance(content, list): + text_parts: List[str] = [] + all_text = True + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text_parts.append(str(item.get("text", ""))) + else: + all_text = False + break + if all_text: + if len(text_parts) == 0: + msg["content"] = None + elif len(text_parts) == 1: + msg["content"] = text_parts[0] + else: + msg["content"] = "\n".join(text_parts) + normalized.append(msg) + return normalized + + +def write_json(obj: Any, log_dir: Union[str, Path], filename: Union[str, Path]): + """Write an object to a json file.""" + if log_dir is None: + return + path = Path(log_dir, filename) + path.parent.mkdir(exist_ok=True, parents=True) + with open(path, "w") as f: + json.dump(obj, f, indent=2) + + +async def awrite_json(obj: Any, log_dir: Union[str, Path], filename: Union[str, Path]): + """Write an object to a json file async.""" + if log_dir is None: + return + path = Path(log_dir, filename) + path.parent.mkdir(exist_ok=True, parents=True) + async with aiofiles.open(path.as_posix(), "w") as f: + await f.write(json.dumps(obj, indent=2)) + + +async def make_acall_with_ratelimit_pause(*args: Any, **kwargs: Any): + """Make an LLM call and pause on ratelimits (fixed 60s, max 3 attempts).""" + i = 0 + while True: + try: + return await litellm.acompletion(*args, **kwargs) + except litellm.RateLimitError: + if i > 2: + raise + logger.info(f"Rate Limit. Sleep for a min (i={i}).") + await asyncio.sleep(60) + i += 1 + + +class LLM: + def __init__(self, llm_config: LLMConfig) -> None: + """LLM client to make calls to chat completion endpoint via LiteLLM.""" + self.config = llm_config + + self.completion_kwargs = dict( + model=self.config.model, + tool_choice=self.config.tool_choice, + base_url=self.config.base_url, + api_version=self.config.api_version, + num_retries=self.config.num_retries, + max_completion_tokens=self.config.max_completion_tokens, + ) + if self.config.drop_params: + self.completion_kwargs["drop_params"] = self.config.drop_params + if self.config.allowed_openai_params is not None and len(self.config.allowed_openai_params) != 0: + self.completion_kwargs["allowed_openai_params"] = self.config.allowed_openai_params + if self.config.reasoning_effort is not None: + self.completion_kwargs["reasoning_effort"] = self.config.reasoning_effort + if self.config.api_key is not None: + self.completion_kwargs["api_key"] = self.config.api_key + + self._api_key = None + if self.config.api_key is not None and self.config.api_key.strip().startswith("os.environ/"): + self._api_key = os.environ[self.config.api_key.strip().removeprefix("os.environ/")] + + async def log_extra_data_log_dir( + self, + subdir: Optional[str] = None, + info: Optional[Any] = None, + filename: str = "extra_info.json", + ) -> None: + if info is None or self.config.raw_log_pardir is None or subdir is None: + return + if not filename.endswith(".json"): + raise ValueError(f"filename must end with '.json', got {filename!r}") + json_log_dir = Path(self.config.raw_log_pardir, subdir) + await awrite_json(obj=info, log_dir=json_log_dir, filename=filename) + + def pre_completion(self, messages: list[dict], tools: Optional[list[dict]] = None, **kwargs: Any) -> tuple: + """Prepare request kwargs and logging context before calling LiteLLM.""" + logging_kwargs = kwargs.pop("logging_kwargs", None) + if logging_kwargs is None or "step" not in logging_kwargs: + curr_step = uuid4().hex + else: + curr_step = logging_kwargs["step"] + + json_log_dir = None + if self.config.raw_log_pardir is not None and logging_kwargs is not None and "subdir" in logging_kwargs: + json_log_dir = Path(self.config.raw_log_pardir, logging_kwargs["subdir"]) + + # Merge in steps so per-call kwargs can intentionally override defaults + # without raising `TypeError` on duplicate keys. + request_kwargs = {"messages": messages} + request_kwargs.update(self.completion_kwargs) + request_kwargs.update(kwargs) + if tools is not None: + request_kwargs["tools"] = tools + else: + request_kwargs.pop("tool_choice", None) + + io_log_kwargs = { + "input_json": dict(obj=request_kwargs, log_dir=json_log_dir, filename=f"{curr_step}_prompt.json") + } + logging_ctx = { + "logging_kwargs": logging_kwargs, + "json_log_dir": json_log_dir, + "curr_step": curr_step, + } + return request_kwargs, logging_ctx, io_log_kwargs + + def post_completion(self, request_kwargs: dict, response: ModelResponse, logging_ctx: dict): + """Finish logging after the completion method is finished.""" + metadata_kv = {} + additional_headers = getattr(response, "_hidden_params", {}).get("additional_headers", {}) or {} + remaining_tpm = additional_headers.get("llm_provider-x-ratelimit-remaining-tokens", "-1") + remaining_rq = additional_headers.get("llm_provider-x-ratelimit-remaining-requests", "-1") + metadata_kv["TPM"] = f"{int(remaining_tpm):,}" + metadata_kv["RQ"] = f"{int(remaining_rq):,}" + + json_log_dir = logging_ctx["json_log_dir"] + curr_step = logging_ctx["curr_step"] + + io_log_kwargs = {} + io_log_kwargs["output_json"] = dict( + obj=response.model_dump(), log_dir=json_log_dir, filename=f"{curr_step}_response.json" + ) + return response, metadata_kv, io_log_kwargs + + async def acompletion(self, messages: list[dict], tools: Optional[list[dict]] = None, **kwargs: Any): + """Call the chat completion endpoint and return results.""" + io_log_kwargs: dict[str, Any] = {} + return_metadata = kwargs.pop("return_metadata", False) + + request_kwargs, logging_ctx, pre_io = self.pre_completion(messages=messages, tools=tools, **kwargs) + io_log_kwargs.update(pre_io) + if self.config.instant_log: + await awrite_json(**io_log_kwargs["input_json"]) + + try: + if self._api_key is not None: + request_kwargs["api_key"] = self._api_key + response = await make_acall_with_ratelimit_pause(**request_kwargs) + except (ContentPolicyViolationError, ContextWindowExceededError, BadRequestError) as e: + if isinstance(e, BadRequestError): + if not ( + "ContentPolicyViolationError".lower() in str(e).lower() + or "ContextWindowExceededError".lower() in str(e).lower() + or ("context" in str(e).lower() and "window" in str(e).lower()) + ): + raise + if self.config.strict_error_handling: + raise + print(LLM_ERROR_PREFIX + " " + str(e)) + return LLM_ERROR_PREFIX + " " + str(e) + + _, metadata_kv, post_io = self.post_completion( + request_kwargs=request_kwargs, response=response, logging_ctx=logging_ctx + ) + io_log_kwargs.update(post_io) + + if self.config.instant_log: + await awrite_json(**io_log_kwargs["output_json"]) + + if return_metadata: + output = dict(response=response, metadata_kv=metadata_kv, io_log_kwargs=io_log_kwargs) + try: + r = response.model_dump() + r.pop("choices", None) + output["api_response_extras"] = r + except Exception: + pass + return output + return response diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/logging_utils.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/logging_utils.py new file mode 100644 index 000000000..c498e6f5b --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/logging_utils.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from BatsResearch/trove (Apache-2.0): +# https://github.com/BatsResearch/trove/blob/main/src/trove/logging_utils.py +"""Simple utilities to create python loggers.""" + +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, Union + +from rich.logging import RichHandler + + +@dataclass +class LoggingModuleConfig: + LOGGING_DISABLE: bool = False + LOG_LEVEL: str = "INFO" + LOGGING_DISABLE_RICH: bool = False + + +config = LoggingModuleConfig( + LOG_LEVEL=os.environ.get("LOG_LEVEL", "INFO"), + LOGGING_DISABLE=os.environ.get("LOGGING_DISABLE", "false").lower() == "true", + LOGGING_DISABLE_RICH=os.environ.get("LOGGING_DISABLE_RICH", "false").lower() == "true", +) + + +# LOGGING_FORMAT_STR_RICH = "%(funcName)s() %(message)s" +LOGGING_FORMAT_STR_RICH = "%(message)s" +LOGGING_FORMAT_STR = "%(asctime)s %(module)s:%(lineno)s, [%(funcName)s] (%(levelname)s)- %(message)s" +LOGGING_TIME_FORMAT_STR = "%H:%M:%S" + +STR_LOG_LEVEL_TO_INT = { + "CRITICAL": logging.CRITICAL, + "DEBUG": logging.DEBUG, + "ERROR": logging.ERROR, + "FATAL": logging.FATAL, + "INFO": logging.INFO, + "NOTSET": logging.NOTSET, + "WARN": logging.WARN, + "WARNING": logging.WARNING, +} +INT_LOG_LEVEL_TO_STR = {v: k for k, v in STR_LOG_LEVEL_TO_INT.items()} + + +def ensure_str_log_level(level: Union[str, int]) -> str: + """Ensures log level is valid and converts it to ``str`` levels recognized by python logging.""" + if isinstance(level, str): + if level.upper() not in STR_LOG_LEVEL_TO_INT: + raise ValueError(f"Log level can only be one of '{list(STR_LOG_LEVEL_TO_INT.keys())}'. Got '{level}'") + return level.upper() + if isinstance(level, int): + if level not in STR_LOG_LEVEL_TO_INT.values(): + raise ValueError(f"Log level can only be one of '{list(INT_LOG_LEVEL_TO_STR.keys())}'. Got '{level}'") + return INT_LOG_LEVEL_TO_STR[level] + raise TypeError(f"Log level can only be 'int' or 'str'. Got '{type(level)}'") + + +class LoggerConfig: + def __init__(self, level: Union[int, str]) -> None: + self.level = ensure_str_log_level(level) + + def is_debug(self) -> bool: + return STR_LOG_LEVEL_TO_INT[self.level] == logging.DEBUG + + def is_info(self) -> bool: + return STR_LOG_LEVEL_TO_INT[self.level] == logging.INFO + + def is_warning(self) -> bool: + return STR_LOG_LEVEL_TO_INT[self.level] == logging.WARNING + + +class LoggerWrapper: + def __init__(self, logger: logging.Logger) -> None: + self._logger = logger + + def info(self, *args, **kwargs): + if not config.LOGGING_DISABLE: + return self._logger.info(*args, **kwargs) + + def debug(self, *args, **kwargs): + if not config.LOGGING_DISABLE: + return self._logger.debug(*args, **kwargs) + + def warning(self, *args, **kwargs): + if not config.LOGGING_DISABLE: + return self._logger.warning(*args, **kwargs) + + def error(self, *args, **kwargs): + if not config.LOGGING_DISABLE: + return self._logger.error(*args, **kwargs) + + def log(self, *args, **kwargs): + if not config.LOGGING_DISABLE: + return self._logger.log(*args, **kwargs) + + def critical(self, *args, **kwargs): + if not config.LOGGING_DISABLE: + return self._logger.critical(*args, **kwargs) + + +def rpath(path: Union[Path, str, os.PathLike]) -> str: + """Make sure path starts with ``/`` or ``./`` so rich will highlight it in logs.""" + path = Path(path).as_posix() + if not (path.startswith("/") or path.startswith(".")): + path = "./" + path + return path + + +def get_logger_with_config( + name: Optional[str] = "ART", + log_level: Optional[str] = None, + rank: Optional[int] = None, + force: bool = False, +) -> Tuple[Union[logging.Logger, LoggerWrapper], LoggerConfig]: + """Creates and returns a logger instance and a config object.""" + logger = logging.getLogger(name) + log_level = ensure_str_log_level(config.LOG_LEVEL if log_level is None else log_level) + if force and len(logger.handlers): + for h in logger.handlers[:]: + logger.removeHandler(h) + if not len(logger.handlers): + logger.setLevel(getattr(logging, log_level)) + if config.LOGGING_DISABLE_RICH: + handler = logging.StreamHandler() + handler.setLevel(getattr(logging, log_level)) + format_str = LOGGING_FORMAT_STR + if rank is not None: + format_str = format_str.replace("%(asctime)s", f"%(asctime)s ") + if name: + format_str = f"[{name}] " + format_str + handler.setFormatter(logging.Formatter(format_str, LOGGING_TIME_FORMAT_STR)) + else: + handler = RichHandler( + level=getattr(logging, log_level), + omit_repeated_times=False, + locals_max_length=None, + locals_max_string=None, + ) + handler.setLevel(getattr(logging, log_level)) + format_str = LOGGING_FORMAT_STR_RICH + if rank is not None: + format_str = f" " + LOGGING_FORMAT_STR_RICH + if name: + format_str = f"[{name}] " + format_str + handler.setFormatter(logging.Formatter(format_str, LOGGING_TIME_FORMAT_STR)) + logger.addHandler(handler) + logging_config = LoggerConfig(level=log_level) + return logger, logging_config diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/00_default.j2 b/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/00_default.j2 new file mode 100644 index 000000000..90ce03438 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/00_default.j2 @@ -0,0 +1 @@ +You are a helpful assistant. Use the "retrieve" tool to find all the documents related to the given query. diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/01_v0.j2 b/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/01_v0.j2 new file mode 100644 index 000000000..f054bb163 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/01_v0.j2 @@ -0,0 +1,19 @@ +You are a retrieval agent, which uses retrieval to help users find the documents they need. + + +Your goal is to help users find documents they related to their search queries. You should be thorough and find all documents relevant to the user's query. If the user's query is a question, you should not answer the question yourself. Instead, you should find the related documents for the given query. + + + +* You are given a retrieval tool, powered by a dense embedding model, that takes a text query and returns the most similar documents. +* You can call the search tool multiple times. +* Search for related documents to the user's query from different angles. +* If needed, revise your search queries based on the documents you find in previous steps. +* Once you find the relevant documents, report the ID of the relevant documents by calling the "final_results" tool, which also ends the interaction. + + + +* You should be thorough and find all related documents. +* While you can use the retrieval tools as many times as you want, it is an expensive tool. So, try to be efficient and find the documents in as few searches as possible. +* The goal is to increase the **Recall** of your search attempt. So, if multiple documents are relevant to the given query, you should find and report all of them even if only a subset of them is enough for answering the query. + diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/02_v1.j2 b/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/02_v1.j2 new file mode 100644 index 000000000..4da14085d --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/prompts/02_v1.j2 @@ -0,0 +1,47 @@ +You are a retrieval agent that finds all documents related to a given query. + + +You are given a search query and a list of documents retrieved for that query. Your task is to write new queries and use the given search tool to find *ALL* the related and somewhat related documents to the given query (i.e., maximize recall). +If the user's query is a question, you should not answer the question yourself. Instead, you should find the related documents for the given query. + + +{% if extended_relevance %} + +- You should be careful, in the context of this task, what it means to be a "query", "document", and "relevant" can sometimes be very complex and might not follow the traditional definition of these terms in standard information retrieval. +- In standard retrieval, a query is usually a user question (like a web search query), the document is some sort of content that provides information (e.g., a web page), and these two are considered relevant if the document provides information that answers the user's query. +- However, in our setting, this could be different. Here are some examples: + * the query is a programming problem and documents are programming language syntax references. A document is relevant if it contains the reference for the programming syntax used for solving the problem. + * both query and documents are descriptions programming problems and a query and document are relevant if the same approach is used to solve them. + * the query is a math problem and documents are theorems. Relevant documents (theorems) are the ones that are useful for solving the math problem. + * the query and document are both math problems. A query and a document are relevant if the same theorem is used for solving them. + * the query is a task description (e.g., for an API programmer) and documents are descriptions of available APIs. Relevant documents (e.g., APIs) are the ones needed for completing the task. +- This is not an exhaustive list. These are just some examples to show you the complexity of queries, documents, and the concept of relevance in this task. +- Note that even here, the relevant documents are still the ones that are useful for a user who is searching for the given query. But the relation is more nuanced. +- You should analyze the query and some of the available documents. And then reason about what could be a meaningful definition of relevance in this case, and what the user could be looking for. +- Moreover, sometimes, the query could be even a prompt that is given to a Large Language Model (LLM) and the user wants to find the useful documents for the LLM that help answering/solving this prompt. + + +{% endif %} + +- You are given a retrieval tool, powered by a dense embedding model, that takes a text query and returns the most similar documents. +{%- if extended_relevance %} +- As explained above, reason and figure out what the meaning of relevance is in this case, and what could be relevant and useful information for the given query. +{%- endif %} +- You can call the search tool multiple times. +- Search for related documents to the user's query from different angles. +- If needed, revise your search queries based on the documents you find in previous steps. +- Once you are confident that you have found all the related and somewhat related documents and there are no more related documents in the corpus, call the "final_results" tool to finish the task. +{%- if enforce_top_k %} +- When calling "final_results", you must select exactly the {{ top_k }} most relevant documents among all documents you have retrieved. +{%- endif %} +- When calling the "final_results" tool, the list of documents must be sorted in the decreasing level of relevance to the query. I.e., the first document is the most relevant to the query, the second document is the second most relevant to the query, and so on. + + + + +- You should be thorough and find all related and somewhat related documents. +- The goal is to increase the **Recall** of your search attempt. So, if multiple documents are relevant to the given query, you should find and report all of them even if only a subset of them is enough for answering the query. +{%- if with_init_docs %} +- **TIP**: you can look at the list of documents retrieved using the original query and think what other queries you can use to find the potentially related documents that are missing in these results. +{%- endif %} + diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/query_rewriter.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/query_rewriter.py new file mode 100644 index 000000000..81e9730f4 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/query_rewriter.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from . import llm_handler + +_SYSTEM_PROMPT = """You are the query rewriting component of an agentic retrieval system. + + +You are part of a larger agentic retrieval system, which works like following. The main retrieval agent is given a user query and a search tool. The goal of the retrieval agent is to find all the related documents related to the given query. It primarily does this by making multiple search calls with different related queries to explore the corpus. + + + + +Importantly, the main agent's search tool is powered by a **dense embedding** model (often trained on human-style queries). Unfortunately, the queries that the main agent generates are often different from the embedding model's training data and fail to find the correct documents. + + + + +You are given one of the sub-queries generated by the agent. You should infer the agent's intent and rewrite the sub-query such that it is more similar to the embedding model's training distribution and lead to better retrieval results. In addition to the generated sub-query, you are also givne the user's original query, so that you have a better understanding of the type of information that we are trying to retrieve. The user query is only provided for your information. You should not mention new details from the user query in the rewritten query if it is not present in the sub-query. Note that you only want to improve the language, style, and writing of the sub-query and NOT it is semantic content or meaning. + +Stay faithful to the given sub-query. I emphasize **DO NOT INCLUDE ANY EXTRA INFORMATION OR INTENT FROM THE USER QUERY IN THE REWRITTEN QUERY**. + +Also important: **INCLUDE ALL THE INFORMATION FROM THE SUB-QUERY IN THE REWRITTEN QUERY. DO NOT REMOVE ANY INFORMATION FROM THE SUB-QUERY.** + + + + +- Preserve meaning. of the sub-query. Do **NOT** add facts or speculate. +- Only add light, non-speculative disambiguation that is implied by the text (e.g., expand "D&I" to "Discourse and Identity (D&I)"; clarify "City Vision" as the South African newspaper if it's named alongside Gugulethu). Do **not** invent facts. +- One clear intent. in the rewritten query, e.g., `[who/what], [relation/action], [time/place], [object/type], etc.` +- Precisely include all information from the sub-query in the rewritten query (e.g., include exact dates if they are mentioned) +- The rewritten query should be a full sentence or direct question. + + + + +Only output the rewritten query. Do not ask questions or add commentary. +""" + + +async def rewrite_query(llm: llm_handler.LLM, main_query: str, sub_query: str) -> str: + main_query = main_query.replace("Query:", "").replace("query:", "").strip() + sub_query = sub_query.replace("Query:", "").replace("query:", "").strip() + txt_msg = f"Main Query: {main_query}\n\nSub-query: {sub_query}" + msg_list = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": [{"type": "text", "text": txt_msg}]}, + ] + + response = await llm.acompletion( + messages=msg_list, + return_metadata=True, + ) + response = response["response"] + return response.choices[0].message.content diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/__init__.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/__init__.py new file mode 100644 index 000000000..b8f3d3ac6 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Selection agent subpackage for NeMo agentic retrieval.""" diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_agent.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_agent.py new file mode 100644 index 000000000..9d97cb10f --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_agent.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from jinja2 import Environment, FileSystemLoader, Template + +from .. import llm_handler, utils +from ..logging_utils import get_logger_with_config +from . import tool_helpers + +logger, _ = get_logger_with_config() + + +async def log_io(io_log_data: dict, llm_instant_log: bool) -> None: + """Log the LLM's raw IO if needed.""" + if llm_instant_log: + return + await llm_handler.awrite_json(**io_log_data["input_json"]) + await llm_handler.awrite_json(**io_log_data["output_json"]) + + +class SelectionAgent: + def __init__( + self, + llm: llm_handler.LLM, + prompt_name: str, + max_steps: int, + extended_relevance: bool = False, + ) -> None: + self.llm = llm + self.max_steps = max_steps + self.extended_relevance = extended_relevance + + if Path(prompt_name).exists(): + self.system_prompt_template = Template(Path(prompt_name).read_text().strip()) + else: + env = Environment( + loader=FileSystemLoader( + Path(__file__).parent.joinpath("selection_prompts").absolute().resolve().as_posix() + ) + ) + self.system_prompt_template = env.get_template(prompt_name) + + self.auto_user_msg = ( + "Please continue on whatever approach you think is suitable.\n" + "If you think you have solved the task, please finish the interaction.\n" + "IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN RESPONSE.\n" + ) + + think_tool = tool_helpers.SelectionThinkTool(extended_relevance=self.extended_relevance) + self.tool_map = {think_tool.name: think_tool} + + async def select_topk( + self, + query: str, + documents: List[Dict[str, Any]], + top_k: int, + session_id: str, + task_info: Optional[Any] = None, + ): + steps = 0 + message_history: list = [] + + seen_docids: list[str] = [] + user_msg: Dict[str, Any] = { + "role": "user", + "content": [{"type": "text", "text": f"Query:\n{query}"}, {"type": "text", "text": "Candidate Documents:"}], + } + for doc in documents: + if doc["id"] in seen_docids: + continue + seen_docids.append(doc["id"]) + user_msg["content"].append({"type": "text", "text": f"Doc ID: {doc['id']}"}) + if doc.get("text", "").strip() != "": + user_msg["content"].append({"type": "text", "text": f"Doc Text: {doc['text']}"}) + if doc.get("image", None) is not None and str(doc["image"]).strip() != "": + user_msg["content"].append({"type": "image_url", "image_url": {"url": doc["image"]}}) + feasible_topk = min(int(top_k), len(set(seen_docids))) + + message_history.append( + { + "role": "system", + "content": self.system_prompt_template.render( + top_k=feasible_topk, extended_relevance=self.extended_relevance + ).strip(), + } + ) + message_history.append(user_msg) + + finish_tool = tool_helpers.LogSelectedDocs(top_k=feasible_topk, candidate_docids=seen_docids) + tool_map = {**self.tool_map, finish_tool.name: finish_tool} + tools = [t.spec for t in tool_map.values()] + _curr_time = datetime.now() + _uuid = uuid4().hex + + raw_log_subdir = _curr_time.strftime("%d-%m-%y_%H-%M-%S-%f") + "_" + _uuid[:8] + if session_id is not None: + raw_log_subdir = session_id + "_" + raw_log_subdir + raw_log_subdir = f"select_{top_k}_agent/" + raw_log_subdir + + await self.llm.log_extra_data_log_dir(subdir=raw_log_subdir, info=task_info) + log_exp_name = _curr_time.strftime("%d-%m_%H-%M") + "_" + _uuid[:4] + if session_id is not None: + log_exp_name = session_id + "_" + log_exp_name + trace_prefix = os.environ.get("LOG_TRACE_PREFIX") + if trace_prefix is not None: + log_exp_name = str(trace_prefix) + "_" + log_exp_name + log_exp_name = f"SL_{top_k}_" + log_exp_name + + api_response_extras: list = [] + + while True: + logging_kwargs = {"step": steps, "subdir": raw_log_subdir, "log_exp_name": log_exp_name} + + response = await self.llm.acompletion( + messages=message_history, + tools=tools, + logging_kwargs=logging_kwargs, + return_metadata=True, + ) + if llm_handler.is_error(response): + message_history.append(utils.AgentErrorMessage(content=response).model_dump()) + return message_history, None + + io_log_data = response["io_log_kwargs"] + if "api_response_extras" in response: + api_response_extras.append(response["api_response_extras"]) + response = response["response"] + + steps += 1 + if len(response.choices) != 1: + raise RuntimeError(f"Expected exactly 1 choice from LLM, got {len(response.choices)}") + + if response.choices[0].finish_reason == "tool_calls": + conv_msg = {"content": [], "role": "assistant", "tool_calls": []} + for call_info in response.choices[0].message.tool_calls: + conv_msg["tool_calls"].append(call_info.model_dump()) + message_history.append(conv_msg) + elif response.choices[0].finish_reason == "stop": + message_history.append( + { + "role": "assistant", + "content": [{"type": "text", "text": response.choices[0].message.content}], + } + ) + else: + message_history.append( + utils.AgentErrorMessage( + content=f"LLM failed with finish_reason '{response.choices[0].finish_reason}'" + ).model_dump() + ) + await log_io(io_log_data=io_log_data, llm_instant_log=self.llm.config.instant_log) + await self.llm.log_extra_data_log_dir( + subdir=raw_log_subdir, info=api_response_extras, filename="api_response_extras.json" + ) + return message_history, None + + if self.max_steps is not None and steps >= self.max_steps: + message_history.append( + utils.AgentErrorMessage(content="Selection Agent reached maximum allowed iterations").model_dump() + ) + await log_io(io_log_data=io_log_data, llm_instant_log=self.llm.config.instant_log) + await self.llm.log_extra_data_log_dir( + subdir=raw_log_subdir, info=api_response_extras, filename="api_response_extras.json" + ) + return message_history, None + + if len(message_history[-1]["content"]) != 0: + message_history.append({"role": "user", "content": [{"type": "text", "text": self.auto_user_msg}]}) + else: + should_end = False + tool_messages: list = [] + for call_info in message_history[-1]["tool_calls"]: + fn_name = call_info["function"]["name"] + err_msg = None + try: + fn_kwargs = json.loads(call_info["function"]["arguments"]) + except Exception: + err_msg = "Error parsing tool arguments. Tool arguments not correctly formatted." + if fn_name not in tool_map or err_msg is not None: + if err_msg is None: + err_msg = f"Error. Tool '{fn_name}' does not exist." + tool_messages.append( + {"content": err_msg, "role": "tool", "tool_call_id": call_info["id"], "name": fn_name} + ) + else: + res = tool_map[fn_name].call(**fn_kwargs) + if fn_name == finish_tool.name and res == finish_tool.correct_call_return_value: + should_end = True + end_kwargs = fn_kwargs + else: + tool_messages.append( + {"content": res, "role": "tool", "tool_call_id": call_info["id"], "name": fn_name} + ) + if should_end: + await log_io(io_log_data=io_log_data, llm_instant_log=self.llm.config.instant_log) + await self.llm.log_extra_data_log_dir( + subdir=raw_log_subdir, info=api_response_extras, filename="api_response_extras.json" + ) + return message_history, end_kwargs + message_history.extend(tool_messages) diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_prompts/00_demo.j2 b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_prompts/00_demo.j2 new file mode 100644 index 000000000..a2233c05b --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_prompts/00_demo.j2 @@ -0,0 +1 @@ +You are given a query and a list of documents. Select the {{ top_k }} most relevant documents for the given query. diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_prompts/01_v0.j2 b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_prompts/01_v0.j2 new file mode 100644 index 000000000..6438ededb --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/selection_prompts/01_v0.j2 @@ -0,0 +1,43 @@ +You are a document re-ranker agent, which is the final stage in an information retrieval pipeline. + + +You are given a search query and a list of retrieved candidate documents that are potentially relevant to the given query. Your goal is to help the users identify the most relevant documents to the given query from the list of candidate documents. + + +{% if extended_relevance %} + +- You should be careful, in the context of this task, what it means to be a "query", "document", and "relevant" can sometimes be very complex and might not follow the traditional definition of these terms in standard re-ranking and retrieval. +- In standard re-ranking/retrieval, a query is usually a user question (like a web search query), the document is some sort of content that provides information (e.g., a web page), and these two are considered relevant if the document provides information that answers the user's query. +- However, in our setting, this could be different. Here are some examples: + * the query is a programming problem and documents are programming language syntax references. A document is relevant if it contains the reference for the programming syntax used for solving the problem. + * both query and documents are descriptions programming problems and a query and document are relevant if the same approach is used to solve them. + * the query is a math problem and documents are theorems. Relevant documents (theorems) are the ones that are useful for solving the math problem. + * the query and document are both math problems. A query and a document are relevant if the same theorem is used for solving them. + * the query is a task description (e.g., for an API programmer) and documents are descriptions of available APIs. Relevant documents (e.g., APIs) are the ones needed for completing the task. +- This is not an exhaustive list. These are just some examples to show you the complexity of queries, documents, and the concept of relevance in this task. +- Note that even here, the relevant documents are still the ones that are useful for a user who is searching for the given query. But the relation is more nuanced. +- You should analyze the query and the available documents. And then reason about what could be a meaningful definition of relevance in this case, and what the user could be looking for. +- Moreover, sometimes, the query could be even a prompt that is given to a Large Language Model (LLM) and the user wants to find the useful documents for the LLM that help answering/solving this prompt. + + +{% endif %} + +* You are given a search query and a list of candidate documents. You have access to the ID and content of each candidate document. +* You should read the query carefully and understand it. +{%- if extended_relevance %} +* As explained above, reason and figure out what the meaning of relevance is in this case, and what could be relevant and useful information for the given query. +{%- endif %} +* Then you should compare the query with each one of the candidate documents. In this comparison, you want to identify if the document is relevant/useful for the given query and to what extent. +* Select the {{ top_k }} most relevant candidate documents for the given query. +* Note that just selecting the most relevant documents is not enough. You should identify the relative level of relevance between the query and selected documents. This helps you sort the selected documents later based on how relevant they are to the query. +* Once you have this information, you should call the "log_selected_documents" function to report the final results and signal the completion of the task. +* Note that the selected document IDs must be reported in the decreasing level of relevance. I.e., The first document in the list is the most relevant, the second is the second most relevant, and so on. This is similar to what a search engine (e.g., Google Search) does (it shows you the relevant results in a sorted order, where the most relevant results appear on top of the list). + + + + +* you have access to a "think" tool that you can use for complex thinking and analysis. Here are examples of cases where the think tool might be useful: + - complex analysis and thinking to understand the meaning and intent of the query. E.g., what is the user trying to find with this query? what kind of information is helpful for the user? + - extended thinking to analyze how each candidate document could or could not be relevant to the given query. + - reasoning to identify the relative level of relevance between the query and selected documents. It helps you sort the documents correctly when reporting the final answer. + diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/tool_helpers.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/tool_helpers.py new file mode 100644 index 000000000..cffb2510f --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/selection_agent/tool_helpers.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import List, Optional + +from ..tool_helpers import BaseTool, ToolError + + +class SelectionThinkTool(BaseTool): + """Tool for selection agent thinking.""" + + def __init__(self, extended_relevance: bool = False): + if extended_relevance: + ext = [ + "- When it is difficult to understand what is the intent of the user and what they are trying to find with this query, use this tool to think about potential definitions of relevance that could be meaningful/useful to the user for this task.", + "- If the intention of the user is vague especially given the available documents, use this tool to think how you should decide what documents are relevant and what the metric of relevance is.", + ] + ext = "\n".join(ext) + "\n" + else: + ext = "" + self.spec_dict = { + "type": "function", + "function": { + "name": "think", + "description": f"""Use the tool to think about something. It will not obtain new information or make any changes, but just log the thought. Use it when complex reasoning or brainstorming is needed. + +Common use cases: +{ext}- When processing a complex query, use this tool to organize your thoughts and think about how each document might be related to the given query. +- If a query is vague or hard to understand, you can use this tool to think about clues in the query that help you identify the connections between a document and the query. +- You can use this tool to think what pieces of information in each document are the most important or relevant for the given query. + +The tool simply logs your thought process for better transparency and does not make any changes.""", + "parameters": { + "type": "object", + "properties": { + "thought": { + "type": "string", + "description": "The thought to log.", + } + }, + "required": ["thought"], + }, + }, + } + + def _spec(self): + return self.spec_dict + + def _call(self, thought: str): + return "Your thought has been logged." + + async def _acall(self, thought: str): + return "Your thought has been logged." + + +class LogSelectedDocs(BaseTool): + """Tool for reporting selected doc IDs and ending the interaction.""" + + _name: Optional[str] = "log_selected_documents" + + def __init__(self, top_k: int, candidate_docids: List[str]): + self.correct_call_return_value = "The results have been successfully logged and the interaction ended." + self.top_k = int(top_k) + self.allowed_doc_ids = set(candidate_docids) + + desc = f"""Records the selected documents and signals the end of the task. + +Use this tool when you have carefully considered the candidate documents and have selected exactly the {self.top_k} most relevant documents to the query. + +The message argument should explain your reasoning and justification for selecting this specific set of documents as the most relevant to the query. + +**Note**: the list of document IDs passed as the `doc_ids` argument must be sorted in the decreasing level of relevance. In other words, the first document in `doc_ids` list is the most relevant to the query, the second document is the second most relevant document, and so on.""" + + self.spec_dict = { + "type": "function", + "function": { + "name": "log_selected_documents", + "description": desc, + "parameters": { + "type": "object", + "required": ["message", "doc_ids"], + "properties": { + "message": { + "type": "string", + "description": "A message for the user to explain why you think the selected are the most relevant to the query. Also, explain why this specific order of document IDs satisfies the most to least relevant ordering criteria.", + }, + "doc_ids": { + "type": "array", + "items": {"type": "string"}, + "description": f"The ID of the {self.top_k} most relevant documents to the given query. The IDs must be sorted in the decreasing level of relevance. I.e., the first document must be the most relevant to the query.", + }, + }, + }, + }, + } + + def _spec(self): + return self.spec_dict + + async def _acall(self, doc_ids: List[str], message: str): + return self._call(doc_ids=doc_ids, message=message) + + def _call(self, doc_ids: List[str], message: str): + if not isinstance(message, str): + raise TypeError(f"The `message` argument must be a string. Got `{type(message)}` type.") + if not isinstance(doc_ids, list): + raise TypeError(f"The `doc_ids` argument must be a list. Got `{type(doc_ids)}` type.") + if len(doc_ids) != self.top_k: + raise ToolError(f"You must select at least {self.top_k} documents. Got {len(doc_ids)} documents.") + if not all(isinstance(i, str) for i in doc_ids): + raise TypeError("Items in `doc_ids` must be of type string (i.e., python's `str` type).") + for i in doc_ids: + if i not in self.allowed_doc_ids: + raise ToolError(f"Document with ID `{i}` is not among the candidate documents.") + return self.correct_call_return_value diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/tool_helpers.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/tool_helpers.py new file mode 100644 index 000000000..76649cbba --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/tool_helpers.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tool definition and utilities for managing tools for an LLM. + +This is an internalized subset of the original external agent code. All MCP and +fastmcp integration has been removed; this module only provides: +- Base tool abstraction (BaseTool) +- Local tools used by the agent (ThinkTool, FinalResults) +- Retrieval output formatting helpers (retrieve_output_to_msg_content) +- Retrieval over-fetch-and-filter helper (retrieve_with_guarantees) +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Set, Union + + +class ToolError(Exception): + """Tool-specific exception raised for invalid tool usage.""" + + +class BaseTool(ABC): + """Define a tool to be passed to the LLM.""" + + _name: Optional[str] = None + + @abstractmethod + def _spec(self) -> dict: + raise NotImplementedError + + def _call(self, **kwargs: Any) -> Any: + raise NotImplementedError + + async def _acall(self, **kwargs: Any) -> Any: + raise NotImplementedError + + def call(self, **kwargs: Any) -> Any: + try: + output = self._call(**kwargs) + except (TypeError, ToolError) as e: + output = f"Error calling '{self.name}' tool. {type(e).__name__}: {str(e)}" + return output + + async def acall(self, **kwargs: Any) -> Any: + try: + output = await self._acall(**kwargs) + except (TypeError, ToolError) as e: + output = f"Error calling '{self.name}' tool. {type(e).__name__}: {str(e)}" + return output + + def __call__(self, **kwargs: Any) -> Any: + return self.call(**kwargs) + + @property + def spec(self) -> dict: + return self._spec() + + @property + def name(self) -> str: + return str(self.spec["function"]["name"]) + + +class RetrieveToolBase(BaseTool): + """Marker base class for retrieve-like tools. + + The Agent uses `isinstance(tool, RetrieveToolBase)` to apply retrieval-specific + behaviors (dedup/guarantees/output formatting). + """ + + _default_top_k: int = 20 + + +class ThinkTool(BaseTool): + """Tool that allows the LLM to think with output tokens.""" + + def __init__(self, extended_relevance: bool = False): + if extended_relevance: + ext = [ + "- When it is difficult to understand what is the intent of the user and what they are trying to find with this query, use this tool to think about potential definitions of relevance that could be meaningful/useful to the user for this task.", + "- If the intention of the user is vague especially given the available documents, use this tool to think how you should decide what documents are relevant and what the metric of relevance is.", + ] + ext = "\n".join(ext) + "\n" + else: + ext = "" + self.spec_dict = { + "type": "function", + "function": { + "name": "think", + "description": f"""Use the tool to think about something. It will not obtain new information or make any changes, but just log the thought. Use it when complex reasoning or brainstorming is needed. + +Common use cases: +{ext}- When processing a complex query, use this tool to organize your thoughts and think about the sub queries that you need to search for to find the relevant information +- If a query is vague is very difficult to find information for it, you can use this tool to think about clues in the query that you can use to narrow down the search and spot relevant pieces of information. +- When finding related documents that help you create better search queries in the next step, use this tool to think about what pieces of information from these documents are helpful to search for. +- When you fail to find any related information to the query, use this tool to think about other search strategies that you can take to retrieve the related documents + +The tool simply logs your thought process for better transparency and does not make any changes.""", + "parameters": { + "type": "object", + "properties": { + "thought": { + "type": "string", + "description": "The thought to log.", + } + }, + "required": ["thought"], + }, + }, + } + + def _spec(self) -> dict: + return self.spec_dict + + def _call(self, thought: str) -> str: + return "Your thought has been logged." + + async def _acall(self, thought: str) -> str: + return "Your thought has been logged." + + +class FinalResults(BaseTool): + """Tool for logging selected document IDs and signaling the end of the interaction.""" + + _name: Optional[str] = "final_results" + + def __init__(self, top_k: Optional[int] = None): + self.correct_call_return_value = "The results have been successfully logged and the interaction ended." + self.top_k = top_k + + tk_ins = "" + if top_k is not None: + tk_ins = f"- You must choose exactly {top_k} document IDs when calling this function." + + desc = f"""Signals the completion of the search process for the current query. + +Use this tool when: +- You have found all the relevant documents to the query. +- Despite several attempts, you cannot find good documents for the given query. + +The message should include: +- A brief summary of your exploration and the results +- Explanation if the search was unsuccessful + +When reporting the selected document IDs, make sure: +- the list of document IDs is sorted in the decreasing level of relevance to the query. I.e., the first document in the list is the most relevant to the query, the second is the second most relevant to the query, and so on. +{tk_ins} + +The successful_search field should be set to true if you believed you have found the most relevant documents to the user's query, and false otherwise. And partial if it is in between.""" + self.spec_dict = { + "type": "function", + "function": { + "name": "final_results", + "description": desc, + "parameters": { + "type": "object", + "required": ["doc_ids", "message", "search_successful"], + "properties": { + "message": { + "type": "string", + "description": "A message for the user to explain why you think you found all the related documents and there is no related document is missing. Also, include a short description of your exploration process. If your attempts to find related documents were unsuccessful, explain why.", + }, + "doc_ids": { + "type": "array", + "items": {"type": "string"}, + "description": "List of document IDs that are relevant to the user's query sorted descending by their level of relevance to the user's query. I.e., the first document is the most relevant to the query, the second is the second most relevant to the query, and so on.", + }, + "search_successful": { + "type": "string", + "enum": ["true", "false", "partial"], + "description": "Whether you managed to find all the related documents to the query.", + }, + }, + }, + }, + } + + def _spec(self) -> dict: + return self.spec_dict + + async def _acall(self, doc_ids: List[str], message: str, search_successful: str) -> str: + return self._call(doc_ids=doc_ids, message=message, search_successful=search_successful) + + def _call(self, doc_ids: List[str], message: str, search_successful: str) -> str: + if not isinstance(message, str): + raise TypeError(f"The `message` argument must be a string. Got `{type(message)}` type.") + if not isinstance(doc_ids, list): + raise TypeError(f"The `doc_ids` argument must be a list. Got `{type(doc_ids)}` type.") + if len(doc_ids) == 0: + raise ToolError("`doc_ids` cannot be empty. You must choose at least one relevant document.") + if not all(isinstance(i, str) for i in doc_ids): + raise TypeError("Items in `doc_ids` must be of type string (i.e., python's `str` type).") + if not isinstance(search_successful, str): + raise TypeError(f"The `search_successful` argument must be a string. Got `{type(search_successful)}` type.") + if search_successful not in ["true", "false", "partial"]: + raise ToolError( + f"`search_successful` must be one of `true`, `false`, or `partial`. Got `{search_successful}` instead." + ) + if self.top_k is not None and len(doc_ids) != self.top_k: + raise ToolError( + f"`doc_ids` must contain exactly {self.top_k} documents. But got {len(doc_ids)} document IDs instead." + ) + return self.correct_call_return_value + + +def retrieve_output_to_msg_content(output: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + """Convert retrieve output into LLM message content blocks.""" + if isinstance(output, str): + if output.startswith("Error"): + return [{"type": "text", "text": output}] + raise RuntimeError("Received unexpected value from the retrieve tool.") + + content_list: List[Dict[str, Any]] = [] + for doc_in in output: + doc = {**doc_in} + if doc.get("text", "").strip() == "": + doc.pop("text", None) + img = doc.pop("image", None) + content_list.append({"type": "text", "text": json.dumps(doc)}) + if img is not None: + content_list.append({"type": "image_url", "image_url": {"url": img}}) + return content_list + + +async def retrieve_with_guarantees( + tool_caller: Callable[..., Any], + top_k: int, + seen_docids: Set[str], + exclude_docids: Set[str], +) -> Union[str, List[Dict[str, Any]]]: + """Call retrieve, ensuring `top_k` new docs and excluding `exclude_docids`.""" + seen_docids = set(seen_docids) + exclude_docids = set(exclude_docids) + res = await tool_caller(__art_top_k=top_k + len(seen_docids) + len(exclude_docids)) + if isinstance(res, str) and res.startswith("Error"): + return res + + res_list = list(sorted(res, key=lambda x: x["score"], reverse=True)) + + output_list: List[Dict[str, Any]] = [] + num_new = 0 + for item in res_list: + if item["id"] in exclude_docids: + continue + rec = {**item} + if rec["id"] not in seen_docids: + num_new += 1 + output_list.append(rec) + if num_new >= top_k: + break + return output_list diff --git a/retrieval-bench/src/retrieval_bench/nemo_agentic/utils.py b/retrieval-bench/src/retrieval_bench/nemo_agentic/utils.py new file mode 100644 index 000000000..e5e6805a3 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/nemo_agentic/utils.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from typing import Any, Dict, List + +from pydantic import BaseModel, ConfigDict + + +class AgentErrorMessage(BaseModel): + # a message type for when the agent hit an error + model_config = ConfigDict(extra="forbid") + role: str = "agent_error" + content: str + + +def rrf_from_subquery_results(retrieval_results: List[List[Dict[str, Any]]], k: int = 60) -> Dict[str, float]: + """Calculates the RRF score for retrieval results.""" + sorted_results: List[List[str]] = [] + for ret_rs in retrieval_results: + sorted_docs = sorted(ret_rs, key=lambda x: x["score"], reverse=True) + sorted_results.append([i["id"] for i in sorted_docs]) + return rrf(sorted_results=sorted_results, k=k) + + +def rrf(sorted_results: List[List[str]], k: int = 60) -> Dict[str, float]: + """Calculates the Reciprocal Rank Fusion score for each document.""" + rrf_scores = defaultdict(float) + for result_list in sorted_results: + for i, item in enumerate(result_list): + rank = i + 1 + rrf_scores[item] += 1 / (rank + k) + return dict(rrf_scores) diff --git a/retrieval-bench/src/retrieval_bench/pipeline_evaluation/__init__.py b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/__init__.py new file mode 100644 index 000000000..77b85bf43 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Python CLI tool for evaluating pipelines on vidore v3 datasets. +""" + +__version__ = "0.1.0" + +from vidore_benchmark.pipeline_evaluation.base_pipeline import BasePipeline +from vidore_benchmark.pipeline_evaluation.dataset_loader import print_dataset_info +from retrieval_bench.pipeline_evaluation.dataset_loader import ( + get_available_datasets, + load_vidore_dataset, +) +from retrieval_bench.pipeline_evaluation.evaluator import aggregate_results, evaluate_retrieval + +__all__ = [ + "BasePipeline", + "evaluate_retrieval", + "aggregate_results", + "load_vidore_dataset", + "get_available_datasets", + "print_dataset_info", +] diff --git a/retrieval-bench/src/retrieval_bench/pipeline_evaluation/dataset_loader.py b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/dataset_loader.py new file mode 100644 index 000000000..6778cb554 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/dataset_loader.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Dataset loader for vidore v3 benchmark datasets. + +Handles downloading and preparing vidore v3 datasets from HuggingFace, +including queries, corpus images, and ground truth relevance judgments. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset +from vidore_benchmark.pipeline_evaluation.dataset_loader import ( + get_available_datasets as _upstream_get_available_datasets, + load_vidore_dataset as _upstream_load_vidore_dataset, +) + + +BRIGHT_TASKS: tuple[str, ...] = ( + "biology", + "earth_science", + "economics", + "psychology", + "robotics", + "stackoverflow", + "sustainable_living", + "leetcode", + "pony", + "aops", + "theoremqa_questions", + "theoremqa_theorems", +) + + +def _repo_root() -> Path: + # /src/retrieval_bench/pipeline_evaluation/dataset_loader.py + return Path(__file__).resolve().parents[3] + + +def _preferred_bright_cache_dir() -> Optional[str]: + """ + Prefer using the sibling BRIGHT repo cache if available. + + Workflow assumption: BRIGHT repo lives at ../BRIGHT relative to this repo root. + """ + bright_cache = _repo_root().parent / "BRIGHT" / "cache" + if bright_cache.exists() and bright_cache.is_dir(): + return str(bright_cache) + return None + + +def _load_bright_split( + *, + task: str, + split: str = "test", + language: Optional[str] = None, +) -> Tuple[ + List[str], + List[str], + List[str], + List[Any], + List[str], + Dict[str, Dict[str, int]], + Dict[str, str], + Dict[str, List[str]], +]: + """ + Load one BRIGHT task as a dataset compatible with our pipeline evaluator. + + Default setting only: + - examples/config: 'examples' + - corpus/config: 'documents' + - qrels key: 'gold_ids' + """ + if split != "test": + raise ValueError("BRIGHT datasets only support split='test' in this integration.") + + if language and str(language).strip().lower() not in ("english", "en"): + raise ValueError("BRIGHT datasets are English-only in this integration (use --language english or omit).") + + task = str(task).strip() + if task not in BRIGHT_TASKS: + raise ValueError(f"Unknown BRIGHT task '{task}'. Expected one of: {', '.join(BRIGHT_TASKS)}") + + cache_dir = _preferred_bright_cache_dir() + + last_err: Optional[Exception] = None + examples_ds = None + docs_ds = None + for ds_name in ("xlangai/bright", "xlangai/BRIGHT"): + try: + examples_ds = load_dataset(ds_name, "examples", cache_dir=cache_dir)[task] + docs_ds = load_dataset(ds_name, "documents", cache_dir=cache_dir)[task] + last_err = None + break + except Exception as e: + last_err = e + examples_ds = None + docs_ds = None + continue + + if examples_ds is None or docs_ds is None: + raise RuntimeError(f"Failed to load BRIGHT task '{task}' from HuggingFace: {last_err}") from last_err + + corpus_ids: List[str] = [] + corpus_images: List[Any] = [] + corpus_texts: List[str] = [] + for dp in docs_ds: + did = str(dp.get("id", "")) + corpus_ids.append(did) + corpus_images.append(None) + corpus_texts.append(str(dp.get("content", ""))) + + query_ids: List[str] = [] + queries: List[str] = [] + qrels: Dict[str, Dict[str, int]] = {} + query_languages: Dict[str, str] = {} + excluded_ids_by_query: Dict[str, List[str]] = {} + + for e in examples_ds: + qid = str(e.get("id", "")) + q = str(e.get("query", "")) + excluded = e.get("excluded_ids", None) + if not isinstance(excluded, list): + excluded = ["N/A"] + excluded_list = [str(x) for x in excluded] + + gold = e.get("gold_ids", None) + if not isinstance(gold, list): + gold = [] + gold_ids = [str(x) for x in gold] + + overlap = set(excluded_list).intersection(set(gold_ids)) + overlap.discard("N/A") + if overlap: + raise ValueError(f"BRIGHT data error: excluded_ids overlaps gold_ids for query_id={qid}: {sorted(overlap)}") + + query_ids.append(qid) + queries.append(q) + query_languages[qid] = "english" + excluded_ids_by_query[qid] = excluded_list + + qrels[qid] = {gid: 1 for gid in gold_ids} + + if not queries: + raise ValueError(f"No queries found in BRIGHT task '{task}'.") + if not corpus_texts: + raise ValueError(f"No corpus documents found in BRIGHT task '{task}'.") + if not any(v for v in qrels.values()): + raise ValueError(f"No relevance judgments found in BRIGHT task '{task}'.") + + return query_ids, queries, corpus_ids, corpus_images, corpus_texts, qrels, query_languages, excluded_ids_by_query + + +def load_vidore_dataset(dataset_name: str, split: str = "test", language: str = None) -> Tuple[ + List[str], + List[str], + List[str], + List[Any], + List[str], + Dict[str, Dict[str, int]], + Dict[str, str], + Dict[str, List[str]], +]: + """ + Load a dataset for the pipeline evaluator. + + ViDoRe datasets are delegated to upstream vidore-benchmark loader. + BRIGHT tasks are handled locally and include excluded-ids semantics. + """ + if str(dataset_name).startswith("bright/"): + task = str(dataset_name).split("/", 1)[1] + return _load_bright_split(task=task, split=split, language=language) + + query_ids, queries, corpus_ids, corpus_images, corpus_texts, qrels, query_languages = _upstream_load_vidore_dataset( + dataset_name=dataset_name, + split=split, + language=language, + ) + excluded_ids_by_query = {str(qid): ["N/A"] for qid in query_ids} + return query_ids, queries, corpus_ids, corpus_images, corpus_texts, qrels, query_languages, excluded_ids_by_query + + +def get_available_datasets() -> List[str]: + """ + Get list of available vidore v3 datasets. + + Returns: + List of dataset names that can be used with load_vidore_dataset() + """ + datasets = list(_upstream_get_available_datasets()) + for name in (f"bright/{t}" for t in BRIGHT_TASKS): + if name not in datasets: + datasets.append(name) + return datasets diff --git a/retrieval-bench/src/retrieval_bench/pipeline_evaluation/evaluator.py b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/evaluator.py new file mode 100644 index 000000000..05cc43505 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/evaluator.py @@ -0,0 +1,520 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Core evaluation orchestration using pytrec_eval. +""" + +import time +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytrec_eval +from vidore_benchmark.pipeline_evaluation.base_pipeline import BasePipeline +from retrieval_bench.pipeline_evaluation.tracing import ( + dataset_trace_dir, + default_trace_run_name, + extract_run_and_time_ms, + load_trace_file, + trace_path, + write_trace_file, +) + + +def evaluate_retrieval( + pipeline: BasePipeline, + query_ids: List[str], + queries: List[str], + corpus_ids: List[str], + corpus_images: List[Any], + corpus_texts: List[str], + qrels: Dict[str, Dict[str, int]], + metrics: List[str] = None, + track_time: bool = True, + traces_dir: str = "traces", + trace_run_name: Optional[str] = None, + dataset_name: Optional[str] = None, + split: str = "test", + language: Optional[str] = None, + query_ids_selector: Optional[str] = None, + excluded_ids_by_query: Optional[Dict[str, List[str]]] = None, +) -> Dict[str, Dict[str, float]]: + """ + Evaluate a pipeline using pytrec_eval. + + Args: + pipeline: Instance of BasePipeline with user's pipeline logic + query_ids: List of query identifiers + queries: List of query texts + corpus_ids: List of corpus item identifiers + corpus_images: List of corpus images (PIL.Image objects) + corpus_texts: List of corpus texts (markdown strings) + qrels: Ground truth relevance judgments in pytrec_eval format + {query_id: {doc_id: relevance_score}} + metrics: List of metrics to calculate (default: ['ndcg_cut_10']) + track_time: Whether to track retrieval time (default: True) + traces_dir: Directory for per-query trace files (default: 'traces') + trace_run_name: Name for this trace run. If None, auto-generated from pipeline class and model ID. + dataset_name: Dataset identifier used for trace directory layout. + split: Data split (default: 'test') + language: Optional language filter passed through to traces. + query_ids_selector: Optional selector string recorded in traces for provenance. + excluded_ids_by_query: Per-query doc IDs to exclude from scored results (BRIGHT semantics). + + Returns: + Dictionary of evaluation results per query: + { + 'q1': {'ndcg_cut_10': 0.85, ...}, + 'q2': {'ndcg_cut_10': 0.72, ...}, + ... + } + If track_time=True, also includes timing information in a special '_timing' key. + """ + if metrics is None: + metrics = ["ndcg_cut_10"] + + wall_start = time.perf_counter() if track_time else None + + def _filtered_run_for_query(qid: str, run_q: Any) -> Any: + if not isinstance(run_q, dict): + return run_q + if not isinstance(excluded_ids_by_query, dict): + return run_q + excluded = excluded_ids_by_query.get(str(qid), None) + if not isinstance(excluded, list): + return run_q + out = dict(run_q) + for did in set(str(x) for x in excluded): + if did != "N/A": + out.pop(did, None) + return out + + # Dataset context (for trace directory layout). + dataset_name_eff = dataset_name or getattr(pipeline, "dataset_name", None) or "unknown_dataset" + + # Trace run name: always enabled; default is __. + trace_run_name_eff = trace_run_name or default_trace_run_name(pipeline) + dataset_dir = dataset_trace_dir(dataset_name_eff, split=split, language=language) + + # Build a quick lookup for query text by id. + query_by_id: Dict[str, str] = {qid: q for qid, q in zip(query_ids, queries)} + + # Load cached per-query runs/timing if present; otherwise schedule query for execution. + cached_run: Dict[str, Dict[str, float]] = {} + per_query_time_ms: Dict[str, float] = {} + to_run_ids: List[str] = [] + to_run_queries: List[str] = [] + + for qid in query_ids: + p = trace_path(traces_dir, trace_run_name_eff, dataset_dir, qid) + trace_obj = load_trace_file(p) + extracted = extract_run_and_time_ms(trace_obj) if trace_obj is not None else None + if extracted is None: + to_run_ids.append(qid) + to_run_queries.append(query_by_id.get(qid, "")) + else: + run_q, t_ms = extracted + cached_run[qid] = run_q + per_query_time_ms[qid] = t_ms + + # Indexing step: always call index() so the pipeline can set up embeddings/indices. + start_time_indexing = time.time() + pipeline.index(corpus_ids=corpus_ids, corpus_images=corpus_images, corpus_texts=corpus_texts) + indexing_time = time.time() - start_time_indexing + + if indexing_time < 1e-5: + indexing_time = 0.0 + + # Execute only the missing queries (if any) and always write traces for executed queries. + executed_run: Dict[str, Dict[str, float]] = {} + pipeline_infos: Optional[Dict[str, Any]] = None + pipeline_infos_public: Optional[Dict[str, Any]] = None + executed_call_wall_ms: Optional[float] = None + + if to_run_ids: + # Provide tracing context to pipelines that want to write per-query traces incrementally. + setattr( + pipeline, + "tracing_context", + { + "traces_dir": traces_dir, + "trace_run_name": trace_run_name_eff, + "dataset": dataset_name_eff, + "dataset_dir": dataset_dir, + "split": split, + "language": language, + "query_ids_selector": query_ids_selector, + "pipeline_class": pipeline.__class__.__name__, + "model_id": getattr(pipeline, "model_id", None), + "llm_model": getattr(pipeline, "llm_model", None), + }, + ) + setattr(pipeline, "excluded_ids_by_query", excluded_ids_by_query) + + call_start = time.perf_counter() if track_time else None + result = pipeline.search(query_ids=to_run_ids, queries=to_run_queries) + if isinstance(result, tuple): + executed_run, pipeline_infos = result + else: + executed_run, pipeline_infos = result, None + if track_time and call_start is not None: + executed_call_wall_ms = (time.perf_counter() - call_start) * 1000.0 + + if not isinstance(executed_run, dict): + raise ValueError(f"Pipeline must return a dict, got {type(executed_run)}") + + # Pull per-query timing from pipeline infos if provided, else fall back to equal-split wall time. + provided_times: Dict[str, Any] = {} + if isinstance(pipeline_infos, dict): + provided_times = pipeline_infos.get("per_query_retrieval_time_milliseconds", {}) or {} + + fallback_ms = None + if track_time and executed_call_wall_ms is not None and to_run_ids: + fallback_ms = executed_call_wall_ms / len(to_run_ids) + + for qid in to_run_ids: + t_ms = provided_times.get(qid, None) if isinstance(provided_times, dict) else None + if isinstance(t_ms, (int, float)): + per_query_time_ms[qid] = float(t_ms) + elif fallback_ms is not None: + per_query_time_ms[qid] = float(fallback_ms) + + per_query_trace: Dict[str, Any] = {} + if isinstance(pipeline_infos, dict): + pqt = pipeline_infos.get("per_query_trace", None) + if isinstance(pqt, dict): + per_query_trace = pqt + + if isinstance(pipeline_infos, dict): + pipeline_infos_public = dict(pipeline_infos) + pipeline_infos_public.pop("per_query_trace", None) + + llm_error_query_ids: List[str] = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + total_tokens = 0 + any_usage = False + _agg_trajectory_steps: List[int] = [] + _agg_llm_turns: List[int] = [] + _agg_retrieval_calls: List[int] = [] + for qid in to_run_ids: + extra = per_query_trace.get(qid, None) + if isinstance(extra, dict) and isinstance(extra.get("llm_error", None), str): + llm_error_query_ids.append(qid) + usage = extra.get("llm_usage", None) if isinstance(extra, dict) else None + if isinstance(usage, dict): + pt = usage.get("prompt_tokens", None) + ct = usage.get("completion_tokens", None) + tt = usage.get("total_tokens", None) + if isinstance(pt, int): + total_prompt_tokens += pt + any_usage = True + if isinstance(ct, int): + total_completion_tokens += ct + any_usage = True + if isinstance(tt, int): + total_tokens += tt + any_usage = True + if isinstance(extra, dict): + for _key, _lst in [ + ("trajectory_steps", _agg_trajectory_steps), + ("llm_turns", _agg_llm_turns), + ("retrieval_calls", _agg_retrieval_calls), + ]: + v = extra.get(_key, None) + if isinstance(v, int): + _lst.append(v) + pipeline_infos_public["llm_error_query_ids"] = llm_error_query_ids + if any_usage: + pipeline_infos_public["llm_total_prompt_tokens"] = total_prompt_tokens + pipeline_infos_public["llm_total_completion_tokens"] = total_completion_tokens + pipeline_infos_public["llm_total_tokens"] = total_tokens + if _agg_trajectory_steps: + pipeline_infos_public["avg_trajectory_steps"] = sum(_agg_trajectory_steps) / len(_agg_trajectory_steps) + if _agg_llm_turns: + pipeline_infos_public["avg_llm_turns"] = sum(_agg_llm_turns) / len(_agg_llm_turns) + if _agg_retrieval_calls: + pipeline_infos_public["avg_retrieval_calls"] = sum(_agg_retrieval_calls) / len(_agg_retrieval_calls) + + for qid in to_run_ids: + p = trace_path(traces_dir, trace_run_name_eff, dataset_dir, qid) + existing = load_trace_file(p) + if isinstance(existing, dict) and extract_run_and_time_ms(existing) is not None: + continue + + run_q = executed_run.get(qid, None) + run_q = _filtered_run_for_query(qid, run_q) + t_ms = per_query_time_ms.get(qid, None) + payload: Dict[str, Any] = { + "query_id": qid, + "dataset": dataset_name_eff, + "dataset_dir": dataset_dir, + "split": split, + "language": language, + "query_ids_selector": query_ids_selector, + "trace_run_name": trace_run_name_eff, + "pipeline_class": pipeline.__class__.__name__, + "model_id": getattr(pipeline, "model_id", None), + "retrieval_time_milliseconds": t_ms, + "run": run_q, + } + extra = per_query_trace.get(qid, None) + if isinstance(extra, dict): + payload["pipeline_trace"] = extra + write_trace_file(p, payload) + + # Combined run (cached + executed). Executed wins on collisions. + run: Dict[str, Dict[str, float]] = {} + run.update(cached_run) + run.update(executed_run) + + expected_query_ids = list(query_ids) + returned_set = set(run.keys()) + + evaluated_query_ids = [qid for qid in expected_query_ids if qid in returned_set] + missing_query_ids = [qid for qid in expected_query_ids if qid not in returned_set] + + if not evaluated_query_ids: + raise ValueError( + "Pipeline returned no results for any expected query_ids. " "Refusing to compute metrics on an empty set." + ) + + # Restrict evaluation to only queries we have both run + qrels for. + run_eval = {qid: _filtered_run_for_query(qid, run[qid]) for qid in evaluated_query_ids} + qrels_eval = {qid: qrels[qid] for qid in evaluated_query_ids if qid in qrels} + + # Create pytrec_eval evaluator + evaluator = pytrec_eval.RelevanceEvaluator(qrels_eval, set(metrics)) + + # Evaluate + results = evaluator.evaluate(run_eval) + + # Persist per-query evaluation metrics into trace files and aggregate pipeline-level + # summaries from traces (e.g. LLM token usage), so fully-cached runs still get totals. + llm_error_query_ids_from_traces: List[str] = [] + llm_total_prompt_tokens = 0 + llm_total_completion_tokens = 0 + llm_total_tokens = 0 + llm_any_usage = False + _trace_trajectory_steps: List[int] = [] + _trace_llm_turns: List[int] = [] + _trace_retrieval_calls: List[int] = [] + for qid in evaluated_query_ids: + p = trace_path(traces_dir, trace_run_name_eff, dataset_dir, qid) + trace_obj = load_trace_file(p) or {} + if not isinstance(trace_obj, dict): + trace_obj = {} + + pipeline_trace = trace_obj.get("pipeline_trace", None) + if isinstance(pipeline_trace, dict): + if isinstance(pipeline_trace.get("llm_error", None), str): + llm_error_query_ids_from_traces.append(qid) + usage = pipeline_trace.get("llm_usage", None) + if isinstance(usage, dict): + pt = usage.get("prompt_tokens", None) + ct = usage.get("completion_tokens", None) + tt = usage.get("total_tokens", None) + if isinstance(pt, int): + llm_total_prompt_tokens += pt + llm_any_usage = True + if isinstance(ct, int): + llm_total_completion_tokens += ct + llm_any_usage = True + if isinstance(tt, int): + llm_total_tokens += tt + llm_any_usage = True + for _keys, _lst in [ + (("trajectory_steps", "agent_steps"), _trace_trajectory_steps), + (("llm_turns", "num_agent_steps"), _trace_llm_turns), + (("retrieval_calls", "num_retrieval_steps"), _trace_retrieval_calls), + ]: + for _key in _keys: + v = pipeline_trace.get(_key, None) + if isinstance(v, int): + _lst.append(v) + break + + trace_obj.update( + { + "query_id": qid, + "dataset": dataset_name_eff, + "dataset_dir": dataset_dir, + "split": split, + "language": language, + "query_ids_selector": query_ids_selector, + "trace_run_name": trace_run_name_eff, + "pipeline_class": pipeline.__class__.__name__, + "model_id": getattr(pipeline, "model_id", None), + "per_query_metrics": results.get(qid, None), + } + ) + write_trace_file(p, trace_obj) + + # Add timing information if tracking + if track_time: + num_queries = len(query_ids) + num_corpus = len(corpus_ids) + wall_ms = (time.perf_counter() - wall_start) * 1000.0 if wall_start is not None else None + + total_retrieval_ms = 0.0 + for qid in evaluated_query_ids: + t_ms = per_query_time_ms.get(qid, None) + if isinstance(t_ms, (int, float)): + total_retrieval_ms += float(t_ms) + + avg_ms = (total_retrieval_ms / len(evaluated_query_ids)) if evaluated_query_ids else 0.0 + qps = (len(evaluated_query_ids) / (total_retrieval_ms / 1000.0)) if total_retrieval_ms > 0 else 0.0 + + num_loaded = len([qid for qid in evaluated_query_ids if qid in cached_run]) + num_executed = len([qid for qid in evaluated_query_ids if qid in executed_run]) + + executed_retrieval_ms = sum( + float(per_query_time_ms[qid]) + for qid in evaluated_query_ids + if qid in executed_run and isinstance(per_query_time_ms.get(qid, None), (int, float)) + ) + + results["_timing"] = { + "total_retrieval_time_milliseconds": total_retrieval_ms, + "total_retrieval_time_milliseconds_executed": executed_retrieval_ms, + "indexing_time_milliseconds": indexing_time * 1000, + "average_time_per_query_milliseconds": avg_ms, + "total_wall_time_milliseconds": wall_ms, + "expected_num_queries": num_queries, + "num_queries": len(evaluated_query_ids), + "num_corpus": num_corpus, + "missing_num_queries": len(missing_query_ids), + "queries_per_second": qps, + "num_queries_loaded_from_trace": num_loaded, + "num_queries_executed": num_executed, + } + + # Attach evaluation infos (and keep any pipeline-provided infos nested). + eval_infos: Dict[str, Any] = { + "tracing": { + "traces_dir": str(Path(traces_dir)), + "trace_run_name": trace_run_name_eff, + "dataset_dir": dataset_dir, + }, + } + + pipeline_infos_summary: Dict[str, Any] = {} + if pipeline_infos_public is not None: + pipeline_infos_summary.update(pipeline_infos_public) + elif isinstance(pipeline_infos, dict): + pipeline_infos_summary.update(pipeline_infos) + pipeline_infos_summary.pop("per_query_trace", None) + + pipeline_infos_summary["llm_error_query_ids"] = sorted( + llm_error_query_ids_from_traces, key=lambda x: int(x) if x.isdigit() else x + ) + if llm_any_usage: + pipeline_infos_summary["llm_total_prompt_tokens"] = llm_total_prompt_tokens + pipeline_infos_summary["llm_total_completion_tokens"] = llm_total_completion_tokens + pipeline_infos_summary["llm_total_tokens"] = llm_total_tokens + if _trace_trajectory_steps: + pipeline_infos_summary["avg_trajectory_steps"] = sum(_trace_trajectory_steps) / len(_trace_trajectory_steps) + if _trace_llm_turns: + pipeline_infos_summary["avg_llm_turns"] = sum(_trace_llm_turns) / len(_trace_llm_turns) + if _trace_retrieval_calls: + pipeline_infos_summary["avg_retrieval_calls"] = sum(_trace_retrieval_calls) / len(_trace_retrieval_calls) + + if pipeline_infos_summary: + eval_infos["pipeline_infos"] = pipeline_infos_summary + results["_infos"] = eval_infos + + return results + + +def aggregate_results( + results: Dict[str, Dict[str, float]], query_languages: Optional[Dict[str, str]] = None +) -> Dict[str, Any]: + """ + Calculate aggregate statistics across all queries. + + If query_languages is provided, also computes per-language aggregates. + + Args: + results: Per-query evaluation results from evaluate_retrieval() + query_languages: Optional mapping of query_id to language + + Returns: + Dictionary of aggregated metrics. If query_languages is provided: + { + 'overall': {'ndcg_cut_10': 0.85, ...}, + 'by_language': { + 'english': {'ndcg_cut_10': 0.87, ...}, + 'french': {'ndcg_cut_10': 0.82, ...}, + }, + 'timing': {...} # if timing info present + 'infos': {...} # if pipeline infos present + } + Otherwise, just returns flat aggregated metrics. + """ + if not results: + return {} + + # Extract meta information if present (without mutating input) + timing_info = results.get("_timing", None) + infos = results.get("_infos", None) + + # Filter to actual per-query metrics (ignore meta keys like _timing/_infos) + query_results = {qid: qres for qid, qres in results.items() if not str(qid).startswith("_")} + if not query_results: + final = {"timing": timing_info} if timing_info else {} + if infos is not None: + final["infos"] = infos + return final + + # Get all metric names from first query + metric_names = list(next(iter(query_results.values())).keys()) + + # If no language splitting requested, return simple aggregation + if query_languages is None: + aggregated = {} + for metric in metric_names: + scores = [query_results[qid][metric] for qid in query_results] + aggregated[metric] = sum(scores) / len(scores) + + if timing_info: + aggregated.update(timing_info) + if infos is not None: + aggregated["infos"] = infos + + return aggregated + + # Split results by language + results_by_language = defaultdict(dict) + for query_id, per_query_results in query_results.items(): + lang = query_languages.get(query_id, "unknown") + results_by_language[lang][query_id] = per_query_results + + # Compute overall aggregates + overall_aggregated = {} + for metric in metric_names: + scores = [query_results[qid][metric] for qid in query_results] + overall_aggregated[metric] = sum(scores) / len(scores) + + # Compute per-language aggregates + by_language_aggregated = {} + for lang, lang_results in results_by_language.items(): + lang_aggregated = {} + for metric in metric_names: + scores = [lang_results[qid][metric] for qid in lang_results] + lang_aggregated[metric] = sum(scores) / len(scores) + lang_aggregated["num_queries"] = len(lang_results) + by_language_aggregated[lang] = lang_aggregated + + # Build final result structure + final_result = { + "overall": overall_aggregated, + "by_language": by_language_aggregated, + } + + if timing_info: + final_result["timing"] = timing_info + if infos is not None: + final_result["infos"] = infos + + return final_result diff --git a/retrieval-bench/src/retrieval_bench/pipeline_evaluation/tracing.py b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/tracing.py new file mode 100644 index 000000000..fe280ea34 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipeline_evaluation/tracing.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Helpers for per-query trace caching and logging. + +Traces are written as one JSON file per query: + traces///.json +""" + +from __future__ import annotations + +import json +import logging +import re +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +logger = logging.getLogger(__name__) + + +def _slugify(value: str) -> str: + """ + Make a filesystem-friendly string. + + Keeps [A-Za-z0-9._-], replaces anything else with '_'. + + Important: we intentionally preserve double-underscore separators (e.g. `A__B__C`) + used in run naming for readability. + """ + value = (value or "").strip() + if not value: + return "unnamed" + value = value.replace("/", "_") + value = re.sub(r"[^A-Za-z0-9._-]+", "_", value) + # Preserve `__` separators but avoid pathological underscore runs produced by replacement. + value = re.sub(r"_{3,}", "__", value).strip("_") + return value or "unnamed" + + +def model_id_short(model_id: Optional[str]) -> Optional[str]: + if not model_id: + return None + return str(model_id).split("/")[-1] + + +def default_trace_run_name(pipeline: Any) -> str: + cls = pipeline.__class__.__name__ + mid = getattr(pipeline, "model_id", None) + mid_short = model_id_short(mid) + llm_mid = getattr(pipeline, "llm_model", None) + llm_short = model_id_short(llm_mid) + + # Prefer: ____ + parts = [cls] + if mid_short: + parts.append(mid_short) + if llm_short: + parts.append(llm_short) + return _slugify("__".join(parts)) + + +def dataset_trace_dir(dataset_name: str, split: str = "test", language: Optional[str] = None) -> str: + """ + Trace subdirectory name for a dataset. + + Simplified by design: we use a stable filesystem-friendly identifier. + + - ViDoRe: keep existing short-name behavior: + 'vidore/vidore_v3_finance_en' -> 'vidore_v3_finance_en' + - BRIGHT: include the dataset prefix to avoid collisions: + 'bright/biology' -> 'bright__biology' + + (split/language are intentionally ignored to keep paths stable and simple.) + """ + ds = str(dataset_name or "unknown_dataset").strip() + parts = [p for p in ds.split("/") if p] + if len(parts) >= 2 and parts[0].lower() == "bright": + return _slugify(f"bright__{parts[1]}") + return _slugify(parts[-1] if parts else "unknown_dataset") + + +def trace_path( + traces_dir: str, + trace_run_name: str, + dataset_dir: str, + query_id: str, +) -> Path: + return Path(traces_dir) / _slugify(trace_run_name) / _slugify(dataset_dir) / f"{_slugify(query_id)}.json" + + +def load_trace_file(path: Path) -> Optional[Dict[str, Any]]: + try: + if not path.exists(): + return None + with open(path, "r") as f: + obj = json.load(f) + if not isinstance(obj, dict): + return None + return obj + except Exception: + logger.debug("Failed to load trace file %s", path, exc_info=True) + return None + + +def extract_run_and_time_ms(trace_obj: Dict[str, Any]) -> Optional[Tuple[Dict[str, float], float]]: + """ + Returns (run, retrieval_time_ms) if trace has required fields; otherwise None. + """ + run = trace_obj.get("run", None) + t = trace_obj.get("retrieval_time_milliseconds", None) + if not isinstance(run, dict): + return None + if not isinstance(t, (int, float)): + return None + return run, float(t) + + +def write_trace_file(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w") as f: + json.dump(payload, f, indent=2) + tmp.replace(path) + + +def write_query_trace( + *, + traces_dir: str, + trace_run_name: str, + dataset: str, + dataset_dir: str, + query_id: str, + pipeline_class: str, + model_id: Optional[str], + retrieval_time_milliseconds: Optional[float], + run: Optional[Dict[str, float]], + split: str = "test", + language: Optional[str] = None, + query_ids_selector: Optional[str] = None, + pipeline_trace: Optional[Dict[str, Any]] = None, +) -> Path: + """ + Write a per-query trace JSON file in the canonical evaluator format. + + This is intended for pipelines that want to write traces incrementally (per query) + while preserving evaluator caching semantics (which require `run` + numeric time). + """ + payload: Dict[str, Any] = { + "query_id": query_id, + "dataset": dataset, + "dataset_dir": dataset_dir, + "split": split, + "language": language, + "query_ids_selector": query_ids_selector, + "trace_run_name": trace_run_name, + "pipeline_class": pipeline_class, + "model_id": model_id, + "retrieval_time_milliseconds": retrieval_time_milliseconds, + "run": run, + } + if isinstance(pipeline_trace, dict): + payload["pipeline_trace"] = pipeline_trace + + p = trace_path(traces_dir, trace_run_name, dataset_dir, query_id) + write_trace_file(p, payload) + return p diff --git a/retrieval-bench/src/retrieval_bench/pipelines/__init__.py b/retrieval-bench/src/retrieval_bench/pipelines/__init__.py new file mode 100644 index 000000000..5a80541c4 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipelines/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Built-in retrieval pipelines for vidore-benchmark evaluation. +""" + +from retrieval_bench.pipelines.backends import VALID_BACKENDS, init_backend +from retrieval_bench.pipelines.dense import DenseRetrievalPipeline +from retrieval_bench.pipelines.agentic import AgenticRetrievalPipeline + +__all__ = [ + "DenseRetrievalPipeline", + "AgenticRetrievalPipeline", + "VALID_BACKENDS", + "init_backend", +] diff --git a/retrieval-bench/src/retrieval_bench/pipelines/agentic.py b/retrieval-bench/src/retrieval_bench/pipelines/agentic.py new file mode 100644 index 000000000..b923cfdcd --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipelines/agentic.py @@ -0,0 +1,594 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Agentic retrieval pipeline: dense retrieval augmented with an LLM agent. + +The agent iteratively refines results using tool calls (retrieve, think, +final_results). The dense backend is selected via the ``backend`` parameter +and initialised through the shared :func:`init_backend` helper. +""" + +from __future__ import annotations + +import asyncio +import contextvars +import json +import os +import re +import sys +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from dotenv import load_dotenv + +_CURRENT_QUERY_ID: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "nemo_agentic_current_query_id", default=None +) + +try: + import torch +except ImportError: + print("Error: Required GPU dependencies not installed.") + print("Please install: pip install torch") + sys.exit(1) + +from vidore_benchmark.pipeline_evaluation.base_pipeline import BasePipeline +from retrieval_bench.pipeline_evaluation.tracing import write_query_trace +from retrieval_bench.pipelines.backends import VALID_BACKENDS, infer_bright_task_key, init_backend + +from retrieval_bench.nemo_agentic.agent import Agent +from retrieval_bench.nemo_agentic.configs import AgentConfig, LLMConfig +from retrieval_bench.nemo_agentic.llm_handler import LLM, is_error, normalize_messages_for_api +from retrieval_bench.nemo_agentic.tool_helpers import BaseTool, FinalResults, RetrieveToolBase, ThinkTool + + +# --------------------------------------------------------------------------- +# RetrieveTool adapter +# --------------------------------------------------------------------------- + + +class RetrieveTool(RetrieveToolBase): + """ + Adapter tool wrapping a vidore-benchmark retriever singleton for the agent. + + Expects ``retriever.retrieve(query, return_markdown=True, excluded_ids=...)`` + to return ``(scores_dict, markdown_dict)``. + """ + + def __init__(self, retriever: Any, excluded_ids: Optional[List[str]] = None, top_k: int = 20): + self.retriever = retriever + self.excluded_ids = excluded_ids or [] + self._default_top_k = int(top_k) + + def _spec(self) -> Dict[str, Any]: + return { + "type": "function", + "function": { + "name": "retrieve", + "description": "Search for documents related to a query using dense retrieval.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query.", + }, + "top_k": { + "type": "integer", + "description": "Number of documents to retrieve.", + "default": self._default_top_k, + }, + }, + "required": ["query"], + }, + }, + } + + async def _acall(self, query: str, top_k: Optional[int] = None, **kwargs: Any) -> List[Dict[str, Any]]: + effective_top_k = int(kwargs.pop("__art_top_k", top_k or self._default_top_k)) + + try: + scores, markdowns = self.retriever.retrieve( + str(query), + return_markdown=True, + excluded_ids=self.excluded_ids, + ) + except TypeError: + scores, markdowns = self.retriever.retrieve(str(query), return_markdown=True) + + results: List[Dict[str, Any]] = [] + for doc_id, score in scores.items(): + results.append( + { + "id": str(doc_id), + "score": float(score), + "text": str(markdowns.get(doc_id, "")), + } + ) + results.sort(key=lambda x: x["score"], reverse=True) + return results[:effective_top_k] + + +# --------------------------------------------------------------------------- +# Result extraction helpers +# --------------------------------------------------------------------------- + + +def extract_final_doc_ids(output_artifacts: Dict[str, Any], *, final_top_k: int = 10) -> Tuple[List[str], str]: + """ + Extract the agent's final ranked doc ids with structured fallbacks. + + Fallback order: + 1) ``final_results`` tool-call args (primary) + 2) ``rrf_scores`` (secondary) + 3) ``top{final_top_k}_selection_result`` (tertiary) + """ + traj = output_artifacts.get("agent_trajectories", []) or [] + for msg in reversed(traj): + for tc in msg.get("tool_calls", []) or []: + fn = tc.get("function", {}) or {} + if fn.get("name") != "final_results": + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + args = json.loads(args) + except Exception: + args = None + if isinstance(args, dict): + doc_ids = args.get("doc_ids") + if isinstance(doc_ids, list) and all(isinstance(i, str) for i in doc_ids): + return list(doc_ids), "final_results" + + rrf_scores = output_artifacts.get("rrf_scores", None) + if isinstance(rrf_scores, dict) and rrf_scores: + try: + sorted_ids = sorted(rrf_scores, key=rrf_scores.get, reverse=True) + return [str(i) for i in sorted_ids[: int(final_top_k)]], "rrf" + except Exception: + pass + + sel = output_artifacts.get(f"top{int(final_top_k)}_selection_result", None) + if isinstance(sel, dict): + doc_ids = sel.get("doc_ids") + if isinstance(doc_ids, list) and doc_ids and all(isinstance(i, str) for i in doc_ids): + return list(doc_ids)[: int(final_top_k)], "selection_agent" + if isinstance(sel, list) and sel and all(isinstance(i, str) for i in sel): + return list(sel)[: int(final_top_k)], "selection_agent" + + return [], "none" + + +# --------------------------------------------------------------------------- +# LLM usage-tracking wrapper +# --------------------------------------------------------------------------- + + +def _wrap_llm_for_usage_tracking(llm: LLM) -> LLM: + llm._accumulated_usage = {"prompt_tokens": 0, "completion_tokens": 0} # type: ignore[attr-defined] + llm._per_query_usage = {} # type: ignore[attr-defined] + _original_acompletion = llm.acompletion + + async def _tracked_acompletion(*args: Any, **kwargs: Any) -> Any: + if "messages" in kwargs: + kwargs["messages"] = normalize_messages_for_api(kwargs["messages"]) + elif args: + args_list = list(args) + if isinstance(args_list[0], list): + args_list[0] = normalize_messages_for_api(args_list[0]) + args = tuple(args_list) + + resp = await _original_acompletion(*args, **kwargs) + + model_resp = resp + if isinstance(resp, dict) and "response" in resp: + model_resp = resp.get("response") + + usage = getattr(model_resp, "usage", None) + if usage is not None: + prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0) + completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0) + + llm._accumulated_usage["prompt_tokens"] += prompt_tokens # type: ignore[attr-defined] + llm._accumulated_usage["completion_tokens"] += completion_tokens # type: ignore[attr-defined] + + qid = _CURRENT_QUERY_ID.get() + if qid: + per_q = llm._per_query_usage.get(qid) # type: ignore[attr-defined] + if not isinstance(per_q, dict): + per_q = {"prompt_tokens": 0, "completion_tokens": 0} + llm._per_query_usage[qid] = per_q # type: ignore[attr-defined] + per_q["prompt_tokens"] = int(per_q.get("prompt_tokens", 0)) + prompt_tokens + per_q["completion_tokens"] = int(per_q.get("completion_tokens", 0)) + completion_tokens + return resp + + llm.acompletion = _tracked_acompletion # type: ignore[assignment] + return llm + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +class AgenticRetrievalPipeline(BasePipeline): + """ + Dense retrieval augmented with an LLM agent that iteratively refines results. + + The ``backend`` parameter selects which dense retriever runs underneath. + Additional backend-specific overrides can be passed via ``**backend_kwargs``. + """ + + def __init__( + self, + *, + backend: str, + retriever_top_k: int = 500, + # Agent / LLM knobs + num_concurrent: int = 1, + llm_model: str, + api_key: str = "os.environ/OPENAI_API_KEY", + base_url: Optional[str] = "os.environ/OPENAI_BASE_URL", + api_version: Optional[str] = None, + reasoning_effort: str = "high", + drop_params: bool = True, + allowed_openai_params: Optional[List[str]] = None, + raw_log_pardir: str = "nemo_agentic_logs", + instant_log: bool = False, + strict_error_handling: bool = False, + target_top_k: int = 10, + max_steps: int = 200, + main_agent_only: bool = False, + selection_topk_list: Optional[List[int]] = None, + # All remaining kwargs are forwarded as backend overrides. + **backend_kwargs: Any, + ) -> None: + load_dotenv() + + if backend not in VALID_BACKENDS: + raise ValueError(f"Unknown backend {backend!r}. " f"Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + self.backend = backend + self.model_id = backend + self.retriever_top_k = int(retriever_top_k) + self.num_concurrent = max(1, int(num_concurrent)) + self._backend_kwargs = dict(backend_kwargs) + + # Resolve os.environ/... convention for base_url. + if base_url and str(base_url).strip().startswith("os.environ/"): + env_var = str(base_url).strip().removeprefix("os.environ/") + base_url = os.environ.get(env_var, None) + + if selection_topk_list is None: + selection_topk_list = [5, 10] + + self._agent_config = AgentConfig( + system_prompt="02_v1.j2", + target_top_k=int(target_top_k), + max_steps=int(max_steps), + main_agent_only=bool(main_agent_only), + selection_topk_list=list(selection_topk_list), + ) + self._llm_config = LLMConfig( + model=str(llm_model), + api_key=str(api_key), + base_url=str(base_url) if base_url else None, + api_version=str(api_version) if api_version else None, + reasoning_effort=str(reasoning_effort), + raw_log_pardir=str(raw_log_pardir), + instant_log=bool(instant_log), + strict_error_handling=bool(strict_error_handling), + drop_params=bool(drop_params), + allowed_openai_params=list(allowed_openai_params) if allowed_openai_params else None, + ) + + self.llm_model = self._llm_config.model + + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This pipeline requires a GPU.") + sys.exit(1) + + # ----------------------------------------------------------------------- + # Async query loop + # ----------------------------------------------------------------------- + + @staticmethod + def _summarize_preflight_error(error: Exception, model: str, base_url: Optional[str]) -> str: + raw = re.sub(r"\s+", " ", str(error)).strip() + hint_parts: List[str] = [] + + low = raw.lower() + if "notfound" in low or "404" in low: + hint_parts.append("model not found for this endpoint/provider") + elif "unauthorized" in low or "401" in low or "invalid api key" in low: + hint_parts.append("authentication failed (check API key)") + elif "forbidden" in low or "403" in low: + hint_parts.append("access denied for this model/key") + + if str(model).startswith("unknown/"): + hint_parts.append("model id starts with 'unknown/' (provider prefix likely incorrect)") + + hint = "" + if hint_parts: + hint = " Hint: " + "; ".join(hint_parts) + "." + + return ( + f"LLM preflight check failed before evaluation. " + f"model={model!r}, base_url={base_url!r}.{hint} " + f"Error: {raw}" + ) + + async def _validate_llm_preflight(self, llm: LLM) -> None: + """ + Run a lightweight LLM call before processing the full dataset. + + This prevents confusing per-query fallback behavior when the endpoint/model + is misconfigured (e.g. model not found, invalid base URL, bad API key). + """ + probe_messages = [ + { + "role": "user", + "content": "Health check: reply with OK.", + } + ] + + try: + response = await llm.acompletion( + messages=probe_messages, + tools=None, + max_completion_tokens=8, + num_retries=0, + ) + except Exception as e: + raise RuntimeError( + self._summarize_preflight_error( + e, model=str(self._llm_config.model), base_url=self._llm_config.base_url + ) + ) from e + + if is_error(response): + # llm_handler returns string errors in non-strict mode. + err = str(response).replace("LLMError:", "", 1).strip() + raise RuntimeError( + self._summarize_preflight_error( + RuntimeError(err), model=str(self._llm_config.model), base_url=self._llm_config.base_url + ) + ) + + async def _run_all_queries( + self, + *, + query_ids: Sequence[str], + queries: Sequence[str], + llm: LLM, + excluded_ids_by_query: Dict[str, List[str]], + tracing_context: Optional[Dict[str, Any]], + ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, Any]], Dict[str, float]]: + agent_config = self._agent_config + final_top_k = int(agent_config.target_top_k or 10) + + results: Dict[str, Dict[str, float]] = {} + per_query_trace: Dict[str, Dict[str, Any]] = {} + per_query_ms: Dict[str, float] = {} + total_queries = len(query_ids) + completed = 0 + completed_lock = asyncio.Lock() + sem = asyncio.Semaphore(max(1, int(getattr(self, "num_concurrent", 1) or 1))) + + async def _process_query(q_idx: int, qid: str, query_text: Any) -> None: + nonlocal completed + async with sem: + excluded = excluded_ids_by_query.get(qid, []) or [] + + t0 = time.perf_counter() + + trace_entry: Dict[str, Any] = {} + trace_entry["query_text"] = str(query_text) + trace_entry["llm_model"] = str( + getattr(self, "_llm_config", None).model if hasattr(self, "_llm_config") else "" + ) + + retrieve_tool = RetrieveTool( + retriever=self._active_retriever, + excluded_ids=list(excluded), + top_k=int(agent_config.target_top_k or 20), + ) + tool_map: Dict[str, BaseTool] = {"retrieve": retrieve_tool} + + tk = ( + int(agent_config.target_top_k) if agent_config.enforce_top_k and agent_config.target_top_k else None + ) + tool_map["final_results"] = FinalResults(top_k=tk) + + if not agent_config.disable_think: + think = ThinkTool(extended_relevance=agent_config.extended_relevance) + tool_map[think.name] = think + + agent = Agent(config=agent_config, llm=llm, tool_map=tool_map, session_id=qid) + + doc_ids: List[str] = [] + source = "none" + try: + if hasattr(llm, "_per_query_usage"): + llm._per_query_usage[str(qid)] = { # type: ignore[attr-defined] + "prompt_tokens": 0, + "completion_tokens": 0, + } + + token = _CURRENT_QUERY_ID.set(str(qid)) + try: + output = await agent.run_for_input( + query=str(query_text), + exclude_docids=set(excluded), + ) + finally: + _CURRENT_QUERY_ID.reset(token) + + doc_ids, source = extract_final_doc_ids(output, final_top_k=final_top_k) + + trajectories = output.get("agent_trajectories", []) or [] + trace_entry["trajectory_steps"] = len(trajectories) + trace_entry["llm_turns"] = sum(1 for m in trajectories if m.get("role") == "assistant") + trace_entry["retrieval_calls"] = sum( + 1 + for m in trajectories + for tc in (m.get("tool_calls") or []) + if (tc.get("function") or {}).get("name") == "retrieve" + ) + trace_entry["result_source"] = source + trace_entry["rrf_used"] = source == "rrf" + trace_entry["selection_agent_ran"] = any( + isinstance(k, str) and k.startswith("top") and k.endswith("_selection_result") + for k in output.keys() + ) + trace_entry["doc_ids"] = list(doc_ids) + retrieval_log = output.get("retrieval_log", []) if isinstance(output, dict) else [] + trace_entry["num_retrieval_calls"] = ( + int(len(retrieval_log)) if isinstance(retrieval_log, list) else 0 + ) + agent_extra = output.get("agent_extra_data", None) if isinstance(output, dict) else None + trace_entry["query_rewriting_used"] = bool(isinstance(agent_extra, dict) and len(agent_extra) > 0) + rrf_scores = output.get("rrf_scores", None) if isinstance(output, dict) else None + if isinstance(rrf_scores, dict) and rrf_scores: + try: + top_rrf_ids = sorted(rrf_scores, key=rrf_scores.get, reverse=True)[:final_top_k] + trace_entry["rrf_scores_summary"] = [str(i) for i in top_rrf_ids] + except Exception: + trace_entry["rrf_scores_summary"] = [] + trace_entry["fallback_used"] = False + except Exception as e: + try: + scores = self._active_retriever.retrieve(str(query_text), excluded_ids=list(excluded)) + except TypeError: + scores = self._active_retriever.retrieve(str(query_text)) + doc_ids = [str(i) for i in sorted(scores, key=scores.get, reverse=True)[:final_top_k]] + trace_entry["llm_error"] = f"{type(e).__name__}: {e}" + trace_entry["result_source"] = "retriever_fallback" + trace_entry["doc_ids"] = list(doc_ids) + trace_entry["fallback_used"] = True + + t1 = time.perf_counter() + elapsed_ms = (t1 - t0) * 1000.0 + per_query_ms[qid] = float(elapsed_ms) + + per_q_usage = getattr(llm, "_per_query_usage", {}).get(str(qid), {}) # type: ignore[attr-defined] + pt = int(per_q_usage.get("prompt_tokens", 0) or 0) + ct = int(per_q_usage.get("completion_tokens", 0) or 0) + trace_entry["llm_usage"] = { + "prompt_tokens": pt, + "completion_tokens": ct, + "total_tokens": pt + ct, + } + trace_entry["elapsed_ms"] = float(elapsed_ms) + + run_for_q = {did: float(len(doc_ids) - rank) for rank, did in enumerate(doc_ids)} + results[qid] = run_for_q + per_query_trace[qid] = trace_entry + + if isinstance(tracing_context, dict): + try: + write_query_trace( + traces_dir=str(tracing_context.get("traces_dir", "traces")), + trace_run_name=str(tracing_context.get("trace_run_name", "unnamed")), + dataset=str(tracing_context.get("dataset", self.dataset_name)), + dataset_dir=str(tracing_context.get("dataset_dir", str(self.dataset_name).split("/")[-1])), + query_id=str(qid), + pipeline_class=str(tracing_context.get("pipeline_class", self.__class__.__name__)), + model_id=str(tracing_context.get("model_id", "")), + retrieval_time_milliseconds=float(elapsed_ms), + run=dict(run_for_q), + split=str(tracing_context.get("split", "test")), + language=tracing_context.get("language", None), + query_ids_selector=tracing_context.get("query_ids_selector", None), + pipeline_trace=dict(trace_entry), + ) + except Exception as e: + print(f"WARNING: failed to write per-query trace for query_id={qid}: {type(e).__name__}: {e}") + + async with completed_lock: + completed += 1 + if completed == 1 or completed % 10 == 0 or completed == total_queries: + print( + f" Agent queries completed: {completed}/{total_queries}" + f" (concurrency={self.num_concurrent})" + ) + + tasks = [ + asyncio.create_task(_process_query(q_idx=i, qid=str(qid), query_text=query_text)) + for i, (qid, query_text) in enumerate(zip(query_ids, queries)) + ] + await asyncio.gather(*tasks) + + return results, per_query_trace, per_query_ms + + # ----------------------------------------------------------------------- + # Main entry point + # ----------------------------------------------------------------------- + + def index(self, corpus_ids: List[str], corpus_images: List[Any], corpus_texts: List[str]) -> None: + super().index(corpus_ids=corpus_ids, corpus_images=corpus_images, corpus_texts=corpus_texts) + + dataset_name = self.dataset_name + task_key = infer_bright_task_key(dataset_name) + + corpus = [{"image": img, "markdown": md} for img, md in zip(corpus_images, corpus_texts)] + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t_init0 = time.perf_counter() + + active_retriever, effective_model_id, init_info = init_backend( + self.backend, + dataset_name=dataset_name, + corpus_ids=corpus_ids, + corpus=corpus, + top_k=self.retriever_top_k, + task_key=task_key, + overrides=self._backend_kwargs or None, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + retriever_init_ms = (time.perf_counter() - t_init0) * 1000.0 + + print(f"Using backend {self.backend} ({effective_model_id})") + print(f"Retriever init: {retriever_init_ms / 1000.0:.2f}s") + + self._active_retriever = active_retriever + self._init_info = init_info + self._retriever_init_ms = retriever_init_ms + + def search( + self, + query_ids: List[str], + queries: List[str], + ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Any]]: + tracing_context = getattr(self, "tracing_context", None) + excluded_ids_by_query = getattr(self, "excluded_ids_by_query", None) or {} + + llm = _wrap_llm_for_usage_tracking(LLM(self._llm_config)) + + try: + # Fail fast once if the LLM endpoint/model is invalid. + asyncio.run(self._validate_llm_preflight(llm)) + + results, per_query_trace, per_query_ms = asyncio.run( + self._run_all_queries( + query_ids=query_ids, + queries=queries, + llm=llm, + excluded_ids_by_query=excluded_ids_by_query, + tracing_context=tracing_context, + ) + ) + finally: + self._active_retriever.unload() + + infos: Dict[str, Any] = { + **self._init_info, + "retriever_top_k": self.retriever_top_k, + "retriever_init_ms": float(self._retriever_init_ms), + "per_query_retrieval_time_milliseconds": per_query_ms, + "per_query_trace": per_query_trace, + } + return results, infos diff --git a/retrieval-bench/src/retrieval_bench/pipelines/backends.py b/retrieval-bench/src/retrieval_bench/pipelines/backends.py new file mode 100644 index 000000000..a05189d20 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipelines/backends.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Shared backend initialization for dense retrieval singletons. + +Used by both DenseRetrievalPipeline and AgenticRetrievalPipeline. +""" + +from __future__ import annotations + +import os +from typing import Any, Dict, Optional, Sequence, Tuple, Union + + +def infer_bright_task_key(dataset_name: Any) -> Optional[str]: + """Extract the BRIGHT task key (e.g. 'biology') from a dataset name like 'bright/biology'.""" + try: + ds = str(dataset_name or "").strip() + except Exception: + return None + if not ds: + return None + parts = [p for p in ds.split("/") if p] + if len(parts) >= 2 and parts[0].lower() == "bright": + return parts[1] + return None + + +VALID_BACKENDS = { + "llama-nv-embed-reasoning-3b", + "llama-nemoretriever-colembed-3b-v1", + "llama-nemotron-embed-vl-1b-v2", + "nemotron-colembed-vl-8b-v2", +} + +_BACKEND_DEFAULTS: Dict[str, Dict[str, Any]] = { + "llama-nv-embed-reasoning-3b": { + "model_id": "nvidia/llama-nv-embed-reasoning-3b", + "max_length": 8192, + "pooling": "mean", + "score_scale": 100.0, + "corpus_batch_size": 1, + "scoring_batch_size": 4096, + "preload_corpus_to_gpu": False, + "query_prefix_fallback": ( + "Instruct: Given the following post, retrieve relevant passages that help answer the post.\nQuery:" + ), + }, + "llama-nemoretriever-colembed-3b-v1": { + "model_id": "nvidia/llama-nemoretriever-colembed-3b-v1", + "batch_size": 32, + "corpus_batch_size": 32, + "corpus_chunk_size": 256, + "preload_corpus_to_gpu": True, + }, + "llama-nemotron-embed-vl-1b-v2": { + "model_id": "nvidia/llama-nemotron-embed-vl-1b-v2", + "device": "auto", + "doc_modality": "image_text", + "doc_max_length": "auto", + "query_max_length": 10240, + "corpus_batch_size": 4, + "corpus_chunk_size": 4096, + "preload_corpus_to_gpu": False, + "max_input_tiles": 6, + "use_thumbnail": True, + }, + "nemotron-colembed-vl-8b-v2": { + "model_id": "nvidia/nemotron-colembed-vl-8b-v2", + "corpus_batch_size": 8, + "corpus_chunk_size": 256, + "preload_corpus_to_gpu": False, + "max_input_tiles": 8, + "use_thumbnail": True, + "cache_dir": "cache/nemotron_colembed_vl_v2", + }, +} + + +class _NemotronEmbedVLAdapter: + """Adapter so nemotron_embed_vl matches the retrieve(excluded_ids=...) interface.""" + + def __init__(self, inner: Any) -> None: + self._inner = inner + + def retrieve( + self, + query: str, + *, + return_markdown: bool = False, + excluded_ids: Optional[Sequence[str]] = None, + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + return self._inner.retrieve( + str(query), + return_markdown=bool(return_markdown), + excluded_ids=excluded_ids, + ) + + def unload(self) -> None: + self._inner.unload() + + +def _import_retriever(backend: str) -> Any: + """Lazily import the singleton retriever for a given backend.""" + if backend == "llama-nv-embed-reasoning-3b": + from retrieval_bench.singletons.hf_dense_retriever import retriever + + return retriever + elif backend == "llama-nemoretriever-colembed-3b-v1": + from retrieval_bench.singletons.colembed_retriever import retriever + + return retriever + elif backend == "llama-nemotron-embed-vl-1b-v2": + from retrieval_bench.singletons.nemotron_embed_vl_dense_retriever import retriever + + return retriever + elif backend == "nemotron-colembed-vl-8b-v2": + from retrieval_bench.singletons.nemotron_colembed_vl_v2_retriever import retriever + + return retriever + else: + raise ValueError(f"Unknown backend {backend!r}. Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + + +def get_backend_defaults(backend: str) -> Dict[str, Any]: + """Return a copy of the default kwargs for a backend.""" + if backend not in VALID_BACKENDS: + raise ValueError(f"Unknown backend {backend!r}. Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + return dict(_BACKEND_DEFAULTS[backend]) + + +def init_backend( + backend: str, + *, + dataset_name: str, + corpus_ids: Any, + corpus: Any, + top_k: int = 100, + task_key: Optional[str] = None, + overrides: Optional[Dict[str, Any]] = None, +) -> Tuple[Any, str, Dict[str, Any]]: + """ + Initialize a retriever backend and return (active_retriever, effective_model_id, init_info). + + The returned retriever object has .retrieve() and .unload() methods. + ``init_info`` contains backend-specific metadata for the infos dict. + """ + if backend not in VALID_BACKENDS: + raise ValueError(f"Unknown backend {backend!r}. Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + + cfg = get_backend_defaults(backend) + if overrides: + cfg.update(overrides) + + model_id = os.path.expanduser(str(cfg.pop("model_id"))) + retriever = _import_retriever(backend) + init_info: Dict[str, Any] = {"backend": backend} + + if backend == "llama-nv-embed-reasoning-3b": + from retrieval_bench.prompts.bright_instructions import ( + NEMO_REASONING_PASSAGE_PREFIX, + get_bright_query_prefix_nemo, + ) + + query_prefix_fallback = str(cfg.pop("query_prefix_fallback")) + query_prefix = get_bright_query_prefix_nemo(task_key=task_key, fallback=query_prefix_fallback) + + pooling = str(cfg.pop("pooling", "mean")) + max_length = int(cfg.pop("max_length", 8192)) + score_scale = float(cfg.pop("score_scale", 100.0)) + corpus_batch_size = int(cfg.pop("corpus_batch_size", 1)) + scoring_batch_size = int(cfg.pop("scoring_batch_size", 4096)) + preload_corpus_to_gpu = bool(cfg.pop("preload_corpus_to_gpu", False)) + + retriever.init( + dataset_name=dataset_name, + corpus_ids=corpus_ids, + corpus=corpus, + model_id=model_id, + device="cuda", + top_k=top_k, + max_length=max_length, + pooling=pooling, + doc_prefix=str(NEMO_REASONING_PASSAGE_PREFIX), + query_prefix=str(query_prefix), + task_description="Given the following post, retrieve relevant passages that help answer the post.", + score_scale=score_scale, + batch_size=1, + corpus_batch_size=corpus_batch_size, + scoring_batch_size=scoring_batch_size, + cache_dir="cache/hf_dense", + preload_corpus_to_gpu=preload_corpus_to_gpu, + ) + init_info.update( + { + "model_id": model_id, + "pooling": pooling, + "task_key": task_key, + "query_prefix": str(query_prefix), + "doc_prefix": str(NEMO_REASONING_PASSAGE_PREFIX), + "max_length": max_length, + "score_scale": score_scale, + } + ) + return retriever, model_id, init_info + + elif backend == "llama-nemoretriever-colembed-3b-v1": + batch_size = int(cfg.pop("batch_size", 32)) + corpus_batch_size = int(cfg.pop("corpus_batch_size", 32)) + corpus_chunk_size = int(cfg.pop("corpus_chunk_size", 256)) + preload_corpus_to_gpu = bool(cfg.pop("preload_corpus_to_gpu", True)) + + retriever.init( + dataset_name=dataset_name, + corpus_ids=corpus_ids, + corpus=corpus, + model_id=model_id, + top_k=top_k, + batch_size=batch_size, + corpus_batch_size=corpus_batch_size, + corpus_chunk_size=corpus_chunk_size, + cache_dir="cache", + preload_corpus_to_gpu=preload_corpus_to_gpu, + ) + init_info.update({"model_id": model_id}) + return retriever, model_id, init_info + + elif backend == "llama-nemotron-embed-vl-1b-v2": + device = str(cfg.pop("device", "auto")) + doc_modality = str(cfg.pop("doc_modality", "image_text")) + doc_max_length = cfg.pop("doc_max_length", "auto") + + # Auto-detect: fall back to text-only when the corpus has no images + # (e.g. BRIGHT text-only datasets). + if doc_modality != "text" and not any("image" in doc for doc in corpus[:5]): + doc_modality = "text" + + query_max_length = int(cfg.pop("query_max_length", 10240)) + corpus_batch_size = int(cfg.pop("corpus_batch_size", 4)) + corpus_chunk_size = int(cfg.pop("corpus_chunk_size", 4096)) + preload_corpus_to_gpu = bool(cfg.pop("preload_corpus_to_gpu", False)) + max_input_tiles = int(cfg.pop("max_input_tiles", 6)) + use_thumbnail = bool(cfg.pop("use_thumbnail", True)) + + retriever.init( + dataset_name=dataset_name, + corpus_ids=corpus_ids, + corpus=corpus, + model_id=model_id, + device=device, + top_k=top_k, + doc_modality=doc_modality, + doc_max_length=doc_max_length, + query_max_length=query_max_length, + corpus_batch_size=corpus_batch_size, + corpus_chunk_size=corpus_chunk_size, + cache_dir="cache/nemotron_vl_dense", + preload_corpus_to_gpu=preload_corpus_to_gpu, + max_input_tiles=max_input_tiles, + use_thumbnail=use_thumbnail, + ) + init_info.update( + { + "model_id": model_id, + "device": device, + "doc_modality": doc_modality, + "doc_max_length": doc_max_length, + "query_max_length": query_max_length, + "max_input_tiles": max_input_tiles, + "use_thumbnail": use_thumbnail, + "corpus_batch_size": corpus_batch_size, + "corpus_chunk_size": corpus_chunk_size, + "preload_corpus_to_gpu": preload_corpus_to_gpu, + } + ) + active = _NemotronEmbedVLAdapter(retriever) + return active, model_id, init_info + + else: # nemotron-colembed-vl-8b-v2 + corpus_batch_size = int(cfg.pop("corpus_batch_size", 8)) + corpus_chunk_size = int(cfg.pop("corpus_chunk_size", 256)) + preload_corpus_to_gpu = bool(cfg.pop("preload_corpus_to_gpu", False)) + max_input_tiles = int(cfg.pop("max_input_tiles", 8)) + use_thumbnail = bool(cfg.pop("use_thumbnail", True)) + cache_dir = str(cfg.pop("cache_dir", "cache/nemotron_colembed_vl_v2")) + + retriever.init( + dataset_name=str(dataset_name), + corpus_ids=corpus_ids, + corpus=corpus, + model_id=str(model_id), + device="cuda", + top_k=top_k, + corpus_batch_size=corpus_batch_size, + corpus_chunk_size=corpus_chunk_size, + cache_dir=cache_dir, + preload_corpus_to_gpu=preload_corpus_to_gpu, + max_input_tiles=max_input_tiles, + use_thumbnail=use_thumbnail, + ) + init_info.update({"model_id": model_id}) + return retriever, model_id, init_info diff --git a/retrieval-bench/src/retrieval_bench/pipelines/dense.py b/retrieval-bench/src/retrieval_bench/pipelines/dense.py new file mode 100644 index 000000000..9af791813 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/pipelines/dense.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unified dense retrieval pipeline supporting multiple backends. + +Replaces the four standalone example pipelines: +- nemo_reasoning_dense_pipeline.py +- llama_nemotron_embed_vl_1b_v2_dense_pipeline.py +- nemotron_colembed_vl_8b_v2_pipeline.py +- colembed_pipeline.py +""" + +from __future__ import annotations + +import sys +import time +from typing import Any, Dict, List, Optional, Tuple + +try: + import torch +except ImportError: + print("Error: Required GPU dependencies not installed.") + print("Please install: pip install torch") + sys.exit(1) + +from vidore_benchmark.pipeline_evaluation.base_pipeline import BasePipeline +from retrieval_bench.pipelines.backends import VALID_BACKENDS, infer_bright_task_key, init_backend + + +class DenseRetrievalPipeline(BasePipeline): + """ + Dense retrieval pipeline with pluggable backend. + + Backends: + - llama-nv-embed-reasoning-3b (text-only dense, BRIGHT-style prefixes) + - llama-nemoretriever-colembed-3b-v1 (text-only late-interaction ColBERT) + - llama-nemotron-embed-vl-1b-v2 (multimodal dense, image+text) + - nemotron-colembed-vl-8b-v2 (multimodal late-interaction ColBERT) + """ + + def __init__(self, *, backend: str, top_k: int = 100, **kwargs: Any) -> None: + if backend not in VALID_BACKENDS: + raise ValueError(f"Unknown backend {backend!r}. " f"Must be one of: {', '.join(sorted(VALID_BACKENDS))}") + self.backend = backend + self.model_id = backend + self.top_k = int(top_k) + self._backend_overrides = dict(kwargs) + + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This pipeline requires a GPU.") + sys.exit(1) + + def index(self, corpus_ids: List[str], corpus_images: List[Any], corpus_texts: List[str]) -> None: + super().index(corpus_ids=corpus_ids, corpus_images=corpus_images, corpus_texts=corpus_texts) + + dataset_name = self.dataset_name + task_key = infer_bright_task_key(dataset_name) + + corpus = [{"image": img, "markdown": md} for img, md in zip(corpus_images, corpus_texts)] + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t_init0 = time.perf_counter() + + active_retriever, effective_model_id, init_info = init_backend( + self.backend, + dataset_name=dataset_name, + corpus_ids=corpus_ids, + corpus=corpus, + top_k=self.top_k, + task_key=task_key, + overrides=self._backend_overrides or None, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + retriever_init_ms = (time.perf_counter() - t_init0) * 1000.0 + + print(f"Using backend {self.backend} ({effective_model_id})") + print(f"Retriever init: {retriever_init_ms / 1000.0:.2f}s") + + self._active_retriever = active_retriever + self._init_info = init_info + self._retriever_init_ms = retriever_init_ms + + def search( + self, + query_ids: List[str], + queries: List[str], + ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Any]]: + results: Dict[str, Dict[str, float]] = {} + per_query_ms: Dict[str, float] = {} + + excluded_ids_by_query = getattr(self, "excluded_ids_by_query", None) or {} + + try: + for q_idx, (query_id, query_text) in enumerate(zip(query_ids, queries)): + if q_idx % 25 == 0: + print(f" Retrieving for query {q_idx + 1}/{len(query_ids)}...") + + excluded_ids: Optional[List[str]] = None + if isinstance(excluded_ids_by_query, dict): + ex = excluded_ids_by_query.get(str(query_id)) + if isinstance(ex, list): + excluded_ids = [str(d) for d in ex if str(d) != "N/A"] + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.perf_counter() + + try: + results[str(query_id)] = self._active_retriever.retrieve(str(query_text), excluded_ids=excluded_ids) + except TypeError: + results[str(query_id)] = self._active_retriever.retrieve(str(query_text)) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + per_query_ms[str(query_id)] = (t1 - t0) * 1000.0 + finally: + self._active_retriever.unload() + + total_ms = sum(per_query_ms.values()) + print(f"\nRetrieval complete in {total_ms / 1000.0:.2f} seconds (sum of per-query times)") + if query_ids: + print(f"Average time per query: {total_ms / len(query_ids):.2f} ms") + + infos: Dict[str, Any] = { + **self._init_info, + "retriever_init_ms": float(self._retriever_init_ms), + "per_query_retrieval_time_milliseconds": per_query_ms, + } + return results, infos diff --git a/retrieval-bench/src/retrieval_bench/prompts/__init__.py b/retrieval-bench/src/retrieval_bench/prompts/__init__.py new file mode 100644 index 000000000..a51bd4303 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/prompts/__init__.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Prompt registries and small prompt-building helpers. +""" + +from .bright_instructions import BRIGHT_SHORT_INSTRUCTIONS, get_task_description # noqa: F401 + +__all__ = [ + "BRIGHT_SHORT_INSTRUCTIONS", + "get_task_description", +] diff --git a/retrieval-bench/src/retrieval_bench/prompts/bright_instructions.py b/retrieval-bench/src/retrieval_bench/prompts/bright_instructions.py new file mode 100644 index 000000000..5525335eb --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/prompts/bright_instructions.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +""" +Short instruction prompts for BRIGHT datasets (future integration). + +These are the task descriptions that the BRIGHT authors used to train/evaluate +instruction-tuned embedding models. + +We keep them here so pipelines can reference them by a stable key (e.g. "theoremqa_theorems"). +""" + +from typing import Optional + + +BRIGHT_TASKS_POST: set[str] = { + "biology", + "earth_science", + "economics", + "psychology", + "robotics", + "stackoverflow", + "sustainable_living", +} + +BRIGHT_TASKS_QUESTION: set[str] = { + "pony", +} + +# These are derived to match BRIGHT's `configs/*/.json` instruction text. +# Important: we return the task description WITHOUT the "Instruct:" / "Query:" wrappers, +# since our retriever wraps as: +# Instruct: \nQuery: +BRIGHT_SHORT_INSTRUCTIONS: dict[str, str] = { + # StackExchange-style posts + # Matches BRIGHT configs (no trailing period): + # "Instruct: Given a {task} post, retrieve relevant passages that help answer the post\nQuery: " + "biology": "Given a biology post, retrieve relevant passages that help answer the post", + "earth_science": "Given a earth_science post, retrieve relevant passages that help answer the post", + "economics": "Given a economics post, retrieve relevant passages that help answer the post", + "psychology": "Given a psychology post, retrieve relevant passages that help answer the post", + "robotics": "Given a robotics post, retrieve relevant passages that help answer the post", + "stackoverflow": "Given a stackoverflow post, retrieve relevant passages that help answer the post", + "sustainable_living": "Given a sustainable_living post, retrieve relevant passages that help answer the post", + # Coding + # "Instruct: Given a coding problem, retrieve relevant examples that help answer the problem\nQuery: " + "leetcode": "Given a coding problem, retrieve relevant examples that help answer the problem", + # Pony + # "Instruct: Given a {task} question, retrieve relevant passages that help answer the question\nQuery: " + "pony": "Given a pony question, retrieve relevant passages that help answer the question", + # Theorem-based + # "Instruct: Given a Math problem, retrieve relevant examples/theorems that help answer the problem\nQuery: " + "aops": "Given a Math problem, retrieve relevant examples that help answer the problem", + "theoremqa_questions": "Given a Math problem, retrieve relevant examples that help answer the problem", + "theoremqa_theorems": "Given a Math problem, retrieve relevant theorems that help answer the problem", +} + + +# --------------------------------------------------------------------------- +# Nemo reasoning retriever prompt formatting +# --------------------------------------------------------------------------- +# +# Some instruction-tuned dense retrievers (including our Nemo reasoning checkpoint) +# were trained with a *full* query prefix that already includes the "Instruct:" and +# "Query:" wrappers. For those models, we replicate the training-time formatting at +# inference time by doing: +# formatted_query = prefix + query_text +# +# We normalize all returned prefixes to end with: "Query: " (colon + space). +BRIGHT_NEMO_QUERY_PREFIXES: dict[str, str] = { + "biology": "Instruct: Given a Biology post, retrieve relevant passages that help answer the post.\nQuery:", + "earth_science": "Instruct: Given an Earth Science post, retrieve relevant passages that help answer the post.\nQuery:", + "economics": "Instruct: Given an Economics post, retrieve relevant passages that help answer the post.\nQuery:", + "psychology": "Instruct: Given a Psychology post, retrieve relevant passages that help answer the post.\nQuery:", + "robotics": "Instruct: Given a Robotics post, retrieve relevant passages that help answer the post.\nQuery:", + "stackoverflow": "Instruct: Given a Stack Overflow post, retrieve relevant passages that help answer the post.\nQuery:", + "sustainable_living": "Instruct: Given a Sustainable Living post, retrieve relevant passages that help answer the post.\nQuery:", + "leetcode": "Instruct: Given a Coding problem, retrieve relevant examples that help answer the problem.\nQuery:", + "pony": "Instruct: Given a Pony question, retrieve relevant passages that help answer the question.\nQuery:", + "aops": "Instruct: Given a Math problem, retrieve relevant examples that help answer the problem.\nQuery:", + "theoremqa_questions": "Instruct: Given a Math problem, retrieve relevant examples that help answer the problem.\nQuery:", + "theoremqa_theorems": "Instruct: Given a Math problem, retrieve relevant theorems that help answer the problem.\nQuery:", +} + +# Doc prefix used for Nemo reasoning retrieval over text passages. +NEMO_REASONING_PASSAGE_PREFIX: str = "passage: " + + +def _ensure_query_colon_space(prefix: str) -> str: + p = str(prefix or "") + needle = "Query:" + idx = p.rfind(needle) + if idx == -1: + return (p.rstrip() + " ") if p.strip() else "" + head = p[: idx + len(needle)] + return head + " " + + +def get_bright_query_prefix_nemo(*, task_key: Optional[str], fallback: str) -> str: + """ + Resolve a full Nemo-style query prefix for a BRIGHT task key, else use fallback. + + Returned value is normalized to end with: "Query: " (colon + space). + """ + if isinstance(task_key, str): + v = BRIGHT_NEMO_QUERY_PREFIXES.get(task_key.strip(), None) + if isinstance(v, str) and v.strip(): + return _ensure_query_colon_space(v) + return _ensure_query_colon_space(str(fallback or "")) + + +def get_task_description(*, task_key: Optional[str], fallback: str) -> str: + """ + Resolve a task description from a BRIGHT short key (if provided), otherwise use fallback. + + This is intentionally small and permissive; callers control the default behavior via `fallback`. + """ + if isinstance(task_key, str): + v = BRIGHT_SHORT_INSTRUCTIONS.get(task_key.strip(), None) + if isinstance(v, str) and v.strip(): + return v.strip() + return str(fallback or "").strip() diff --git a/retrieval-bench/src/retrieval_bench/singletons/__init__.py b/retrieval-bench/src/retrieval_bench/singletons/__init__.py new file mode 100644 index 000000000..e3b65f4dd --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Lightweight singleton-style components used by pipelines. +""" + +__all__ = [] diff --git a/retrieval-bench/src/retrieval_bench/singletons/_shared.py b/retrieval-bench/src/retrieval_bench/singletons/_shared.py new file mode 100644 index 000000000..6e2c7b663 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/_shared.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared utilities used across singleton retrievers.""" + +from __future__ import annotations + +import hashlib +import logging +from typing import Optional, Sequence + +logger = logging.getLogger(__name__) + +try: + import torch +except ImportError: + torch = None # type: ignore[assignment] + + +def hash_corpus_ids10(corpus_ids: Sequence[str]) -> str: + """Return a 10-char hex hash of corpus IDs for cache keying.""" + h = hashlib.sha256() + for cid in corpus_ids: + h.update(str(cid).encode("utf-8")) + h.update(b"\n") + return h.hexdigest()[:10] + + +def slugify(value: str) -> str: + """Make a filesystem-friendly string from a model/dataset name.""" + v = (value or "").strip().replace("/", "__") + return v or "unnamed" + + +def try_preload_corpus_to_gpu(corpus_embeddings_cpu: "torch.Tensor", device: str) -> "Optional[torch.Tensor]": + """ + Attempt to move corpus embeddings to GPU; return None on OOM. + + Handles different OOM exception types across PyTorch versions. + """ + try: + return corpus_embeddings_cpu.to(device, non_blocking=True) + except Exception as e: + oom_types = tuple( + t + for t in ( + getattr(torch, "OutOfMemoryError", None), + getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None), + ) + if isinstance(t, type) + ) + + is_oom = False + if oom_types and isinstance(e, oom_types): + is_oom = True + elif isinstance(e, RuntimeError) and "out of memory" in str(e).lower(): + is_oom = True + + if not is_oom: + raise + + logger.debug("OOM preloading corpus to GPU; falling back to CPU scoring: %s", e) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return None diff --git a/retrieval-bench/src/retrieval_bench/singletons/bm25s_retriever.py b/retrieval-bench/src/retrieval_bench/singletons/bm25s_retriever.py new file mode 100644 index 000000000..ebc609210 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/bm25s_retriever.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +BM25S singleton retriever (keyword / lexical search). + +This module encapsulates BM25 retrieval via `bm25s` behind a small interface: + +- init(...): build/load a cached BM25 index once per dataset/corpus +- retrieve(query): run BM25 retrieval for a single query string +- unload(): free memory + +Important: BM25S returns positional doc indices; we always map indices back to +the provided `corpus_ids` so the pipeline returns evaluator-compatible ids. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from retrieval_bench.singletons._shared import hash_corpus_ids10 as _hash_corpus_ids10 +from retrieval_bench.utils.corpus_language import corpus_language + + +def _stemmer_language(dataset_name: str) -> str: + # Workflow rule: stemmer is keyed off corpus language, which depends on dataset. + # (Delegated to shared helper so other components can reuse this decision.) + return corpus_language(dataset_name) + + +def _stopwords_language(stemmer_language: str) -> str: + # bm25s uses short language codes for built-in stopwords. + return "fr" if stemmer_language == "french" else "en" + + +@dataclass(frozen=True, slots=True) +class _CacheMeta: + dataset_name: str + stemmer_language: str + num_docs: int + corpus_ids_hash10: str + + def to_json(self) -> Dict[str, Any]: + return { + "dataset_name": self.dataset_name, + "stemmer_language": self.stemmer_language, + "num_docs": int(self.num_docs), + "corpus_ids_hash10": self.corpus_ids_hash10, + } + + +class _Bm25sState: + def __init__( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus_markdown: Sequence[str], + top_k: int, + cache_dir: Path, + ) -> None: + self.dataset_name = str(dataset_name) + self.top_k = int(top_k) + self.cache_dir = cache_dir + + self.corpus_ids: List[str] = [str(x) for x in corpus_ids] + self.corpus_markdown: List[str] = [str(x) for x in corpus_markdown] + self.stemmer_language = _stemmer_language(self.dataset_name) + self.stopwords_language = _stopwords_language(self.stemmer_language) + self.corpus_ids_hash10 = self._corpus_ids_hash10() + + self._bm25 = None + self._stemmer = None + + self._load_or_build_index() + + def _cache_key_hash10(self) -> str: + key = f"{self.dataset_name}::{self.stemmer_language}::bm25s" + return hashlib.sha256(key.encode("utf-8")).hexdigest()[:10] + + def _corpus_ids_hash10(self) -> str: + return _hash_corpus_ids10(self.corpus_ids) + + def _index_dir(self) -> Path: + ds_slug = self.dataset_name.replace("/", "__") + key_hash = self._cache_key_hash10() + return self.cache_dir / f"bm25s_index__{ds_slug}__{self.stemmer_language}__{key_hash}" + + def _meta_path(self) -> Path: + return self._index_dir() / "meta.json" + + def _build_meta(self) -> _CacheMeta: + return _CacheMeta( + dataset_name=self.dataset_name, + stemmer_language=self.stemmer_language, + num_docs=len(self.corpus_ids), + corpus_ids_hash10=self.corpus_ids_hash10, + ) + + def _load_meta(self) -> Optional[_CacheMeta]: + p = self._meta_path() + if not p.exists(): + return None + try: + data = json.loads(p.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return _CacheMeta( + dataset_name=str(data.get("dataset_name", "")), + stemmer_language=str(data.get("stemmer_language", "")), + num_docs=int(data.get("num_docs", -1)), + corpus_ids_hash10=str(data.get("corpus_ids_hash10", "")), + ) + except Exception: + return None + + def _write_meta_atomic(self, meta: _CacheMeta) -> None: + p = self._meta_path() + p.parent.mkdir(parents=True, exist_ok=True) + tmp = p.with_suffix(".json.tmp") + tmp.write_text(json.dumps(meta.to_json(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + os.replace(tmp, p) + + def _meta_matches(self, meta: _CacheMeta) -> bool: + try: + if meta.dataset_name != self.dataset_name: + return False + if meta.stemmer_language != self.stemmer_language: + return False + if int(meta.num_docs) != len(self.corpus_ids): + return False + if meta.corpus_ids_hash10 != self.corpus_ids_hash10: + return False + return True + except Exception: + return False + + def _load_or_build_index(self) -> None: + try: + import bm25s # type: ignore + import Stemmer # type: ignore + except ImportError as e: # pragma: no cover + raise ImportError( + "Missing dependencies for BM25S singleton. Install with `pip install bm25s PyStemmer` " + "(or ensure they are in project dependencies)." + ) from e + + self._stemmer = Stemmer.Stemmer(self.stemmer_language) + index_dir = self._index_dir() + meta = self._load_meta() + + if meta is not None and self._meta_matches(meta): + try: + self._bm25 = bm25s.BM25.load(str(index_dir)) + return + except Exception: + # Fall through to rebuild. + self._bm25 = None + + # Build from scratch and persist. + corpus_tokens = bm25s.tokenize( + self.corpus_markdown, + stopwords=self.stopwords_language, + stemmer=self._stemmer, + ) + bm25 = bm25s.BM25() + bm25.index(corpus_tokens) + + index_dir.mkdir(parents=True, exist_ok=True) + bm25.save(str(index_dir)) + self._write_meta_atomic(self._build_meta()) + self._bm25 = bm25 + + def retrieve_one( + self, query: str, *, return_markdown: bool = False + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + if self._bm25 is None or self._stemmer is None: + raise RuntimeError("BM25S retriever not initialized. Call retriever.init(...) first.") + + import bm25s # type: ignore + + query_tokens = bm25s.tokenize( + [str(query)], + stopwords=self.stopwords_language, + stemmer=self._stemmer, + ) + + k = min(int(self.top_k), len(self.corpus_ids)) + results_idx, scores = self._bm25.retrieve(query_tokens, k=k) + + # bm25s returns arrays of shape (n_queries, k); we pass a single query. + idxs = results_idx[0] + scs = scores[0] + + run: Dict[str, float] = {} + markdown_by_id: Dict[str, str] = {} + + for doc_pos, score in zip(list(idxs), list(scs)): + pos = int(doc_pos) + if pos < 0 or pos >= len(self.corpus_ids): + continue + doc_id = self.corpus_ids[pos] + run[doc_id] = float(score) + if return_markdown: + markdown_by_id[doc_id] = self.corpus_markdown[pos] + + if not return_markdown: + return run + return run, markdown_by_id + + +class Bm25sSingletonRetriever: + """ + A module-level singleton facade for bm25s retrieval. + + This wrapper provides explicit lifecycle calls (init/unload) while still + maintaining a single global instance and hiding indexing/caching complexity. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + self._state: Optional[_Bm25sState] = None + + def init( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Dict[str, Any]], + top_k: int = 100, + cache_dir: str | Path = "cache/bm25s", + ) -> None: + with self._lock: + cache_dir = Path(cache_dir) + + corpus_markdown = [str(doc.get("markdown", "")) for doc in corpus] + + # If already initialized for the same dataset and same corpus ids, keep as-is (fast path). + if ( + self._state is not None + and self._state.dataset_name == str(dataset_name) + and self._state.corpus_ids_hash10 == _hash_corpus_ids10(corpus_ids) + ): + self._state.top_k = int(top_k) + self._state.cache_dir = cache_dir + return + + self._state = _Bm25sState( + dataset_name=str(dataset_name), + corpus_ids=corpus_ids, + corpus_markdown=corpus_markdown, + top_k=int(top_k), + cache_dir=cache_dir, + ) + + def retrieve( + self, query: str, *, return_markdown: bool = False + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + with self._lock: + if self._state is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + return self._state.retrieve_one(query, return_markdown=return_markdown) + + def unload(self) -> None: + with self._lock: + self._state = None + + +# --------------------------------------------------------------------------- +# Module-level singleton instance +# --------------------------------------------------------------------------- +retriever = Bm25sSingletonRetriever() diff --git a/retrieval-bench/src/retrieval_bench/singletons/colembed_retriever.py b/retrieval-bench/src/retrieval_bench/singletons/colembed_retriever.py new file mode 100644 index 000000000..24a6972ae --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/colembed_retriever.py @@ -0,0 +1,342 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +ColEmbed singleton retriever (deep module). + +This module encapsulates all heavy retrieval operations for the NeMo Retriever +ColEmbed model behind a small interface: + +- init(...): load model + corpus embeddings (cached) once +- retrieve(query): run retrieval for a single query string +- unload(): free GPU/CPU memory + +The exported `retriever` object is a module-level singleton. +""" + +from __future__ import annotations + +import hashlib +import logging +import os +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +logger = logging.getLogger(__name__) + +try: + import torch +except ImportError as e: # pragma: no cover + raise ImportError( + "Required dependencies not installed for ColEmbed retriever. " + "Please install at least: torch (and for actual retrieval: transformers, optionally flash-attn)." + ) from e + +from retrieval_bench.singletons._shared import try_preload_corpus_to_gpu as _try_preload_corpus_to_gpu + + +class _ColEmbedState: + def __init__( + self, + *, + model_id: str, + device: str, + corpus_chunk_size: int, + batch_size: int, + corpus_batch_size: int, + top_k: int, + cache_dir: Path, + ) -> None: + self.model_id = model_id + self.device = device + self.corpus_chunk_size = corpus_chunk_size + self.batch_size = batch_size + self.corpus_batch_size = corpus_batch_size + self.top_k = top_k + self.cache_dir = cache_dir + + self.dataset_name: Optional[str] = None + self.corpus_ids: Optional[List[str]] = None + self.corpus_markdown: Optional[List[str]] = None + self.corpus_embeddings_cpu: Optional[torch.Tensor] = None + self.corpus_embeddings_gpu: Optional[torch.Tensor] = None + + self.model = self._load_model() + + def _load_model(self): + # CUDA is required (matching the original pipeline behavior). + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. ColEmbed retriever requires an NVIDIA GPU.") + + # Compatibility shim for torch/transformers version skew. + from retrieval_bench.utils.torch_compat import patch_torch_is_autocast_enabled + + patch_torch_is_autocast_enabled() + + # Lazy import so importing this module doesn't require transformers. + from transformers import AutoModel # type: ignore + + try: + model = AutoModel.from_pretrained( + self.model_id, + device_map="cuda", + torch_dtype=torch.bfloat16, + trust_remote_code=True, + attn_implementation="flash_attention_2", + ) + except Exception: + model = AutoModel.from_pretrained( + self.model_id, + device_map="cuda", + torch_dtype=torch.bfloat16, + trust_remote_code=True, + attn_implementation="eager", + ) + + model.eval() + return model + + def _corpus_cache_path(self, dataset_name: str) -> Path: + dataset_slug = dataset_name.replace("/", "__") + model_slug = self.model_id.split("/")[-1].replace("/", "__") + key = f"{dataset_name}::{self.model_id}" + key_hash = hashlib.sha256(key.encode("utf-8")).hexdigest()[:10] + filename = f"corpus_embeddings__{dataset_slug}__{model_slug}__{key_hash}.pt" + return self.cache_dir / filename + + def _embed_corpus_batched(self, corpus: Sequence[Any]) -> torch.Tensor: + corpus_embeddings: List[torch.Tensor] = [] + num_batches = (len(corpus) + self.corpus_batch_size - 1) // self.corpus_batch_size + + for i in range(0, len(corpus), self.corpus_batch_size): + batch_idx = i // self.corpus_batch_size + 1 + batch = corpus[i : i + self.corpus_batch_size] + + with torch.no_grad(): + batch_embeddings = self.model.forward_passages(batch, batch_size=len(batch)) + corpus_embeddings.append(batch_embeddings.cpu()) + + del batch_embeddings + torch.cuda.empty_cache() + + # lightweight progress marker (avoid spamming logs) + if batch_idx % 25 == 0 or batch_idx == num_batches: + _ = torch.cuda.memory_allocated() + + return torch.cat(corpus_embeddings, dim=0) + + def _load_or_build_corpus_embeddings( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Any], + ) -> torch.Tensor: + cache_path = self._corpus_cache_path(dataset_name) + cache_path.parent.mkdir(parents=True, exist_ok=True) + + if cache_path.exists(): + try: + emb = torch.load(cache_path, map_location="cpu") + if not isinstance(emb, torch.Tensor): + raise TypeError(f"Expected torch.Tensor in cache, got {type(emb)}") + if emb.shape[0] != len(corpus_ids): + raise ValueError( + f"Cached embeddings mismatch: cached={emb.shape[0]} vs corpus_ids={len(corpus_ids)}" + ) + return emb + except Exception: + logger.debug("Cache load failed for %s, recomputing", cache_path, exc_info=True) + + t0 = time.time() + emb = self._embed_corpus_batched(corpus) + elapsed = time.time() - t0 + print(f"[cache] corpus embedding took {elapsed:.1f}s ({len(corpus)} docs)") + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + torch.save(emb, tmp_path) + os.replace(tmp_path, cache_path) + return emb + + def _embed_query(self, query: str) -> torch.Tensor: + # Returns CPU tensor [seq_len, embed_dim] + with torch.no_grad(): + q_emb = self.model.forward_queries([query], batch_size=1).cpu() + return q_emb[0] # [seq_len, dim] + + def _score_query(self, query_embedding_cpu: torch.Tensor) -> torch.Tensor: + if self.corpus_embeddings_cpu is None: + raise RuntimeError("corpus_embeddings_cpu is not set; call init() first") + + num_corpus = self.corpus_embeddings_cpu.shape[0] + scores_cpu = torch.empty(num_corpus, dtype=torch.float32, device="cpu") + chunk = self.corpus_chunk_size + device = self.device + + with torch.no_grad(): + q_gpu = query_embedding_cpu.to(device, non_blocking=True) # [q_seq, dim] + q_t = q_gpu.transpose(0, 1) # [dim, q_seq] + + for c_start in range(0, num_corpus, chunk): + c_end = min(c_start + chunk, num_corpus) + + if self.corpus_embeddings_gpu is not None: + c_gpu = self.corpus_embeddings_gpu[c_start:c_end] + else: + c_gpu = self.corpus_embeddings_cpu[c_start:c_end].to(device, non_blocking=True) + + token_sims = torch.matmul(c_gpu, q_t) # [chunk, c_seq, q_seq] + chunk_scores = token_sims.max(dim=1).values.float().sum(dim=1) # [chunk] + scores_cpu[c_start:c_end] = chunk_scores.cpu() + + return scores_cpu + + def retrieve_one( + self, query: str, *, return_markdown: bool = False + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + if self.corpus_ids is None or self.corpus_embeddings_cpu is None or self.corpus_markdown is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + + query_embedding_cpu = self._embed_query(query) + scores_cpu = self._score_query(query_embedding_cpu) + + k = min(self.top_k, len(self.corpus_ids)) + topk_scores, topk_indices = torch.topk(scores_cpu, k) + + corpus_ids = self.corpus_ids + run = {corpus_ids[int(idx)]: float(score) for idx, score in zip(topk_indices.tolist(), topk_scores.tolist())} + + if not return_markdown: + return run + + corpus_markdown = self.corpus_markdown + markdown_by_id = {corpus_ids[int(idx)]: corpus_markdown[int(idx)] for idx in topk_indices.tolist()} + return run, markdown_by_id + + +class ColEmbedSingletonRetriever: + """ + A module-level singleton facade for ColEmbed retrieval. + + This wrapper provides explicit lifecycle calls (init/unload) while still + maintaining a single global instance and hiding all retrieval complexity. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + self._state: Optional[_ColEmbedState] = None + + def init( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Dict[str, Any]], + model_id: str = "nvidia/llama-nemoretriever-colembed-1b-v1", + device: str = "cuda", + top_k: int = 100, + batch_size: int = 32, + corpus_batch_size: int = 32, + corpus_chunk_size: int = 256, + cache_dir: str | Path = "cache", + preload_corpus_to_gpu: bool = True, + ) -> None: + """ + Initialize (or re-initialize) the singleton for a given dataset/corpus. + + - Model is loaded once per process (unless model_id/device changes). + - Corpus embeddings are loaded from cache or computed once per dataset. + """ + with self._lock: + cache_dir = Path(cache_dir) + + # If state exists but model configuration changed, fully unload. + if self._state is not None and (self._state.model_id != model_id or self._state.device != device): + self.unload() + + if self._state is None: + self._state = _ColEmbedState( + model_id=model_id, + device=device, + corpus_chunk_size=corpus_chunk_size, + batch_size=batch_size, + corpus_batch_size=corpus_batch_size, + top_k=top_k, + cache_dir=cache_dir, + ) + else: + # Update tunables (safe). + self._state.top_k = top_k + self._state.batch_size = batch_size + self._state.corpus_batch_size = corpus_batch_size + self._state.corpus_chunk_size = corpus_chunk_size + self._state.cache_dir = cache_dir + + # If already initialized for the same dataset with same corpus_ids length, keep as-is. + if ( + self._state.dataset_name == dataset_name + and self._state.corpus_ids is not None + and len(self._state.corpus_ids) == len(corpus_ids) + ): + return + + corpus_images = [doc["image"] for doc in corpus] + corpus_markdown = [doc["markdown"] for doc in corpus] + + # (Re)load corpus embeddings for this dataset/corpus. + corpus_embeddings_cpu = self._state._load_or_build_corpus_embeddings( + dataset_name=dataset_name, corpus_ids=corpus_ids, corpus=corpus_images + ) + + self._state.dataset_name = dataset_name + self._state.corpus_ids = list(corpus_ids) + self._state.corpus_markdown = corpus_markdown + self._state.corpus_embeddings_cpu = corpus_embeddings_cpu + + # Optional preload to GPU for faster repeated retrieval. + self._state.corpus_embeddings_gpu = None + if preload_corpus_to_gpu: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu( + corpus_embeddings_cpu, self._state.device + ) + + def retrieve( + self, query: str, *, return_markdown: bool = False + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + """ + Retrieve top-k corpus items for a single query. + + Note: This method intentionally uses a lock to remain safe if called from + multiple threads in the future (GPU inference + shared model). + """ + with self._lock: + if self._state is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + return self._state.retrieve_one(query, return_markdown=return_markdown) + + def unload(self) -> None: + """Free model + embeddings and release GPU memory.""" + with self._lock: + if self._state is None: + return + + try: + if self._state.corpus_embeddings_gpu is not None: + del self._state.corpus_embeddings_gpu + if self._state.corpus_embeddings_cpu is not None: + del self._state.corpus_embeddings_cpu + if self._state.corpus_markdown is not None: + del self._state.corpus_markdown + if getattr(self._state, "model", None) is not None: + del self._state.model + finally: + self._state = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Module-level singleton instance +# --------------------------------------------------------------------------- +retriever = ColEmbedSingletonRetriever() diff --git a/retrieval-bench/src/retrieval_bench/singletons/hf_dense_retriever.py b/retrieval-bench/src/retrieval_bench/singletons/hf_dense_retriever.py new file mode 100644 index 000000000..859f5b71c --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/hf_dense_retriever.py @@ -0,0 +1,566 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +HuggingFace dense-text singleton retriever. + +This retriever is designed for instruction-tuned embedding models like: + - hanhainebula/reason-embed-llama-3.1-8b-0928 + +Workflow assumptions (intentional): +- GPU-only: this model is too large for practical CPU use in our workflow. +- Corpus documents are embedded from the ViDoRe v3 `markdown` field (text). +- Queries default to the wrapper format: + Instruct: \nQuery: + but can alternatively use a full `query_prefix` (for models trained that way). +- Pooling is configurable: mean (default) or last-token pooling. + +The exported `retriever` object is a module-level singleton with explicit lifecycle: + init(...) -> retrieve(...) -> unload() +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +import time +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +logger = logging.getLogger(__name__) + +try: + import torch + import torch.nn.functional as F # noqa: N812 +except ImportError as e: # pragma: no cover + raise ImportError("Required dependencies not installed for HF dense retriever. Install: torch") from e + +from retrieval_bench.singletons._shared import hash_corpus_ids10 as _hash_corpus_ids10 +from retrieval_bench.singletons._shared import slugify as _slugify +from retrieval_bench.singletons._shared import try_preload_corpus_to_gpu as _try_preload_corpus_to_gpu + + +def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Pool sentence embedding using the last non-pad token (handles left vs right padding). + + Copied from the reason-embed model card reference implementation. + """ + # attention_mask: [bs, seq] + # last_hidden_states: [bs, seq, dim] + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if bool(left_padding): + return last_hidden_states[:, -1] + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + +def _mean_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Mean pool over non-pad tokens. + + attention_mask: [bs, seq] (0/1) + last_hidden_states: [bs, seq, dim] + """ + # Upcast is important for numeric stability (matches training/eval reference). + last_hidden_states = last_hidden_states.to(torch.float32) + last_hidden_states_masked = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + denom = attention_mask.sum(dim=1)[..., None].to(torch.float32).clamp(min=1.0) + embedding = last_hidden_states_masked.sum(dim=1) / denom + embedding = F.normalize(embedding, dim=-1) + + # Keep stored embeddings compact (CPU cache is large); compute in fp32, store in fp16. + return embedding.to(torch.float16) + + +def _is_pathlike_model_id(model_id: str) -> bool: + m = str(model_id or "") + if not m: + return False + if m.startswith(("~", "/", "./", "../")): + return True + # If the expanded path exists locally, treat it as a path. + try: + p = Path(os.path.expanduser(m)) + return p.exists() + except Exception: + return False + + +def _normalize_model_id(model_id: str) -> str: + m = str(model_id or "") + if _is_pathlike_model_id(m): + return os.path.expanduser(m) + return m + + +def _wrap_instruct(task_description: str, query: str) -> str: + return f"Instruct: {task_description}\nQuery: {query}" + + +@dataclass(frozen=True, slots=True) +class _CacheMeta: + dataset_name: str + model_id: str + max_length: int + pooling: str + doc_prefix: str + num_docs: int + corpus_ids_hash10: str + + def to_json(self) -> Dict[str, Any]: + return { + "dataset_name": self.dataset_name, + "model_id": self.model_id, + "max_length": int(self.max_length), + "pooling": str(self.pooling), + "doc_prefix": str(self.doc_prefix), + "num_docs": int(self.num_docs), + "corpus_ids_hash10": self.corpus_ids_hash10, + } + + +class _HfDenseState: + def __init__( + self, + *, + model_id: str, + device: str, + max_length: int, + pooling: str, + doc_prefix: str, + query_prefix: Optional[str], + task_description: str, + score_scale: float, + batch_size: int, + corpus_batch_size: int, + scoring_batch_size: int, + top_k: int, + cache_dir: Path, + ) -> None: + self.model_id = str(model_id) + self.device = str(device) + self.max_length = int(max_length) + self.pooling = str(pooling) + self.doc_prefix = str(doc_prefix) + self.query_prefix = str(query_prefix) if isinstance(query_prefix, str) else None + self.task_description = str(task_description) + self.score_scale = float(score_scale) + self.batch_size = int(batch_size) + self.corpus_batch_size = int(corpus_batch_size) + self.scoring_batch_size = int(scoring_batch_size) + self.top_k = int(top_k) + self.cache_dir = cache_dir + + self.dataset_name: Optional[str] = None + self.corpus_ids: Optional[List[str]] = None + self.corpus_id_to_idx: Optional[Dict[str, int]] = None + self.corpus_markdown: Optional[List[str]] = None + self.corpus_embeddings_cpu: Optional[torch.Tensor] = None # [n, dim] float16 + self.corpus_embeddings_gpu: Optional[torch.Tensor] = None # [n, dim] float16 + + self.tokenizer, self.model = self._load_model_and_tokenizer() + + def _pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + mode = str(self.pooling or "last_token").strip().lower() + if mode in ("mean", "avg", "average"): + return _mean_pool(last_hidden_states, attention_mask) + # Default / legacy. + return _last_token_pool(last_hidden_states, attention_mask) + + def _load_model_and_tokenizer(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This dense retriever requires an NVIDIA GPU.") + if not str(self.device).startswith("cuda"): + raise RuntimeError( + f"Invalid device '{self.device}'. This dense retriever is GPU-only; use 'cuda'/'cuda:0'." + ) + + # Compatibility shim for torch/transformers version skew. + from retrieval_bench.utils.torch_compat import patch_torch_is_autocast_enabled + + patch_torch_is_autocast_enabled() + + from transformers import AutoModel, AutoTokenizer # type: ignore + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True) + model.eval() + model.to(self.device) + model.half() + return tokenizer, model + + def _index_dir(self, dataset_name: str, *, corpus_ids_hash10: str) -> Path: + ds_slug = _slugify(dataset_name) + model_slug = _slugify(self.model_id.split("/")[-1]) + doc_slug = _slugify(self.doc_prefix)[:32] + pool_slug = _slugify(self.pooling) + key = f"{dataset_name}::{self.model_id}::{self.max_length}::{pool_slug}::{self.doc_prefix}::{corpus_ids_hash10}" + key_hash = hashlib.sha256(key.encode("utf-8")).hexdigest()[:10] + return ( + self.cache_dir + / f"hf_dense__{ds_slug}__{model_slug}__len{self.max_length}__{pool_slug}__doc{doc_slug}__{key_hash}" + ) + + def _meta_path(self, dataset_name: str, *, corpus_ids_hash10: str) -> Path: + return self._index_dir(dataset_name, corpus_ids_hash10=corpus_ids_hash10) / "meta.json" + + def _emb_path(self, dataset_name: str, *, corpus_ids_hash10: str) -> Path: + return self._index_dir(dataset_name, corpus_ids_hash10=corpus_ids_hash10) / "embeddings.pt" + + def _build_meta(self, *, dataset_name: str, corpus_ids_hash10: str, num_docs: int) -> _CacheMeta: + return _CacheMeta( + dataset_name=str(dataset_name), + model_id=str(self.model_id), + max_length=int(self.max_length), + pooling=str(self.pooling), + doc_prefix=str(self.doc_prefix), + num_docs=int(num_docs), + corpus_ids_hash10=str(corpus_ids_hash10), + ) + + def _load_meta(self, dataset_name: str, *, corpus_ids_hash10: str) -> Optional[_CacheMeta]: + p = self._meta_path(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + if not p.exists(): + return None + try: + data = json.loads(p.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return _CacheMeta( + dataset_name=str(data.get("dataset_name", "")), + model_id=str(data.get("model_id", "")), + max_length=int(data.get("max_length", -1)), + pooling=str(data.get("pooling", "")), + doc_prefix=str(data.get("doc_prefix", "")), + num_docs=int(data.get("num_docs", -1)), + corpus_ids_hash10=str(data.get("corpus_ids_hash10", "")), + ) + except Exception: + return None + + def _meta_matches(self, meta: _CacheMeta, *, dataset_name: str, corpus_ids_hash10: str, num_docs: int) -> bool: + try: + if meta.dataset_name != str(dataset_name): + return False + if meta.model_id != str(self.model_id): + return False + if int(meta.max_length) != int(self.max_length): + return False + if str(meta.pooling) != str(self.pooling): + return False + if str(meta.doc_prefix) != str(self.doc_prefix): + return False + if int(meta.num_docs) != int(num_docs): + return False + if meta.corpus_ids_hash10 != str(corpus_ids_hash10): + return False + return True + except Exception: + return False + + def _write_meta_atomic(self, meta: _CacheMeta, *, dataset_name: str, corpus_ids_hash10: str) -> None: + p = self._meta_path(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + p.parent.mkdir(parents=True, exist_ok=True) + tmp = p.with_suffix(".json.tmp") + tmp.write_text(json.dumps(meta.to_json(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + os.replace(tmp, p) + + def _tokenize(self, texts: Sequence[str]) -> Dict[str, torch.Tensor]: + batch = self.tokenizer( + list(texts), + max_length=int(self.max_length), + padding=True, + truncation=True, + return_tensors="pt", + pad_to_multiple_of=8, + ) + return {k: v.to(self.device) for k, v in batch.items()} + + def _embed_texts_batched(self, texts: Sequence[str], *, batch_size: int) -> torch.Tensor: + out: List[torch.Tensor] = [] + bs = max(1, int(batch_size)) + + with torch.no_grad(): + for i in range(0, len(texts), bs): + chunk = texts[i : i + bs] + batch = self._tokenize(chunk) + outputs = self.model(**batch) + pooled = self._pool(outputs.last_hidden_state, batch["attention_mask"]) + # Mean/avg pooling path already normalizes in fp32 and returns fp16. + mode = str(self.pooling or "last_token").strip().lower() + if mode not in ("mean", "avg", "average"): + pooled = F.normalize(pooled, p=2, dim=1) + out.append(pooled.detach().to("cpu")) + return torch.cat(out, dim=0) if out else torch.empty((0, 0), dtype=torch.float16, device="cpu") + + def _load_or_build_corpus_embeddings( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus_markdown: Sequence[str], + ) -> torch.Tensor: + corpus_ids_hash10 = _hash_corpus_ids10(corpus_ids) + emb_path = self._emb_path(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + meta = self._load_meta(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + + if meta is not None and self._meta_matches( + meta, dataset_name=dataset_name, corpus_ids_hash10=corpus_ids_hash10, num_docs=len(corpus_ids) + ): + try: + emb = torch.load(emb_path, map_location="cpu") + if not isinstance(emb, torch.Tensor): + raise TypeError(f"Expected torch.Tensor in cache, got {type(emb)}") + if emb.shape[0] != len(corpus_ids): + raise ValueError(f"Cached embeddings mismatch: cached={emb.shape[0]} vs corpus={len(corpus_ids)}") + return emb + except Exception: + logger.debug("Cache load failed for %s, recomputing", emb_path, exc_info=True) + + # Build from scratch. + emb_path.parent.mkdir(parents=True, exist_ok=True) + t0 = time.time() + emb = self._embed_texts_batched(corpus_markdown, batch_size=int(self.corpus_batch_size)) + elapsed = time.time() - t0 + print(f"[cache] corpus embedding took {elapsed:.1f}s ({len(corpus_markdown)} docs)") + + tmp = emb_path.with_suffix(".pt.tmp") + torch.save(emb, tmp) + os.replace(tmp, emb_path) + self._write_meta_atomic( + self._build_meta(dataset_name=dataset_name, corpus_ids_hash10=corpus_ids_hash10, num_docs=len(corpus_ids)), + dataset_name=dataset_name, + corpus_ids_hash10=corpus_ids_hash10, + ) + return emb + + def embed_query(self, query_text: str) -> torch.Tensor: + if isinstance(self.query_prefix, str) and self.query_prefix: + q = str(self.query_prefix) + str(query_text) + else: + q = _wrap_instruct(self.task_description, str(query_text)) + emb = self._embed_texts_batched([q], batch_size=1) + if emb.ndim != 2 or emb.shape[0] != 1: + raise RuntimeError(f"Unexpected query embedding shape: {tuple(emb.shape)}") + return emb[0] # [dim] on CPU + + def score_query(self, query_embedding_cpu: torch.Tensor) -> torch.Tensor: + if self.corpus_embeddings_cpu is None: + raise RuntimeError("corpus_embeddings_cpu is not set; call init() first") + num_docs = self.corpus_embeddings_cpu.shape[0] + scores_cpu = torch.empty((num_docs,), dtype=torch.float32, device="cpu") + + chunk = max(1, int(self.scoring_batch_size)) + scale = float(self.score_scale) + + with torch.no_grad(): + q_gpu = query_embedding_cpu.to(self.device, non_blocking=True) # [dim] + q_gpu = q_gpu.unsqueeze(1) # [dim, 1] + + for c_start in range(0, num_docs, chunk): + c_end = min(c_start + chunk, num_docs) + if self.corpus_embeddings_gpu is not None: + c_gpu = self.corpus_embeddings_gpu[c_start:c_end] + else: + c_gpu = self.corpus_embeddings_cpu[c_start:c_end].to(self.device, non_blocking=True) + + # [chunk, dim] @ [dim, 1] -> [chunk, 1] + chunk_scores = torch.matmul(c_gpu, q_gpu).squeeze(1).float() * scale + scores_cpu[c_start:c_end] = chunk_scores.to("cpu") + + return scores_cpu + + def retrieve_one( + self, + query: str, + *, + return_markdown: bool = False, + excluded_ids: Optional[Sequence[str]] = None, + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + if self.corpus_ids is None or self.corpus_embeddings_cpu is None or self.corpus_markdown is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + + q_emb_cpu = self.embed_query(query) + scores_cpu = self.score_query(q_emb_cpu) + + # Apply per-query excluded ids BEFORE top-k selection (BRIGHT semantics). + # This prevents excluded docs from "stealing" slots in top-k. + if excluded_ids and self.corpus_id_to_idx: + for did in set(str(x) for x in excluded_ids): + if did == "N/A": + continue + idx = self.corpus_id_to_idx.get(did, None) + if idx is None: + continue + try: + scores_cpu[int(idx)] = float("-inf") + except Exception: + # Ignore malformed indices; keep scoring robust. + pass + + k = min(int(self.top_k), len(self.corpus_ids)) + topk_scores, topk_indices = torch.topk(scores_cpu, k) + + ids = self.corpus_ids + run = {ids[int(idx)]: float(score) for idx, score in zip(topk_indices.tolist(), topk_scores.tolist())} + + if not return_markdown: + return run + + md = self.corpus_markdown + markdown_by_id = {ids[int(idx)]: md[int(idx)] for idx in topk_indices.tolist()} + return run, markdown_by_id + + +class HfDenseSingletonRetriever: + """ + Module-level singleton facade. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + self._state: Optional[_HfDenseState] = None + + def init( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Dict[str, Any]], + model_id: str, + device: str = "cuda", + top_k: int = 100, + max_length: int = 8192, + pooling: str = "mean", + doc_prefix: str = "", + query_prefix: Optional[str] = None, + task_description: str = "Given the following post, retrieve relevant passages that help answer the post.", + score_scale: float = 100.0, + batch_size: int = 1, + corpus_batch_size: int = 1, + scoring_batch_size: int = 4096, + cache_dir: str | Path = "cache/hf_dense", + preload_corpus_to_gpu: bool = False, + ) -> None: + """ + Initialize (or re-initialize) the singleton for a given dataset/corpus. + """ + with self._lock: + cache_dir = Path(cache_dir) + model_id_norm = _normalize_model_id(model_id) + + if self._state is not None and ( + self._state.model_id != str(model_id_norm) + or self._state.device != str(device) + or int(self._state.max_length) != int(max_length) + or str(self._state.pooling) != str(pooling) + or str(self._state.doc_prefix) != str(doc_prefix) + ): + self.unload() + + if self._state is None: + self._state = _HfDenseState( + model_id=str(model_id_norm), + device=str(device), + max_length=int(max_length), + pooling=str(pooling), + doc_prefix=str(doc_prefix), + query_prefix=query_prefix, + task_description=str(task_description), + score_scale=float(score_scale), + batch_size=int(batch_size), + corpus_batch_size=int(corpus_batch_size), + scoring_batch_size=int(scoring_batch_size), + top_k=int(top_k), + cache_dir=cache_dir, + ) + else: + # Update tunables. + self._state.top_k = int(top_k) + self._state.batch_size = int(batch_size) + self._state.corpus_batch_size = int(corpus_batch_size) + self._state.scoring_batch_size = int(scoring_batch_size) + self._state.cache_dir = cache_dir + self._state.task_description = str(task_description) + self._state.query_prefix = str(query_prefix) if isinstance(query_prefix, str) else None + self._state.score_scale = float(score_scale) + + corpus_markdown = [str(doc.get("markdown", "")) for doc in corpus] + corpus_ids_list = [str(x) for x in corpus_ids] + corpus_texts_for_embed = [str(self._state.doc_prefix) + md for md in corpus_markdown] + + corpus_ids_hash10 = _hash_corpus_ids10(corpus_ids_list) + if ( + self._state.dataset_name == str(dataset_name) + and self._state.corpus_ids is not None + and _hash_corpus_ids10(self._state.corpus_ids) == corpus_ids_hash10 + and self._state.corpus_embeddings_cpu is not None + ): + # Already initialized for the same corpus; only (possibly) update GPU preload. + if preload_corpus_to_gpu and self._state.corpus_embeddings_gpu is None: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu( + self._state.corpus_embeddings_cpu, self._state.device + ) + if (not preload_corpus_to_gpu) and self._state.corpus_embeddings_gpu is not None: + self._state.corpus_embeddings_gpu = None + return + + emb_cpu = self._state._load_or_build_corpus_embeddings( + dataset_name=str(dataset_name), + corpus_ids=corpus_ids_list, + corpus_markdown=corpus_texts_for_embed, + ) + + self._state.dataset_name = str(dataset_name) + self._state.corpus_ids = corpus_ids_list + self._state.corpus_id_to_idx = {cid: i for i, cid in enumerate(corpus_ids_list)} + self._state.corpus_markdown = corpus_markdown + self._state.corpus_embeddings_cpu = emb_cpu + + self._state.corpus_embeddings_gpu = None + if preload_corpus_to_gpu: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu(emb_cpu, self._state.device) + + def retrieve( + self, query: str, *, return_markdown: bool = False, excluded_ids: Optional[Sequence[str]] = None + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + with self._lock: + if self._state is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + return self._state.retrieve_one( + str(query), + return_markdown=bool(return_markdown), + excluded_ids=excluded_ids, + ) + + def unload(self) -> None: + with self._lock: + if self._state is None: + return + try: + if self._state.corpus_embeddings_gpu is not None: + del self._state.corpus_embeddings_gpu + if self._state.corpus_embeddings_cpu is not None: + del self._state.corpus_embeddings_cpu + if self._state.corpus_markdown is not None: + del self._state.corpus_markdown + if getattr(self._state, "model", None) is not None: + del self._state.model + if getattr(self._state, "tokenizer", None) is not None: + del self._state.tokenizer + finally: + self._state = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Module-level singleton instance +# --------------------------------------------------------------------------- +retriever = HfDenseSingletonRetriever() diff --git a/retrieval-bench/src/retrieval_bench/singletons/nemotron_colembed_vl_v2_retriever.py b/retrieval-bench/src/retrieval_bench/singletons/nemotron_colembed_vl_v2_retriever.py new file mode 100644 index 000000000..711e3d280 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/nemotron_colembed_vl_v2_retriever.py @@ -0,0 +1,475 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Nemotron ColEmbed-VL v2 singleton retriever (late interaction, ColBERT-style MaxSim). + +Backed model: + - nvidia/nemotron-colembed-vl-8b-v2 + +Workflow: + - Corpus documents are embedded from the ViDoRe `image` field (PIL images). + - Queries are embedded from text. + - Scoring is explicit ColBERT MaxSim (matmul + max over doc tokens + sum over query tokens), + computed on GPU in bounded corpus chunks. + +Interface (module-level singleton): + - init(...): load model + load/build cached corpus embeddings once per dataset/corpus + - retrieve(query): retrieve top-k for a single query; optionally return markdown for agentic prompts + - unload(): free model + embeddings and release GPU memory +""" + +from __future__ import annotations + +import hashlib +import logging +import os +import time +import threading +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +logger = logging.getLogger(__name__) + +try: + import torch +except ImportError as e: # pragma: no cover + raise ImportError( + "Required dependencies not installed for Nemotron ColEmbed-VL v2 retriever. " + "Please install at least: torch (and for actual retrieval: transformers, optionally flash-attn)." + ) from e + +from retrieval_bench.singletons._shared import try_preload_corpus_to_gpu as _try_preload_corpus_to_gpu + + +def _set_tiling_knobs_if_present(model: Any, *, max_input_tiles: int, use_thumbnail: bool) -> None: + """ + Best-effort configuration of the remote-code processor tiling knobs. + + The model card recommends: + - max_input_tiles = 8 + - use_thumbnails = True + """ + + proc = getattr(model, "processor", None) + if proc is None: + return + + if hasattr(proc, "max_input_tiles"): + try: + setattr(proc, "max_input_tiles", int(max_input_tiles)) + except Exception: + pass + + # Some implementations call it `use_thumbnail`, some `use_thumbnails`. + if hasattr(proc, "use_thumbnail"): + try: + setattr(proc, "use_thumbnail", bool(use_thumbnail)) + except Exception: + pass + elif hasattr(proc, "use_thumbnails"): + try: + setattr(proc, "use_thumbnails", bool(use_thumbnail)) + except Exception: + pass + + +class _NemotronColEmbedVLV2State: + def __init__( + self, + *, + model_id: str, + device: str, + corpus_chunk_size: int, + corpus_batch_size: int, + top_k: int, + cache_dir: Path, + max_input_tiles: int, + use_thumbnail: bool, + ) -> None: + self.model_id = str(model_id) + self.device = str(device) + self.corpus_chunk_size = int(corpus_chunk_size) + self.corpus_batch_size = int(corpus_batch_size) + self.top_k = int(top_k) + self.cache_dir = cache_dir + self.max_input_tiles = int(max_input_tiles) + self.use_thumbnail = bool(use_thumbnail) + + self.dataset_name: Optional[str] = None + self.corpus_ids: Optional[List[str]] = None + self.corpus_id_to_idx: Optional[Dict[str, int]] = None + self.corpus_markdown: Optional[List[str]] = None + self.corpus_embeddings_cpu: Optional[torch.Tensor] = None + self.corpus_embeddings_gpu: Optional[torch.Tensor] = None + self.corpus_token_lengths_cpu: Optional[torch.Tensor] = None + self.corpus_token_lengths_gpu: Optional[torch.Tensor] = None + + self.model = self._load_model() + _set_tiling_knobs_if_present( + self.model, + max_input_tiles=int(self.max_input_tiles), + use_thumbnail=bool(self.use_thumbnail), + ) + + def _load_model(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Nemotron ColEmbed-VL v2 retriever requires an NVIDIA GPU.") + + # Compatibility shim for torch/transformers version skew. + from retrieval_bench.utils.torch_compat import patch_torch_is_autocast_enabled + + patch_torch_is_autocast_enabled() + + # Lazy import so importing this module doesn't require transformers. + from transformers import AutoModel # type: ignore + + def _from_pretrained(*, attn_implementation: str): + common_kwargs = { + "trust_remote_code": True, + "attn_implementation": str(attn_implementation), + } + # Newer HF stacks deprecate `torch_dtype` in favor of `dtype`. + try: + return AutoModel.from_pretrained( + self.model_id, + dtype=torch.bfloat16, + **common_kwargs, + ) + except TypeError: + return AutoModel.from_pretrained( + self.model_id, + torch_dtype=torch.bfloat16, + **common_kwargs, + ) + + # Prefer FlashAttention2 when available; fall back to eager. + try: + model = _from_pretrained(attn_implementation="flash_attention_2") + except Exception: + model = _from_pretrained(attn_implementation="eager") + + model.to("cuda") + model.eval() + return model + + def _corpus_cache_path(self, dataset_name: str) -> Path: + dataset_slug = str(dataset_name).replace("/", "__") + model_slug = self.model_id.split("/")[-1].replace("/", "__") + key = ( + f"{dataset_name}::{self.model_id}::images" + f"::max_input_tiles={int(self.max_input_tiles)}" + f"::use_thumbnail={bool(self.use_thumbnail)}" + ) + key_hash = hashlib.sha256(key.encode("utf-8")).hexdigest()[:10] + filename = f"corpus_image_embeddings__{dataset_slug}__{model_slug}__{key_hash}.pt" + return self.cache_dir / filename + + def _embed_corpus_batched(self, corpus_images: Sequence[Any]) -> Tuple[torch.Tensor, torch.Tensor]: + if not corpus_images: + return ( + torch.empty((0, 0, 0), dtype=torch.float32, device="cpu"), + torch.empty((0,), dtype=torch.int32, device="cpu"), + ) + + imgs = [img.convert("RGB") for img in corpus_images] + bs = max(1, int(self.corpus_batch_size)) + + with torch.inference_mode(): + emb = self.model.forward_images(imgs, batch_size=bs) + + if not isinstance(emb, torch.Tensor): + raise RuntimeError(f"forward_images returned unexpected type: {type(emb)}") + + emb = emb.to("cpu") + + try: + token_norms = emb.float().norm(dim=-1) + lengths = (token_norms > 1e-6).sum(dim=1).to(dtype=torch.int32) + except Exception: + lengths = torch.full((emb.shape[0],), emb.shape[1], dtype=torch.int32) + + return emb, lengths + + def _load_or_build_corpus_embeddings( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus_images: Sequence[Any], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cache_path = self._corpus_cache_path(dataset_name) + cache_path.parent.mkdir(parents=True, exist_ok=True) + + if not cache_path.exists(): + print(f"[cache] dense rebuild: {cache_path} (missing)") + else: + try: + obj = torch.load(cache_path, map_location="cpu") + if isinstance(obj, torch.Tensor): + emb = obj + lengths = torch.full((int(emb.shape[0]),), int(emb.shape[1]), dtype=torch.int32, device="cpu") + elif isinstance(obj, dict) and isinstance(obj.get("embeddings", None), torch.Tensor): + emb = obj["embeddings"] + lengths_obj = obj.get("lengths", None) + if isinstance(lengths_obj, torch.Tensor): + lengths = lengths_obj.to("cpu", dtype=torch.int32) + else: + lengths = torch.full((int(emb.shape[0]),), int(emb.shape[1]), dtype=torch.int32, device="cpu") + else: + raise TypeError(f"Expected torch.Tensor or dict cache, got {type(obj)}") + + if int(emb.shape[0]) != int(len(corpus_ids)): + raise ValueError( + f"Cached embeddings mismatch: cached={emb.shape[0]} vs corpus_ids={len(corpus_ids)}" + ) + if int(lengths.shape[0]) != int(emb.shape[0]): + lengths = torch.full((int(emb.shape[0]),), int(emb.shape[1]), dtype=torch.int32, device="cpu") + lengths = torch.clamp(lengths.to("cpu", dtype=torch.int32), min=0, max=int(emb.shape[1])) + print(f"[cache] dense hit: {cache_path}") + return emb, lengths + except Exception: + # fall through to recompute + print(f"[cache] dense rebuild: {cache_path} (load_error)") + + t0 = time.time() + emb, lengths = self._embed_corpus_batched(corpus_images) + elapsed = time.time() - t0 + print(f"[cache] corpus embedding took {elapsed:.1f}s ({len(corpus_images)} docs)") + tmp_path = cache_path.with_suffix(cache_path.suffix + ".tmp") + torch.save({"embeddings": emb, "lengths": lengths}, tmp_path) + os.replace(tmp_path, cache_path) + return emb, lengths + + def _embed_query(self, query: str) -> torch.Tensor: + # Returns CPU tensor [q_seq, dim] + with torch.no_grad(): + q_emb = self.model.forward_queries([str(query)], batch_size=1).detach().to("cpu") + if not isinstance(q_emb, torch.Tensor) or q_emb.ndim != 3 or q_emb.shape[0] != 1: + raise RuntimeError(f"Unexpected query embedding shape: {getattr(q_emb, 'shape', None)}") + return q_emb[0] + + def _score_query(self, query_embedding_cpu: torch.Tensor) -> torch.Tensor: + if self.corpus_embeddings_cpu is None: + raise RuntimeError("corpus_embeddings_cpu is not set; call init() first") + if self.corpus_token_lengths_cpu is None: + raise RuntimeError("corpus_token_lengths_cpu is not set; call init() first") + + num_corpus = self.corpus_embeddings_cpu.shape[0] + scores_cpu = torch.empty((num_corpus,), dtype=torch.float32, device="cpu") + + chunk = max(1, int(self.corpus_chunk_size)) + device = str(self.device) + + with torch.no_grad(): + q_gpu = query_embedding_cpu.to(device, non_blocking=True) # [q_seq, dim] + q_t = q_gpu.transpose(0, 1) # [dim, q_seq] + + for c_start in range(0, num_corpus, chunk): + c_end = min(c_start + chunk, num_corpus) + + if self.corpus_embeddings_gpu is not None: + c_gpu = self.corpus_embeddings_gpu[c_start:c_end] + else: + c_gpu = self.corpus_embeddings_cpu[c_start:c_end].to(device, non_blocking=True) + + if self.corpus_token_lengths_gpu is not None: + len_gpu = self.corpus_token_lengths_gpu[c_start:c_end] + else: + len_gpu = self.corpus_token_lengths_cpu[c_start:c_end].to(device, non_blocking=True) + + # c_gpu: [chunk, c_seq, dim] + # q_t: [dim, q_seq] + token_sims = torch.matmul(c_gpu, q_t) # [chunk, c_seq, q_seq] + # Mask padded tokens so they can never win the max. + c_seq = int(token_sims.shape[1]) + pos = torch.arange(c_seq, device=device).unsqueeze(0) # [1, c_seq] + valid = pos < len_gpu.to(device=device).unsqueeze(1) # [chunk, c_seq] + token_sims = token_sims.masked_fill(~valid.unsqueeze(-1), float("-inf")) + chunk_scores = token_sims.max(dim=1).values.float().sum(dim=1) # [chunk] + scores_cpu[c_start:c_end] = chunk_scores.detach().to("cpu") + + return scores_cpu + + def retrieve_one( + self, + query: str, + *, + return_markdown: bool = False, + excluded_ids: Optional[Sequence[str]] = None, + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + if self.corpus_ids is None or self.corpus_embeddings_cpu is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + + q_emb_cpu = self._embed_query(str(query)) + scores_cpu = self._score_query(q_emb_cpu) + + # Apply per-query excluded ids BEFORE top-k selection (BRIGHT semantics). + if excluded_ids and self.corpus_id_to_idx: + for did in set(str(x) for x in excluded_ids): + if did == "N/A": + continue + idx = self.corpus_id_to_idx.get(did, None) + if idx is None: + continue + try: + scores_cpu[int(idx)] = float("-inf") + except Exception: + pass + + k = min(int(self.top_k), len(self.corpus_ids)) + topk_scores, topk_indices = torch.topk(scores_cpu, k) + ids = self.corpus_ids + run = {ids[int(idx)]: float(score) for idx, score in zip(topk_indices.tolist(), topk_scores.tolist())} + + if not return_markdown: + return run + + md = self.corpus_markdown or [""] * len(ids) + markdown_by_id = {ids[int(idx)]: str(md[int(idx)]) for idx in topk_indices.tolist()} + return run, markdown_by_id + + +class NemotronColEmbedVLV2SingletonRetriever: + """ + Module-level singleton facade for Nemotron ColEmbed-VL v2 retrieval. + """ + + def __init__(self) -> None: + self._lock = threading.RLock() + self._state: Optional[_NemotronColEmbedVLV2State] = None + + def init( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Dict[str, Any]], + model_id: str = "nvidia/nemotron-colembed-vl-8b-v2", + device: str = "cuda", + top_k: int = 100, + corpus_batch_size: int = 8, + corpus_chunk_size: int = 256, + cache_dir: str | Path = "cache/nemotron_colembed_vl_v2", + preload_corpus_to_gpu: bool = False, + max_input_tiles: int = 8, + use_thumbnail: bool = True, + ) -> None: + with self._lock: + cache_dir_p = Path(cache_dir) + + # If state exists but model configuration changed, unload. + if self._state is not None and (self._state.model_id != str(model_id) or self._state.device != str(device)): + self.unload() + + if self._state is None: + self._state = _NemotronColEmbedVLV2State( + model_id=str(model_id), + device=str(device), + corpus_chunk_size=int(corpus_chunk_size), + corpus_batch_size=int(corpus_batch_size), + top_k=int(top_k), + cache_dir=cache_dir_p, + max_input_tiles=int(max_input_tiles), + use_thumbnail=bool(use_thumbnail), + ) + else: + # Update tunables. + self._state.top_k = int(top_k) + self._state.corpus_batch_size = int(corpus_batch_size) + self._state.corpus_chunk_size = int(corpus_chunk_size) + self._state.cache_dir = cache_dir_p + self._state.max_input_tiles = int(max_input_tiles) + self._state.use_thumbnail = bool(use_thumbnail) + + # If already initialized for same dataset and same corpus length, keep as-is (fast path). + if ( + self._state.dataset_name == str(dataset_name) + and self._state.corpus_ids is not None + and len(self._state.corpus_ids) == len(corpus_ids) + and self._state.corpus_embeddings_cpu is not None + and self._state.corpus_token_lengths_cpu is not None + ): + # Only adjust GPU preload. + if preload_corpus_to_gpu and self._state.corpus_embeddings_gpu is None: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu( + self._state.corpus_embeddings_cpu, self._state.device + ) + if preload_corpus_to_gpu and self._state.corpus_token_lengths_gpu is None: + self._state.corpus_token_lengths_gpu = self._state.corpus_token_lengths_cpu.to( + self._state.device, non_blocking=True + ) + if (not preload_corpus_to_gpu) and self._state.corpus_embeddings_gpu is not None: + self._state.corpus_embeddings_gpu = None + if (not preload_corpus_to_gpu) and self._state.corpus_token_lengths_gpu is not None: + self._state.corpus_token_lengths_gpu = None + return + + corpus_images = [doc["image"] for doc in corpus] + corpus_markdown = [str(doc.get("markdown", "")) for doc in corpus] + + emb_cpu, lengths_cpu = self._state._load_or_build_corpus_embeddings( + dataset_name=str(dataset_name), + corpus_ids=corpus_ids, + corpus_images=corpus_images, + ) + + self._state.dataset_name = str(dataset_name) + corpus_ids_list = [str(x) for x in corpus_ids] + self._state.corpus_ids = corpus_ids_list + self._state.corpus_id_to_idx = {cid: i for i, cid in enumerate(corpus_ids_list)} + self._state.corpus_markdown = corpus_markdown + self._state.corpus_embeddings_cpu = emb_cpu + self._state.corpus_token_lengths_cpu = lengths_cpu + + self._state.corpus_embeddings_gpu = None + self._state.corpus_token_lengths_gpu = None + if preload_corpus_to_gpu: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu(emb_cpu, self._state.device) + self._state.corpus_token_lengths_gpu = self._state.corpus_token_lengths_cpu.to( + self._state.device, non_blocking=True + ) + + def retrieve( + self, + query: str, + *, + return_markdown: bool = False, + excluded_ids: Optional[Sequence[str]] = None, + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + with self._lock: + if self._state is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + return self._state.retrieve_one( + str(query), + return_markdown=bool(return_markdown), + excluded_ids=excluded_ids, + ) + + def unload(self) -> None: + with self._lock: + if self._state is None: + return + try: + if self._state.corpus_embeddings_gpu is not None: + del self._state.corpus_embeddings_gpu + if self._state.corpus_token_lengths_gpu is not None: + del self._state.corpus_token_lengths_gpu + if self._state.corpus_embeddings_cpu is not None: + del self._state.corpus_embeddings_cpu + if self._state.corpus_token_lengths_cpu is not None: + del self._state.corpus_token_lengths_cpu + if self._state.corpus_markdown is not None: + del self._state.corpus_markdown + if getattr(self._state, "model", None) is not None: + del self._state.model + finally: + self._state = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Module-level singleton instance +# --------------------------------------------------------------------------- +retriever = NemotronColEmbedVLV2SingletonRetriever() diff --git a/retrieval-bench/src/retrieval_bench/singletons/nemotron_embed_vl_dense_retriever.py b/retrieval-bench/src/retrieval_bench/singletons/nemotron_embed_vl_dense_retriever.py new file mode 100644 index 000000000..4e6c256da --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/singletons/nemotron_embed_vl_dense_retriever.py @@ -0,0 +1,610 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Nemotron Embed VL v2 singleton dense retriever (multimodal). + +This module encapsulates heavy retrieval operations for: + nvidia/llama-nemotron-embed-vl-1b-v2 + +Interface (module-level singleton): + - init(...): load model + corpus embeddings (cached) once + - retrieve(query): run retrieval for a single query string + - unload(): free GPU/CPU memory + +Design notes: + - GPU-only by design (no CPU fallback); this model is too slow on CPU for our workflow. + - Corpus embedding supports modality: image, text, image_text (default). + - Query embedding is text-only via model.encode_queries(). + - Scores are cosine similarity via dot-product over L2-normalized embeddings. + - Corpus embeddings are cached to disk (embeddings.pt + meta.json). +""" + +from __future__ import annotations + +import hashlib +import json +import os +import time +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +try: + import torch +except ImportError as e: # pragma: no cover + raise ImportError("Required dependencies not installed for Nemotron-VL dense retriever. Install: torch") from e + +from retrieval_bench.singletons._shared import hash_corpus_ids10 as _hash_corpus_ids10 +from retrieval_bench.singletons._shared import slugify as _slugify +from retrieval_bench.singletons._shared import try_preload_corpus_to_gpu as _try_preload_corpus_to_gpu + + +def _l2_normalize_fp32(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + x32 = x.to(torch.float32) + return x32 / (x32.norm(p=2, dim=-1, keepdim=True) + eps) + + +def _doc_len_by_modality(modality: str) -> int: + m = str(modality or "").strip().lower() + if m == "image": + return 2048 + if m == "text": + return 8192 + if m in ("image_text", "imagetext", "image+text"): + return 10240 + raise ValueError(f"Unknown doc_modality '{modality}'. Expected: 'image', 'text', or 'image_text'.") + + +@dataclass(frozen=True, slots=True) +class _CacheMeta: + dataset_name: str + model_id: str + doc_modality: str + doc_max_length: int + query_max_length: int + max_input_tiles: int + use_thumbnail: bool + num_docs: int + corpus_ids_hash10: str + + def to_json(self) -> Dict[str, Any]: + return { + "dataset_name": str(self.dataset_name), + "model_id": str(self.model_id), + "doc_modality": str(self.doc_modality), + "doc_max_length": int(self.doc_max_length), + "query_max_length": int(self.query_max_length), + "max_input_tiles": int(self.max_input_tiles), + "use_thumbnail": bool(self.use_thumbnail), + "num_docs": int(self.num_docs), + "corpus_ids_hash10": str(self.corpus_ids_hash10), + } + + +class _NemotronVLDenseState: + def __init__( + self, + *, + model_id: str, + device: str, + top_k: int, + doc_modality: str, + doc_max_length: int, + query_max_length: int, + corpus_batch_size: int, + corpus_chunk_size: int, + cache_dir: Path, + max_input_tiles: int, + use_thumbnail: bool, + ) -> None: + self.model_id = str(model_id) + self.device = str(device) + self.top_k = int(top_k) + self.doc_modality = str(doc_modality) + self.doc_max_length = int(doc_max_length) + self.query_max_length = int(query_max_length) + self.corpus_batch_size = int(corpus_batch_size) + self.corpus_chunk_size = int(corpus_chunk_size) + self.cache_dir = cache_dir + self.max_input_tiles = int(max_input_tiles) + self.use_thumbnail = bool(use_thumbnail) + + self.dataset_name: Optional[str] = None + self.corpus_ids: Optional[List[str]] = None + self.corpus_id_to_idx: Optional[Dict[str, int]] = None + self.corpus_markdown: Optional[List[str]] = None + self.corpus_embeddings_cpu: Optional[torch.Tensor] = None # [n, dim] float16 + self.corpus_embeddings_gpu: Optional[torch.Tensor] = None # [n, dim] float16 + + self.model = self._load_model() + self.processor = self._get_processor() + self._fix_embedding_tokenizer_mismatch() + self._configure_processor() + + def _load_model(self): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Nemotron-VL dense retriever is GPU-only.") + + # Compatibility shim for torch/transformers version skew (used elsewhere in repo). + from retrieval_bench.utils.torch_compat import patch_torch_is_autocast_enabled + + patch_torch_is_autocast_enabled() + + try: + from transformers import AutoModel # type: ignore + except Exception as e: # pragma: no cover + raise ImportError("Missing dependency: transformers. Install it to use this retriever.") from e + + def _from_pretrained(*, attn_implementation: str): + common_kwargs = { + "trust_remote_code": True, + "attn_implementation": str(attn_implementation), + } + # Newer HF stacks deprecate `torch_dtype` in favor of `dtype`. + try: + return AutoModel.from_pretrained( + self.model_id, + dtype=torch.bfloat16, + **common_kwargs, + ) + except TypeError: + return AutoModel.from_pretrained( + self.model_id, + torch_dtype=torch.bfloat16, + **common_kwargs, + ) + + # Prefer FlashAttention2 when available; fall back to eager. + try: + model = _from_pretrained(attn_implementation="flash_attention_2") + except Exception: + model = _from_pretrained(attn_implementation="eager") + + model.to("cuda") + model.eval() + return model + + def _get_processor(self): + proc = getattr(self.model, "processor", None) + if proc is None: + raise RuntimeError( + "Nemotron-VL model did not expose `model.processor`. " + "This retriever expects a trust_remote_code model implementation with a processor attached." + ) + return proc + + def _fix_embedding_tokenizer_mismatch(self) -> None: + tokenizer = getattr(self.processor, "tokenizer", None) + if tokenizer is None: + return + tok_vocab_size = len(tokenizer) + emb = self.model.get_input_embeddings() + if emb is None: + return + emb_rows = emb.weight.shape[0] + if tok_vocab_size > emb_rows: + dim = emb.weight.shape[1] + new_weight = torch.zeros(tok_vocab_size, dim, dtype=emb.weight.dtype, device=emb.weight.device) + new_weight[:emb_rows] = emb.weight.data + emb.weight = torch.nn.Parameter(new_weight) + emb.num_embeddings = tok_vocab_size + + def _configure_processor(self) -> None: + # Tiling settings (model card defaults). + if hasattr(self.processor, "max_input_tiles"): + setattr(self.processor, "max_input_tiles", int(self.max_input_tiles)) + # Some implementations use `use_thumbnail`, others `use_thumbnails`. + if hasattr(self.processor, "use_thumbnail"): + setattr(self.processor, "use_thumbnail", bool(self.use_thumbnail)) + elif hasattr(self.processor, "use_thumbnails"): + setattr(self.processor, "use_thumbnails", bool(self.use_thumbnail)) + + def _set_processor_max_length_for_call(self, *, p_max_length: int) -> None: + # The model card uses `processor.p_max_length` to control token budget. + if hasattr(self.processor, "p_max_length"): + setattr(self.processor, "p_max_length", int(p_max_length)) + + # Best-effort: some processors also honor these. + if hasattr(self.processor, "max_length"): + setattr(self.processor, "max_length", int(p_max_length)) + + def _index_dir(self, dataset_name: str, *, corpus_ids_hash10: str) -> Path: + ds_slug = _slugify(dataset_name) + model_slug = _slugify(self.model_id.split("/")[-1]) + mod_slug = _slugify(self.doc_modality) + key = ( + f"{dataset_name}::{self.model_id}::{mod_slug}::" + f"dlen{self.doc_max_length}::qlen{self.query_max_length}::" + f"tiles{self.max_input_tiles}::thumb{int(self.use_thumbnail)}::{corpus_ids_hash10}" + ) + key_hash = hashlib.sha256(key.encode("utf-8")).hexdigest()[:10] + return self.cache_dir / ( + f"nemotron_vl_dense__{ds_slug}__{model_slug}__{mod_slug}__dlen{self.doc_max_length}__" + f"qlen{self.query_max_length}__tiles{self.max_input_tiles}__thumb{int(self.use_thumbnail)}__{key_hash}" + ) + + def _meta_path(self, dataset_name: str, *, corpus_ids_hash10: str) -> Path: + return self._index_dir(dataset_name, corpus_ids_hash10=corpus_ids_hash10) / "meta.json" + + def _emb_path(self, dataset_name: str, *, corpus_ids_hash10: str) -> Path: + return self._index_dir(dataset_name, corpus_ids_hash10=corpus_ids_hash10) / "embeddings.pt" + + def _load_meta(self, dataset_name: str, *, corpus_ids_hash10: str) -> Optional[_CacheMeta]: + p = self._meta_path(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + if not p.exists(): + return None + try: + data = json.loads(p.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return None + return _CacheMeta( + dataset_name=str(data.get("dataset_name", "")), + model_id=str(data.get("model_id", "")), + doc_modality=str(data.get("doc_modality", "")), + doc_max_length=int(data.get("doc_max_length", -1)), + query_max_length=int(data.get("query_max_length", -1)), + max_input_tiles=int(data.get("max_input_tiles", -1)), + use_thumbnail=bool(data.get("use_thumbnail", False)), + num_docs=int(data.get("num_docs", -1)), + corpus_ids_hash10=str(data.get("corpus_ids_hash10", "")), + ) + except Exception: + return None + + def _meta_matches(self, meta: _CacheMeta, *, dataset_name: str, corpus_ids_hash10: str, num_docs: int) -> bool: + try: + if meta.dataset_name != str(dataset_name): + return False + if meta.model_id != str(self.model_id): + return False + if str(meta.doc_modality) != str(self.doc_modality): + return False + if int(meta.doc_max_length) != int(self.doc_max_length): + return False + if int(meta.query_max_length) != int(self.query_max_length): + return False + if int(meta.max_input_tiles) != int(self.max_input_tiles): + return False + if bool(meta.use_thumbnail) != bool(self.use_thumbnail): + return False + if int(meta.num_docs) != int(num_docs): + return False + if str(meta.corpus_ids_hash10) != str(corpus_ids_hash10): + return False + return True + except Exception: + return False + + def _write_meta_atomic(self, meta: _CacheMeta, *, dataset_name: str, corpus_ids_hash10: str) -> None: + p = self._meta_path(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + p.parent.mkdir(parents=True, exist_ok=True) + tmp = p.with_suffix(".json.tmp") + tmp.write_text(json.dumps(meta.to_json(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + os.replace(tmp, p) + + def _embed_corpus_batched(self, corpus: Sequence[Dict[str, Any]]) -> torch.Tensor: + bs = max(1, int(self.corpus_batch_size)) + out: List[torch.Tensor] = [] + modality = str(self.doc_modality).strip().lower() + + # Set doc max length for document calls. + self._set_processor_max_length_for_call(p_max_length=int(self.doc_max_length)) + + with torch.inference_mode(): + for i in range(0, len(corpus), bs): + batch = corpus[i : i + bs] + + if modality == "image": + images = [doc["image"].convert("RGB") for doc in batch] + emb = self.model.encode_documents(images=images) + elif modality == "text": + texts = [str(doc.get("markdown", "")) for doc in batch] + emb = self.model.encode_documents(texts=texts) + else: # image_text + images = [doc["image"].convert("RGB") for doc in batch] + texts = [str(doc.get("markdown", "")) for doc in batch] + emb = self.model.encode_documents(images=images, texts=texts) + + if not isinstance(emb, torch.Tensor): + raise RuntimeError(f"encode_documents returned unexpected type: {type(emb)}") + if emb.ndim != 2: + raise RuntimeError(f"Unexpected document embedding shape: {tuple(emb.shape)}") + + emb = _l2_normalize_fp32(emb).to(torch.float16).detach().to("cpu") + out.append(emb) + + return torch.cat(out, dim=0) if out else torch.empty((0, 0), dtype=torch.float16, device="cpu") + + def _load_or_build_corpus_embeddings( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Dict[str, Any]], + ) -> torch.Tensor: + corpus_ids_list = [str(x) for x in corpus_ids] + corpus_ids_hash10 = _hash_corpus_ids10(corpus_ids_list) + + emb_path = self._emb_path(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + meta = self._load_meta(dataset_name, corpus_ids_hash10=corpus_ids_hash10) + + if meta is None: + print(f"[cache] dense rebuild: {emb_path} (missing)") + elif not self._meta_matches( + meta, dataset_name=dataset_name, corpus_ids_hash10=corpus_ids_hash10, num_docs=len(corpus_ids_list) + ): + print(f"[cache] dense rebuild: {emb_path} (meta_mismatch)") + else: + try: + emb = torch.load(emb_path, map_location="cpu") + if not isinstance(emb, torch.Tensor): + raise TypeError(f"Expected torch.Tensor in cache, got {type(emb)}") + if emb.shape[0] != len(corpus_ids_list): + raise ValueError( + f"Cached embeddings mismatch: cached={emb.shape[0]} vs corpus={len(corpus_ids_list)}" + ) + print(f"[cache] dense hit: {emb_path}") + return emb + except Exception: + print(f"[cache] dense rebuild: {emb_path} (load_error)") + + emb_path.parent.mkdir(parents=True, exist_ok=True) + t0 = time.time() + emb = self._embed_corpus_batched(corpus) + elapsed = time.time() - t0 + print(f"[cache] corpus embedding took {elapsed:.1f}s ({len(corpus)} docs)") + + tmp = emb_path.with_suffix(".pt.tmp") + torch.save(emb, tmp) + os.replace(tmp, emb_path) + + self._write_meta_atomic( + _CacheMeta( + dataset_name=str(dataset_name), + model_id=str(self.model_id), + doc_modality=str(self.doc_modality), + doc_max_length=int(self.doc_max_length), + query_max_length=int(self.query_max_length), + max_input_tiles=int(self.max_input_tiles), + use_thumbnail=bool(self.use_thumbnail), + num_docs=int(len(corpus_ids_list)), + corpus_ids_hash10=str(corpus_ids_hash10), + ), + dataset_name=dataset_name, + corpus_ids_hash10=corpus_ids_hash10, + ) + return emb + + def embed_query(self, query_text: str) -> torch.Tensor: + # Set query max length for query call. + self._set_processor_max_length_for_call(p_max_length=int(self.query_max_length)) + + with torch.inference_mode(): + emb = self.model.encode_queries([str(query_text)]) + + if not isinstance(emb, torch.Tensor): + raise RuntimeError(f"encode_queries returned unexpected type: {type(emb)}") + if emb.ndim != 2 or emb.shape[0] != 1: + raise RuntimeError(f"Unexpected query embedding shape: {tuple(emb.shape)}") + + emb1 = emb[0] + emb1 = _l2_normalize_fp32(emb1).to(torch.float16).detach().to("cpu") + return emb1 # [dim] on CPU + + def score_query(self, query_embedding_cpu: torch.Tensor) -> torch.Tensor: + if self.corpus_embeddings_cpu is None: + raise RuntimeError("corpus_embeddings_cpu is not set; call init() first") + + num_docs = self.corpus_embeddings_cpu.shape[0] + scores_cpu = torch.empty((num_docs,), dtype=torch.float32, device="cpu") + + chunk = max(1, int(self.corpus_chunk_size)) + device = str(self.device) + + with torch.inference_mode(): + q_gpu = query_embedding_cpu.to(device, non_blocking=True) # [dim] + q_gpu = q_gpu.unsqueeze(1) # [dim, 1] + + for c_start in range(0, num_docs, chunk): + c_end = min(c_start + chunk, num_docs) + if self.corpus_embeddings_gpu is not None: + c_gpu = self.corpus_embeddings_gpu[c_start:c_end] + else: + c_gpu = self.corpus_embeddings_cpu[c_start:c_end].to(device, non_blocking=True) + + chunk_scores = torch.matmul(c_gpu, q_gpu).squeeze(1).float() # [chunk] + scores_cpu[c_start:c_end] = chunk_scores.to("cpu") + + return scores_cpu + + def retrieve_one( + self, + query: str, + *, + return_markdown: bool = False, + excluded_ids: Optional[Sequence[str]] = None, + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + if self.corpus_ids is None or self.corpus_embeddings_cpu is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + + q_emb_cpu = self.embed_query(str(query)) + scores_cpu = self.score_query(q_emb_cpu) + + # Apply per-query excluded ids BEFORE top-k selection (BRIGHT semantics). + if excluded_ids and self.corpus_id_to_idx: + for did in set(str(x) for x in excluded_ids): + if did == "N/A": + continue + idx = self.corpus_id_to_idx.get(did, None) + if idx is None: + continue + try: + scores_cpu[int(idx)] = float("-inf") + except Exception: + pass + + k = min(int(self.top_k), len(self.corpus_ids)) + topk_scores, topk_indices = torch.topk(scores_cpu, k) + ids = self.corpus_ids + run = {ids[int(idx)]: float(score) for idx, score in zip(topk_indices.tolist(), topk_scores.tolist())} + + if not return_markdown: + return run + + md = self.corpus_markdown or [""] * len(ids) + markdown_by_id = {ids[int(idx)]: str(md[int(idx)]) for idx in topk_indices.tolist()} + return run, markdown_by_id + + +class NemotronEmbedVLDenseSingletonRetriever: + def __init__(self) -> None: + self._lock = threading.RLock() + self._state: Optional[_NemotronVLDenseState] = None + + def init( + self, + *, + dataset_name: str, + corpus_ids: Sequence[str], + corpus: Sequence[Dict[str, Any]], + model_id: str = "nvidia/llama-nemotron-embed-vl-1b-v2", + device: str = "auto", + top_k: int = 100, + doc_modality: str = "image_text", + doc_max_length: Union[int, str] = "auto", + query_max_length: int = 10240, + corpus_batch_size: int = 4, + corpus_chunk_size: int = 4096, + cache_dir: str | Path = "cache/nemotron_vl_dense", + preload_corpus_to_gpu: bool = False, + max_input_tiles: int = 6, + use_thumbnail: bool = True, + ) -> None: + with self._lock: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Nemotron-VL dense retriever is GPU-only.") + + device_eff = str(device or "auto").strip().lower() + if device_eff in ("auto",): + device_eff = "cuda" + if not device_eff.startswith("cuda"): + raise RuntimeError(f"Invalid device '{device}'. This retriever is GPU-only; use 'cuda'/'cuda:0'.") + + modality_eff = str(doc_modality or "image_text").strip() + if isinstance(doc_max_length, str) and doc_max_length.strip().lower() == "auto": + doc_max_length_eff = _doc_len_by_modality(modality_eff) + else: + doc_max_length_eff = int(doc_max_length) # may raise, intended + + cache_dir_p = Path(cache_dir) + + # If state exists but config changed, unload. + if self._state is not None and ( + self._state.model_id != str(model_id) + or self._state.device != str(device_eff) + or self._state.doc_modality != str(modality_eff) + or int(self._state.doc_max_length) != int(doc_max_length_eff) + or int(self._state.query_max_length) != int(query_max_length) + or int(self._state.max_input_tiles) != int(max_input_tiles) + or bool(self._state.use_thumbnail) != bool(use_thumbnail) + ): + self.unload() + + if self._state is None: + self._state = _NemotronVLDenseState( + model_id=str(model_id), + device=str(device_eff), + top_k=int(top_k), + doc_modality=str(modality_eff), + doc_max_length=int(doc_max_length_eff), + query_max_length=int(query_max_length), + corpus_batch_size=int(corpus_batch_size), + corpus_chunk_size=int(corpus_chunk_size), + cache_dir=cache_dir_p, + max_input_tiles=int(max_input_tiles), + use_thumbnail=bool(use_thumbnail), + ) + else: + # Update tunables. + self._state.top_k = int(top_k) + self._state.corpus_batch_size = int(corpus_batch_size) + self._state.corpus_chunk_size = int(corpus_chunk_size) + self._state.cache_dir = cache_dir_p + + corpus_ids_list = [str(x) for x in corpus_ids] + corpus_ids_hash10 = _hash_corpus_ids10(corpus_ids_list) + + if ( + self._state.dataset_name == str(dataset_name) + and self._state.corpus_ids is not None + and _hash_corpus_ids10(self._state.corpus_ids) == corpus_ids_hash10 + and self._state.corpus_embeddings_cpu is not None + ): + # Already initialized for the same corpus; only adjust GPU preload. + if preload_corpus_to_gpu and self._state.corpus_embeddings_gpu is None: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu( + self._state.corpus_embeddings_cpu, self._state.device + ) + if (not preload_corpus_to_gpu) and self._state.corpus_embeddings_gpu is not None: + self._state.corpus_embeddings_gpu = None + return + + emb_cpu = self._state._load_or_build_corpus_embeddings( + dataset_name=str(dataset_name), + corpus_ids=corpus_ids_list, + corpus=list(corpus), + ) + + self._state.dataset_name = str(dataset_name) + self._state.corpus_ids = corpus_ids_list + self._state.corpus_id_to_idx = {cid: i for i, cid in enumerate(corpus_ids_list)} + self._state.corpus_markdown = [str(doc.get("markdown", "")) for doc in corpus] + self._state.corpus_embeddings_cpu = emb_cpu + + self._state.corpus_embeddings_gpu = None + if preload_corpus_to_gpu: + self._state.corpus_embeddings_gpu = _try_preload_corpus_to_gpu(emb_cpu, self._state.device) + + def retrieve( + self, + query: str, + *, + return_markdown: bool = False, + excluded_ids: Optional[Sequence[str]] = None, + ) -> Union[Dict[str, float], Tuple[Dict[str, float], Dict[str, str]]]: + with self._lock: + if self._state is None: + raise RuntimeError("Retriever not initialized. Call retriever.init(...) first.") + return self._state.retrieve_one( + str(query), + return_markdown=bool(return_markdown), + excluded_ids=excluded_ids, + ) + + def unload(self) -> None: + with self._lock: + if self._state is None: + return + try: + if self._state.corpus_embeddings_gpu is not None: + del self._state.corpus_embeddings_gpu + if self._state.corpus_embeddings_cpu is not None: + del self._state.corpus_embeddings_cpu + if self._state.corpus_markdown is not None: + del self._state.corpus_markdown + if getattr(self._state, "model", None) is not None: + del self._state.model + if getattr(self._state, "processor", None) is not None: + del self._state.processor + finally: + self._state = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Module-level singleton instance +# --------------------------------------------------------------------------- +retriever = NemotronEmbedVLDenseSingletonRetriever() diff --git a/retrieval-bench/src/retrieval_bench/utils/corpus_language.py b/retrieval-bench/src/retrieval_bench/utils/corpus_language.py new file mode 100644 index 000000000..7fef1e900 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/utils/corpus_language.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Literal + +# Corpus-language mapping used by our workflow. +# +# Important: this is keyed off *corpus language* (not query language). Some datasets +# have a French corpus even when filtering to English-subset queries. +FRENCH_CORPUS_DATASET_SHORTS = { + "vidore_v3_finance_fr", + "vidore_v3_energy", + "vidore_v3_physics", +} + + +def dataset_short(dataset_name: str) -> str: + try: + return str(dataset_name).split("/")[-1] + except Exception: + return str(dataset_name) + + +def corpus_language(dataset_name: str) -> Literal["english", "french"]: + short = dataset_short(dataset_name) + return "french" if short in FRENCH_CORPUS_DATASET_SHORTS else "english" + + +def is_french_corpus(dataset_name: str) -> bool: + return corpus_language(dataset_name) == "french" diff --git a/retrieval-bench/src/retrieval_bench/utils/torch_compat.py b/retrieval-bench/src/retrieval_bench/utils/torch_compat.py new file mode 100644 index 000000000..4e45f22a3 --- /dev/null +++ b/retrieval-bench/src/retrieval_bench/utils/torch_compat.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +""" +Compatibility shims for mismatched torch/transformers versions. + +We occasionally update `transformers` on systems where `torch` is older (or vice versa). +Some transformers versions call `torch.is_autocast_enabled(device_type)` while older +torch versions only support `torch.is_autocast_enabled()` (no args). +""" + +from typing import Any + + +def patch_torch_is_autocast_enabled() -> None: + """ + Patch `torch.is_autocast_enabled` to accept optional args on older torch. + + Safe on newer torch: no-op. + Safe when torch is unavailable: no-op. + """ + try: + import torch # type: ignore + except Exception: + return + + fn = getattr(torch, "is_autocast_enabled", None) + if not callable(fn): + return + + # Newer torch accepts a device_type argument (e.g. "cuda"). + # Older torch raises: "takes no arguments (1 given)". + try: + _ = fn("cuda") + return # already compatible + except TypeError: + pass + except Exception: + # If it failed for some other reason, don't patch. + return + + orig = fn + + def _wrapped(*args: Any, **kwargs: Any) -> bool: + # Ignore any provided args (device_type) and defer to older torch behavior. + return bool(orig()) + + try: + setattr(torch, "is_autocast_enabled", _wrapped) + except Exception: + return diff --git a/retrieval-bench/tests/pipeline_evaluation/test_evaluator.py b/retrieval-bench/tests/pipeline_evaluation/test_evaluator.py new file mode 100644 index 000000000..88a21fc2d --- /dev/null +++ b/retrieval-bench/tests/pipeline_evaluation/test_evaluator.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for pipeline evaluation evaluator functions. +""" + +from typing import Any, Dict, List, Optional +import uuid + +import pytest + +from retrieval_bench.pipeline_evaluation import BasePipeline +from retrieval_bench.pipeline_evaluation.evaluator import aggregate_results, evaluate_retrieval + + +class MockPipeline(BasePipeline): + """Mock pipeline that returns predefined results.""" + + def __init__(self, results: Dict[str, Dict[str, float]], infos: Optional[Dict[str, Any]] = None): + self.results = results + self.infos = infos + # Make trace run names unique per test instance to avoid cross-test cache reuse. + self.model_id = f"mock-{uuid.uuid4().hex}" + + def search(self, query_ids: List[str], queries: List[str]): + if self.infos is not None: + return self.results, self.infos + return self.results + + +class TestEvaluateRetrieval: + """Tests for evaluate_retrieval function.""" + + @pytest.fixture + def simple_qrels(self): + """Simple qrels for testing.""" + return { + "q1": {"doc1": 1, "doc2": 1}, + "q2": {"doc2": 1, "doc3": 1}, + } + + @pytest.fixture + def perfect_results(self): + """Results that should achieve perfect NDCG.""" + return { + "q1": {"doc1": 0.95, "doc2": 0.90}, + "q2": {"doc2": 0.95, "doc3": 0.90}, + } + + @pytest.fixture + def sample_inputs(self): + """Standard inputs for testing.""" + return { + "query_ids": ["q1", "q2"], + "queries": ["Query 1", "Query 2"], + "corpus_ids": ["doc1", "doc2", "doc3"], + "corpus_images": [None, None, None], + "corpus_texts": ["Text 1", "Text 2", "Text 3"], + } + + def test_evaluate_retrieval_returns_dict(self, simple_qrels, perfect_results, sample_inputs): + """Test that evaluate_retrieval returns a dictionary.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + assert isinstance(results, dict) + + def test_evaluate_retrieval_contains_query_results(self, simple_qrels, perfect_results, sample_inputs): + """Test that results contain entries for each query.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + # Should have results for each query (excluding special keys like _timing) + query_results = {k: v for k, v in results.items() if not k.startswith("_")} + assert "q1" in query_results + assert "q2" in query_results + + def test_default_metric_is_ndcg_cut_10(self, simple_qrels, perfect_results, sample_inputs): + """Test that default metric is ndcg_cut_10.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + assert "ndcg_cut_10" in results["q1"] + assert "ndcg_cut_10" in results["q2"] + + def test_custom_metrics(self, simple_qrels, perfect_results, sample_inputs): + """Test using custom metrics.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + metrics=["ndcg_cut_5", "map"], + ) + + assert "ndcg_cut_5" in results["q1"] + assert "map" in results["q1"] + + def test_timing_information_included_by_default(self, simple_qrels, perfect_results, sample_inputs): + """Test that timing information is included by default.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + assert "_timing" in results + timing = results["_timing"] + assert "total_retrieval_time_milliseconds" in timing + assert "num_queries" in timing + assert "queries_per_second" in timing + assert timing["num_queries"] == 2 + + def test_timing_can_be_disabled(self, simple_qrels, perfect_results, sample_inputs): + """Test that timing can be disabled.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + track_time=False, + ) + + assert "_timing" not in results + + def test_pipeline_infos_included(self, simple_qrels, sample_inputs): + """Test that pipeline infos are included when returned.""" + results_data = {"q1": {"doc1": 0.9}, "q2": {"doc2": 0.9}} + infos = {"cost": 0.50, "num_gpus": 2} + pipeline = MockPipeline(results_data, infos=infos) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + assert "_infos" in results + assert "pipeline_infos" in results["_infos"] + assert results["_infos"]["pipeline_infos"]["cost"] == 0.50 + assert results["_infos"]["pipeline_infos"]["num_gpus"] == 2 + + def test_missing_query_results_are_reported(self, simple_qrels, sample_inputs): + """Test that missing query results are reported in timing metadata.""" + # Pipeline only returns results for q1, not q2 + partial_results = {"q1": {"doc1": 0.9}} + pipeline = MockPipeline(partial_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + # Evaluator scores only returned queries and reports missing ones in timing. + assert "q2" not in results + assert "_timing" in results + assert results["_timing"]["missing_num_queries"] == 1 + + def test_invalid_pipeline_return_raises_error(self, simple_qrels, sample_inputs): + """Test that invalid pipeline return type raises ValueError.""" + + class InvalidReturnPipeline(BasePipeline): + def search(self, query_ids: List[str], queries: List[str]): + return "not_a_dict" + + pipeline = InvalidReturnPipeline() + + with pytest.raises(ValueError) as exc_info: + evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + assert "dict" in str(exc_info.value).lower() + + def test_perfect_retrieval_gets_high_ndcg(self, simple_qrels, perfect_results, sample_inputs): + """Test that perfect retrieval gets high NDCG score.""" + pipeline = MockPipeline(perfect_results) + + results = evaluate_retrieval( + pipeline, + sample_inputs["query_ids"], + sample_inputs["queries"], + sample_inputs["corpus_ids"], + sample_inputs["corpus_images"], + sample_inputs["corpus_texts"], + simple_qrels, + ) + + # Perfect NDCG should be 1.0 + assert abs(results["q1"]["ndcg_cut_10"] - 1.0) < 1e-7 + assert abs(results["q2"]["ndcg_cut_10"] - 1.0) < 1e-7 + + +class TestAggregateResults: + """Tests for aggregate_results function.""" + + def test_aggregate_empty_results_returns_empty(self): + """Test that empty results return empty aggregation.""" + assert aggregate_results({}) == {} + + def test_simple_aggregation_without_languages(self): + """Test simple aggregation without language information.""" + results = { + "q1": {"ndcg_cut_10": 0.8, "map": 0.7}, + "q2": {"ndcg_cut_10": 0.6, "map": 0.5}, + } + + aggregated = aggregate_results(results) + + assert abs(aggregated["ndcg_cut_10"] - 0.7) < 1e-7 # (0.8 + 0.6) / 2 + assert abs(aggregated["map"] - 0.6) < 1e-7 # (0.7 + 0.5) / 2 + + def test_aggregation_with_timing_info(self): + """Test that timing info is preserved in aggregation.""" + results = { + "q1": {"ndcg_cut_10": 0.8}, + "q2": {"ndcg_cut_10": 0.6}, + "_timing": { + "total_retrieval_time_milliseconds": 1000, + "num_queries": 2, + "queries_per_second": 2.0, + }, + } + + aggregated = aggregate_results(results) + + assert "total_retrieval_time_milliseconds" in aggregated + assert aggregated["total_retrieval_time_milliseconds"] == 1000 + + def test_aggregation_with_language_info(self): + """Test aggregation with language breakdown.""" + results = { + "q1": {"ndcg_cut_10": 0.9}, + "q2": {"ndcg_cut_10": 0.8}, + "q3": {"ndcg_cut_10": 0.7}, + "q4": {"ndcg_cut_10": 0.6}, + } + query_languages = { + "q1": "english", + "q2": "english", + "q3": "french", + "q4": "french", + } + + aggregated = aggregate_results(results, query_languages) + + assert "overall" in aggregated + assert "by_language" in aggregated + assert "english" in aggregated["by_language"] + assert "french" in aggregated["by_language"] + + # Overall average + assert abs(aggregated["overall"]["ndcg_cut_10"] - (0.9 + 0.8 + 0.7 + 0.6) / 4) < 1e-7 + + # English average: (0.9 + 0.8) / 2 = 0.85 + assert abs(aggregated["by_language"]["english"]["ndcg_cut_10"] - (0.9 + 0.8) / 2) < 1e-7 + # French average: (0.7 + 0.6) / 2 = 0.65 + assert abs(aggregated["by_language"]["french"]["ndcg_cut_10"] - (0.7 + 0.6) / 2) < 1e-7 + + def test_language_aggregation_includes_query_counts(self): + """Test that language aggregation includes query counts.""" + results = { + "q1": {"ndcg_cut_10": 0.9}, + "q2": {"ndcg_cut_10": 0.8}, + "q3": {"ndcg_cut_10": 0.7}, + } + query_languages = { + "q1": "english", + "q2": "english", + "q3": "french", + } + + aggregated = aggregate_results(results, query_languages) + + assert aggregated["by_language"]["english"]["num_queries"] == 2 + assert aggregated["by_language"]["french"]["num_queries"] == 1 + + def test_unknown_language_handling(self): + """Test that queries without language mapping are labeled 'unknown'.""" + results = { + "q1": {"ndcg_cut_10": 0.9}, + "q2": {"ndcg_cut_10": 0.8}, + } + query_languages = { + "q1": "english", + # q2 is missing from language mapping + } + + aggregated = aggregate_results(results, query_languages) + + assert "unknown" in aggregated["by_language"] + assert abs(aggregated["by_language"]["unknown"]["ndcg_cut_10"] - 0.8) < 1e-7 + + def test_timing_info_in_language_aggregation(self): + """Test that timing info is included in language aggregation.""" + results = { + "q1": {"ndcg_cut_10": 0.9}, + "_timing": { + "total_retrieval_time_milliseconds": 500, + "num_queries": 1, + "queries_per_second": 2.0, + }, + } + query_languages = {"q1": "english"} + + aggregated = aggregate_results(results, query_languages) + + assert "timing" in aggregated + assert aggregated["timing"]["total_retrieval_time_milliseconds"] == 500 + + def test_single_query_aggregation(self): + """Test aggregation with a single query.""" + results = {"q1": {"ndcg_cut_10": 0.85}} + + aggregated = aggregate_results(results) + + assert abs(aggregated["ndcg_cut_10"] - 0.85) < 1e-7 + + def test_multiple_metrics_aggregation(self): + """Test aggregation with multiple metrics.""" + results = { + "q1": {"ndcg_cut_5": 0.9, "ndcg_cut_10": 0.95, "map": 0.8}, + "q2": {"ndcg_cut_5": 0.7, "ndcg_cut_10": 0.75, "map": 0.6}, + } + + aggregated = aggregate_results(results) + + assert abs(aggregated["ndcg_cut_5"] - (0.9 + 0.7) / 2) < 1e-7 + assert abs(aggregated["ndcg_cut_10"] - (0.95 + 0.75) / 2) < 1e-7 + assert abs(aggregated["map"] - (0.8 + 0.6) / 2) < 1e-7 + + def test_timing_only_results(self): + """Test results that only contain timing info.""" + results = { + "_timing": { + "total_retrieval_time_milliseconds": 1000, + "num_queries": 5, + "queries_per_second": 5.0, + } + } + + aggregated = aggregate_results(results) + + assert "timing" in aggregated + assert aggregated["timing"]["num_queries"] == 5 diff --git a/retrieval-bench/tests/pipeline_evaluation/test_pipelines.py b/retrieval-bench/tests/pipeline_evaluation/test_pipelines.py new file mode 100644 index 000000000..e6012cd49 --- /dev/null +++ b/retrieval-bench/tests/pipeline_evaluation/test_pipelines.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for pipeline modules in retrieval_bench. +""" + +import pytest + +from retrieval_bench.pipelines.backends import VALID_BACKENDS, get_backend_defaults, init_backend + + +def test_valid_backends_is_non_empty_set(): + assert isinstance(VALID_BACKENDS, set) + assert VALID_BACKENDS + + +def test_get_backend_defaults_returns_copy(): + backend = next(iter(VALID_BACKENDS)) + + cfg1 = get_backend_defaults(backend) + cfg2 = get_backend_defaults(backend) + + assert isinstance(cfg1, dict) + assert isinstance(cfg2, dict) + assert cfg1 == cfg2 + + cfg1["_sentinel"] = 1 + assert "_sentinel" not in cfg2 + + +def test_get_backend_defaults_rejects_unknown_backend(): + with pytest.raises(ValueError, match="Unknown backend"): + get_backend_defaults("does-not-exist") + + +def test_init_backend_rejects_unknown_backend(): + with pytest.raises(ValueError, match="Unknown backend"): + init_backend( + "does-not-exist", + dataset_name="vidore/test", + corpus_ids=[], + corpus=[], + ) diff --git a/setup.cfg b/setup.cfg index d22cbbf67..3af80e13e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -100,6 +100,8 @@ extend-ignore = per-file-ignores = # imported but unused __init__.py: F401, E402 + # retrieval-bench prompt strings and tool descriptions exceed 120 chars + retrieval-bench/**/*.py: E501 # Ignore additional deps needed for examples examples/*.py: F821 # Cython Exclusions