diff --git a/tests/test_lancedb_config.py b/tests/test_lancedb_config.py new file mode 100644 index 000000000..cf2d6e12f --- /dev/null +++ b/tests/test_lancedb_config.py @@ -0,0 +1,280 @@ +"""Offline unit tests for LanceDB config objects. + +These tests don't require a running LanceDB instance. They freeze the +contract for the three IVF-family indexes (IVF_PQ / IVF_HNSW_SQ / +IVF_HNSW_PQ) so that any future refactor that accidentally diverges their +code paths (e.g. introduces index-type-specific branches) will fail CI. + +Background: IVF_HNSW_SQ / IVF_HNSW_PQ share the same code path as IVF_PQ +and the CLI is wired up for all three. These tests encode that claim. +""" + +import typing +from typing import Annotated, get_type_hints + +from vectordb_bench.backend.clients.api import IndexType, MetricType +from vectordb_bench.backend.clients.lancedb.config import ( + LanceDBAutoIndexConfig, + LanceDBIndexConfig, + LanceDBIVFHNSWPQConfig, + LanceDBIVFHNSWSQConfig, + LanceDBNoIndexConfig, + _lancedb_case_config, +) + +# --------------------------------------------------------------------------- +# Registry mapping +# --------------------------------------------------------------------------- + + +def test_registry_covers_all_lancedb_index_types(): + """Every LanceDB-supported IndexType must resolve to a config class.""" + required = { + IndexType.IVFPQ, + IndexType.AUTOINDEX, + IndexType.IVF_HNSW_SQ, + IndexType.IVF_HNSW_PQ, + IndexType.NONE, + } + assert required.issubset(_lancedb_case_config.keys()) + + # HNSW is kept for backwards compatibility and must map to IVF_HNSW_SQ. + assert _lancedb_case_config[IndexType.HNSW] is LanceDBIVFHNSWSQConfig + + +# --------------------------------------------------------------------------- +# index_param() / search_param() contract +# --------------------------------------------------------------------------- + + +def test_ivfpq_default_params_are_minimal(): + cfg = LanceDBIndexConfig() + assert cfg.index == IndexType.IVFPQ + p = cfg.index_param() + # Always present + assert p["metric"] == "cosine" or p["metric"] == "l2" + assert "num_bits" in p + # Zero-valued optionals stay out of the param dict so LanceDB uses its + # own defaults. + assert "num_partitions" not in p + assert "num_sub_vectors" not in p + # search_param() is empty when all tunables are zero. + assert cfg.search_param() == {} + + +def test_ivfpq_tuned_params_are_forwarded(): + cfg = LanceDBIndexConfig( + metric_type=MetricType.COSINE, + num_partitions=256, + num_sub_vectors=96, + nbits=8, + nprobes=20, + refine_factor=10, + ) + p = cfg.index_param() + assert p == { + "metric": "cosine", + "num_bits": 8, + "sample_rate": 256, + "max_iterations": 50, + "num_partitions": 256, + "num_sub_vectors": 96, + } + assert cfg.search_param() == {"nprobes": 20, "refine_factor": 10} + + +def test_ivf_hnsw_sq_params_are_forwarded(): + cfg = LanceDBIVFHNSWSQConfig( + metric_type=MetricType.COSINE, + num_partitions=256, + m=16, + ef_construction=200, + ef=128, + nprobes=20, + refine_factor=10, + ) + p = cfg.index_param() + assert p["index_type"] == "IVF_HNSW_SQ" + assert p["num_partitions"] == 256 + assert p["m"] == 16 + assert p["ef_construction"] == 200 + assert cfg.search_param() == { + "ef": 128, + "nprobes": 20, + "refine_factor": 10, + } + + +def test_ivf_hnsw_pq_params_are_forwarded(): + cfg = LanceDBIVFHNSWPQConfig( + metric_type=MetricType.COSINE, + num_partitions=256, + num_sub_vectors=96, + m=16, + ef_construction=200, + ef=128, + nprobes=20, + refine_factor=10, + ) + p = cfg.index_param() + assert p["index_type"] == "IVF_HNSW_PQ" + assert p["num_partitions"] == 256 + assert p["num_sub_vectors"] == 96 + assert p["m"] == 16 + assert p["ef_construction"] == 200 + assert cfg.search_param() == { + "ef": 128, + "nprobes": 20, + "refine_factor": 10, + } + + +# --------------------------------------------------------------------------- +# Code-path unification +# --------------------------------------------------------------------------- + + +def test_ivf_family_shares_search_knobs(): + """IVF_PQ exposes nprobes+refine_factor; IVF_HNSW_SQ/PQ additionally + expose ef. These are the only keys lancedb.py's search_embedding knows + how to forward, so any new index type must stay inside this union. + """ + allowed = {"nprobes", "ef", "refine_factor"} + + ivfpq = LanceDBIndexConfig(nprobes=10, refine_factor=5) + sq = LanceDBIVFHNSWSQConfig(nprobes=10, ef=64, refine_factor=5) + pq = LanceDBIVFHNSWPQConfig(nprobes=10, ef=64, refine_factor=5) + + for cfg in (ivfpq, sq, pq): + assert set(cfg.search_param().keys()).issubset(allowed) + + +def test_no_index_and_autoindex_are_well_formed(): + none_cfg = LanceDBNoIndexConfig() + assert none_cfg.index == IndexType.NONE + assert none_cfg.index_param() == {} + + auto_cfg = LanceDBAutoIndexConfig() + assert auto_cfg.index == IndexType.AUTOINDEX + assert "metric" in auto_cfg.index_param() + + +# --------------------------------------------------------------------------- +# CLI wiring — validate via static TypedDict introspection (no heavy runtime +# imports needed, avoids hdrh / streamlit / etc. dependency chains) +# --------------------------------------------------------------------------- + + +def _extract_click_option_names(typed_dict_cls: type) -> set[str]: + """Extract ``--option-name`` strings from a Click-annotated TypedDict. + + Each field is ``Annotated[T, click.option("--name", ...)]``. + We pull the option strings from the ``click.Option`` metadata. + """ + hints = get_type_hints(typed_dict_cls, include_extras=True) + names: set[str] = set() + for _field, hint in hints.items(): + if typing.get_origin(hint) is Annotated: + for meta in hint.__metadata__: + # click.option(...) produces a click.core.Decorator / functools.partial + # but in this project it's stored as a click.Argument or a + # ``functools.partial`` wrapping ``click.option``. Extract the + # first positional string that starts with "--". + if hasattr(meta, "name") and isinstance(meta.name, str): + names.add(meta.name) + # click.option() returns a decorator whose .args[0] is the + # option flag(s). We try a few accessor patterns. + for attr in ("args", "decls"): + for val in getattr(meta, attr, ()): + if isinstance(val, str) and val.startswith("--"): + names.add(val) + # For click.option stored as click.core.Option or Decorator + if hasattr(meta, "opts"): + for opt in meta.opts: + if isinstance(opt, str) and opt.startswith("--"): + names.add(opt) + return names + + +def test_cli_typed_dicts_define_all_expected_commands(): + """Every expected CLI TypedDict class must be importable from lancedb/cli.py.""" + # We mock the heavy module to avoid pulling the entire runtime. + # lancedb/cli.py imports ....cli.cli which triggers hdrh etc. + # Instead we just verify the TypedDict definitions exist in source. + import ast + from pathlib import Path + + cli_path = Path("vectordb_bench/backend/clients/lancedb/cli.py") + tree = ast.parse(cli_path.read_text()) + + class_names: set[str] = set() + func_names: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + class_names.add(node.name) + elif isinstance(node, ast.FunctionDef): + func_names.add(node.name) + + # TypedDicts that drive CLI options + assert { + "LanceDBTypedDict", + "LanceDBIVFPQTypedDict", + "LanceDBIVFHNSWSQTypedDict", + "LanceDBIVFHNSWPQTypedDict", + }.issubset(class_names) + + # Command functions registered via @cli.command() + assert {"LanceDB", "LanceDBAutoIndex", "LanceDBIVFPQ", "LanceDBIVFHNSWSQ", "LanceDBIVFHNSWPQ"}.issubset(func_names) + + +def test_cli_typeddict_ivfpq_has_search_knobs(): + """IVF_PQ TypedDict must define nprobes / refine_factor / num_partitions.""" + import ast + from pathlib import Path + + cli_src = Path("vectordb_bench/backend/clients/lancedb/cli.py").read_text() + tree = ast.parse(cli_src) + + def _fields_of(cls_name: str) -> set[str]: + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == cls_name: + return { + item.target.id + for item in node.body + if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name) + } + return set() + + ivfpq = _fields_of("LanceDBIVFPQTypedDict") + assert {"nprobes", "refine_factor", "num_partitions", "num_sub_vectors", "nbits"}.issubset(ivfpq) + + +def test_cli_typeddict_hnsw_variants_are_superset_of_ivfpq(): + """Both HNSW TypedDicts must have nprobes + refine_factor + graph knobs.""" + import ast + from pathlib import Path + + cli_src = Path("vectordb_bench/backend/clients/lancedb/cli.py").read_text() + tree = ast.parse(cli_src) + + def _fields_of(cls_name: str) -> set[str]: + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == cls_name: + return { + item.target.id + for item in node.body + if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name) + } + return set() + + sq_fields = _fields_of("LanceDBIVFHNSWSQTypedDict") + pq_fields = _fields_of("LanceDBIVFHNSWPQTypedDict") + + # Both must have shared search knobs + shared_knobs = {"nprobes", "refine_factor", "num_partitions", "m", "ef", "ef_construction"} + assert shared_knobs.issubset(sq_fields), f"SQ missing: {shared_knobs - sq_fields}" + assert shared_knobs.issubset(pq_fields), f"PQ missing: {shared_knobs - pq_fields}" + + # PQ variant has num_sub_vectors, SQ does not + assert "num_sub_vectors" in pq_fields + assert "num_sub_vectors" not in sq_fields diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index f507abe33..2876df239 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -50,6 +50,8 @@ class IndexType(StrEnum): SVS_VAMANA_LEANVEC = "SVS_VAMANA_LEANVEC" Hologres_HGraph = "HGraph" Hologres_Graph = "Graph" + IVF_HNSW_SQ = "IVF_HNSW_SQ" + IVF_HNSW_PQ = "IVF_HNSW_PQ" NONE = "NONE" diff --git a/vectordb_bench/backend/clients/lancedb/cli.py b/vectordb_bench/backend/clients/lancedb/cli.py index 219ae114e..494c28a2b 100644 --- a/vectordb_bench/backend/clients/lancedb/cli.py +++ b/vectordb_bench/backend/clients/lancedb/cli.py @@ -1,3 +1,4 @@ +import os from typing import Annotated, Unpack import click @@ -22,21 +23,134 @@ class LanceDBTypedDict(CommonTypedDict): str | None, click.option("--token", type=str, help="Authentication token", required=False), ] + cos_secret_id: Annotated[ + str | None, + click.option( + "--cos-secret-id", + type=str, + help="Tencent COS secret ID (or set COS_SECRET_ID env var)", + required=False, + ), + ] + cos_secret_key: Annotated[ + str | None, + click.option( + "--cos-secret-key", + type=str, + help="Tencent COS secret key (or set COS_SECRET_KEY env var)", + required=False, + ), + ] + cos_endpoint: Annotated[ + str | None, + click.option( + "--cos-endpoint", + type=str, + help="Tencent COS endpoint (or set COS_ENDPOINT env var)", + required=False, + ), + ] + cos_region: Annotated[ + str | None, + click.option( + "--cos-region", + type=str, + help="Tencent COS region (or set TENCENTCLOUD_REGION env var)", + required=False, + ), + ] + + +def _build_storage_options(**parameters) -> dict[str, str] | None: + """Build storage_options based on URI scheme. + + Supports: + - cos:// / s3:// → Tencent COS credentials + - goosefs:// → GooseFS authentication options + """ + uri = parameters.get("uri", "") + + # --- GooseFS storage options --- + if uri.startswith("goosefs://"): + return _build_goosefs_storage_options(**parameters) + + # --- COS / S3 storage options --- + if uri.startswith(("cos://", "s3://")): + return _build_cos_storage_options(**parameters) + + return None + + +def _build_cos_storage_options(**parameters) -> dict[str, str] | None: + """Build storage_options for COS/S3 if credentials are provided.""" + secret_id = parameters.get("cos_secret_id") or os.environ.get("COS_SECRET_ID") + secret_key = parameters.get("cos_secret_key") or os.environ.get("COS_SECRET_KEY") + endpoint = parameters.get("cos_endpoint") or os.environ.get("COS_ENDPOINT") + region = parameters.get("cos_region") or os.environ.get("TENCENTCLOUD_REGION") + + if not (secret_id and secret_key): + return None + + storage_options = { + "aws_access_key_id": secret_id, + "aws_secret_access_key": secret_key, + } + if endpoint: + storage_options["endpoint"] = endpoint + if region: + storage_options["region"] = region + + return storage_options + + +def _build_goosefs_storage_options(**_parameters: str) -> dict[str, str] | None: + """Build storage_options for GooseFS. + + Recognized environment variables: + - GOOSEFS_AUTH_TYPE → goosefs_auth_type (simple / nosasl) + - GOOSEFS_AUTH_USERNAME → goosefs_auth_username + - GOOSEFS_WRITE_TYPE → goosefs_write_type (CACHE_THROUGH etc.) + - GOOSEFS_BLOCK_SIZE → goosefs_block_size + - GOOSEFS_CHUNK_SIZE → goosefs_chunk_size + """ + storage_options: dict[str, str] = {} + + _goosefs_env_keys = { + "GOOSEFS_AUTH_TYPE": "goosefs_auth_type", + "GOOSEFS_AUTH_USERNAME": "goosefs_auth_username", + "GOOSEFS_WRITE_TYPE": "goosefs_write_type", + "GOOSEFS_BLOCK_SIZE": "goosefs_block_size", + "GOOSEFS_CHUNK_SIZE": "goosefs_chunk_size", + } + + for env_key, opt_key in _goosefs_env_keys.items(): + value = os.environ.get(env_key) + if value: + storage_options[opt_key] = value + + return storage_options if storage_options else None + + +def _build_db_config(**parameters): + from .config import LanceDBConfig + + return LanceDBConfig( + db_label=parameters["db_label"], + uri=parameters["uri"], + token=SecretStr(parameters["token"]) if parameters.get("token") else None, + storage_options=_build_storage_options(**parameters), + ) @cli.command() @click_parameter_decorators_from_typed_dict(LanceDBTypedDict) def LanceDB(**parameters: Unpack[LanceDBTypedDict]): - from .config import LanceDBConfig, _lancedb_case_config + from .config import LanceDBNoIndexConfig run( db=DB.LanceDB, - db_config=LanceDBConfig( - db_label=parameters["db_label"], - uri=parameters["uri"], - token=SecretStr(parameters["token"]) if parameters.get("token") else None, - ), - db_case_config=_lancedb_case_config.get("NONE")(), + db_config=_build_db_config(**parameters), + db_case_config=LanceDBNoIndexConfig(), **parameters, ) @@ -44,16 +158,12 @@ def LanceDB(**parameters: Unpack[LanceDBTypedDict]): @cli.command() @click_parameter_decorators_from_typed_dict(LanceDBTypedDict) def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]): - from .config import LanceDBConfig, _lancedb_case_config + from .config import LanceDBAutoIndexConfig run( db=DB.LanceDB, - db_config=LanceDBConfig( - db_label=parameters["db_label"], - uri=parameters["uri"], - token=SecretStr(parameters["token"]) if parameters.get("token") else None, - ), - db_case_config=_lancedb_case_config.get(IndexType.AUTOINDEX)(), + db_config=_build_db_config(**parameters), + db_case_config=LanceDBAutoIndexConfig(), **parameters, ) @@ -65,7 +175,8 @@ class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict): "--num-partitions", type=int, default=0, - help="Number of partitions for IVFPQ index, unset = use LanceDB default", + help="Number of partitions for IVF_PQ index, 0 = use LanceDB default", + show_default=True, ), ] num_sub_vectors: Annotated[ @@ -74,7 +185,8 @@ class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict): "--num-sub-vectors", type=int, default=0, - help="Number of sub-vectors for IVFPQ index, unset = use LanceDB default", + help="Number of sub-vectors for IVF_PQ index, 0 = use LanceDB default", + show_default=True, ), ] nbits: Annotated[ @@ -83,13 +195,28 @@ class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict): "--nbits", type=int, default=8, - help="Number of bits for IVFPQ index (must be 4 or 8), unset = use LanceDB default", + help="Number of bits for quantization (4 or 8)", + show_default=True, ), ] nprobes: Annotated[ int, click.option( - "--nprobes", type=int, default=0, help="Number of probes for IVFPQ search, unset = use LanceDB default" + "--nprobes", + type=int, + default=0, + help="Number of probes for IVF search, 0 = use LanceDB default", + show_default=True, + ), + ] + refine_factor: Annotated[ + int, + click.option( + "--refine-factor", + type=int, + default=0, + help="Refine factor for better recall, 0 = disabled", + show_default=True, ), ] @@ -97,50 +224,171 @@ class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict): @cli.command() @click_parameter_decorators_from_typed_dict(LanceDBIVFPQTypedDict) def LanceDBIVFPQ(**parameters: Unpack[LanceDBIVFPQTypedDict]): - from .config import LanceDBConfig, LanceDBIndexConfig + from .config import LanceDBIndexConfig run( db=DB.LanceDB, - db_config=LanceDBConfig( - db_label=parameters["db_label"], - uri=parameters["uri"], - token=SecretStr(parameters["token"]) if parameters.get("token") else None, - ), + db_config=_build_db_config(**parameters), db_case_config=LanceDBIndexConfig( index=IndexType.IVFPQ, num_partitions=parameters["num_partitions"], num_sub_vectors=parameters["num_sub_vectors"], nbits=parameters["nbits"], nprobes=parameters["nprobes"], + refine_factor=parameters["refine_factor"], ), **parameters, ) -class LanceDBHNSWTypedDict(CommonTypedDict, LanceDBTypedDict): - m: Annotated[int, click.option("--m", type=int, default=0, help="HNSW parameter m")] +class LanceDBIVFHNSWSQTypedDict(CommonTypedDict, LanceDBTypedDict): + num_partitions: Annotated[ + int, + click.option( + "--num-partitions", + type=int, + default=0, + help="Number of IVF partitions, 0 = use LanceDB default", + show_default=True, + ), + ] + m: Annotated[ + int, + click.option("--m", type=int, default=0, help="HNSW parameter m, 0 = use LanceDB default", show_default=True), + ] ef_construction: Annotated[ - int, click.option("--ef-construction", type=int, default=0, help="HNSW parameter ef_construction") + int, + click.option( + "--ef-construction", + type=int, + default=0, + help="HNSW ef_construction, 0 = use LanceDB default", + show_default=True, + ), + ] + ef: Annotated[ + int, + click.option("--ef", type=int, default=0, help="HNSW search ef, 0 = use LanceDB default", show_default=True), + ] + nprobes: Annotated[ + int, + click.option( + "--nprobes", + type=int, + default=0, + help="Number of probes for IVF search, 0 = use LanceDB default", + show_default=True, + ), + ] + refine_factor: Annotated[ + int, + click.option( + "--refine-factor", + type=int, + default=0, + help="Refine factor for better recall, 0 = disabled", + show_default=True, + ), ] - ef: Annotated[int, click.option("--ef", type=int, default=0, help="HNSW search parameter ef")] @cli.command() -@click_parameter_decorators_from_typed_dict(LanceDBHNSWTypedDict) -def LanceDBHNSW(**parameters: Unpack[LanceDBHNSWTypedDict]): - from .config import LanceDBConfig, LanceDBHNSWIndexConfig +@click_parameter_decorators_from_typed_dict(LanceDBIVFHNSWSQTypedDict) +def LanceDBIVFHNSWSQ(**parameters: Unpack[LanceDBIVFHNSWSQTypedDict]): + from .config import LanceDBIVFHNSWSQConfig run( db=DB.LanceDB, - db_config=LanceDBConfig( - db_label=parameters["db_label"], - uri=parameters["uri"], - token=SecretStr(parameters["token"]) if parameters.get("token") else None, + db_config=_build_db_config(**parameters), + db_case_config=LanceDBIVFHNSWSQConfig( + num_partitions=parameters["num_partitions"], + m=parameters["m"], + ef_construction=parameters["ef_construction"], + ef=parameters["ef"], + nprobes=parameters["nprobes"], + refine_factor=parameters["refine_factor"], + ), + **parameters, + ) + + +class LanceDBIVFHNSWPQTypedDict(CommonTypedDict, LanceDBTypedDict): + num_partitions: Annotated[ + int, + click.option( + "--num-partitions", + type=int, + default=0, + help="Number of IVF partitions, 0 = use LanceDB default", + show_default=True, + ), + ] + num_sub_vectors: Annotated[ + int, + click.option( + "--num-sub-vectors", + type=int, + default=0, + help="Number of PQ sub-vectors, 0 = use LanceDB default", + show_default=True, ), - db_case_config=LanceDBHNSWIndexConfig( + ] + m: Annotated[ + int, + click.option("--m", type=int, default=0, help="HNSW parameter m, 0 = use LanceDB default", show_default=True), + ] + ef_construction: Annotated[ + int, + click.option( + "--ef-construction", + type=int, + default=0, + help="HNSW ef_construction, 0 = use LanceDB default", + show_default=True, + ), + ] + ef: Annotated[ + int, + click.option("--ef", type=int, default=0, help="HNSW search ef, 0 = use LanceDB default", show_default=True), + ] + nprobes: Annotated[ + int, + click.option( + "--nprobes", + type=int, + default=0, + help="Number of probes for IVF search, 0 = use LanceDB default", + show_default=True, + ), + ] + refine_factor: Annotated[ + int, + click.option( + "--refine-factor", + type=int, + default=0, + help="Refine factor for better recall, 0 = disabled", + show_default=True, + ), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(LanceDBIVFHNSWPQTypedDict) +def LanceDBIVFHNSWPQ(**parameters: Unpack[LanceDBIVFHNSWPQTypedDict]): + from .config import LanceDBIVFHNSWPQConfig + + run( + db=DB.LanceDB, + db_config=_build_db_config(**parameters), + db_case_config=LanceDBIVFHNSWPQConfig( + num_partitions=parameters["num_partitions"], + num_sub_vectors=parameters["num_sub_vectors"], m=parameters["m"], ef_construction=parameters["ef_construction"], ef=parameters["ef"], + nprobes=parameters["nprobes"], + refine_factor=parameters["refine_factor"], ), **parameters, ) diff --git a/vectordb_bench/backend/clients/lancedb/config.py b/vectordb_bench/backend/clients/lancedb/config.py index b0621c22e..8a98c36f3 100644 --- a/vectordb_bench/backend/clients/lancedb/config.py +++ b/vectordb_bench/backend/clients/lancedb/config.py @@ -6,18 +6,22 @@ class LanceDBConfig(DBConfig): """LanceDB connection configuration.""" - db_label: str - uri: str + db_label: str = "" + uri: str = "/tmp/lancedb" token: SecretStr | None = None + storage_options: dict[str, str] | None = None def to_dict(self) -> dict: return { "uri": self.uri, "token": self.token.get_secret_value() if self.token else None, + "storage_options": self.storage_options, } class LanceDBIndexConfig(BaseModel, DBCaseConfig): + """Default IVF_PQ index configuration.""" + index: IndexType = IndexType.IVFPQ metric_type: MetricType = MetricType.L2 num_partitions: int = 0 @@ -26,91 +30,166 @@ class LanceDBIndexConfig(BaseModel, DBCaseConfig): sample_rate: int = 256 max_iterations: int = 50 nprobes: int = 0 + refine_factor: int = 0 + + def parse_metric(self) -> str: + if self.metric_type in (MetricType.L2, MetricType.COSINE): + return self.metric_type.value.lower() + if self.metric_type in (MetricType.IP, MetricType.DP): + return "dot" + msg = f"Metric type {self.metric_type} is not supported for LanceDB!" + raise ValueError(msg) def index_param(self) -> dict: - if self.index not in [ - IndexType.IVFPQ, - IndexType.HNSW, - IndexType.AUTOINDEX, - IndexType.NONE, - ]: - msg = f"Index type {self.index} is not supported for LanceDB!" - raise ValueError(msg) - - # See https://lancedb.github.io/lancedb/python/python/#lancedb.table.Table.create_index params = { "metric": self.parse_metric(), "num_bits": self.nbits, "sample_rate": self.sample_rate, "max_iterations": self.max_iterations, } - if self.num_partitions > 0: params["num_partitions"] = self.num_partitions if self.num_sub_vectors > 0: params["num_sub_vectors"] = self.num_sub_vectors - return params def search_param(self) -> dict: params = {} if self.nprobes > 0: params["nprobes"] = self.nprobes - + if self.refine_factor > 0: + params["refine_factor"] = self.refine_factor return params - def parse_metric(self) -> str: - if self.metric_type in [MetricType.L2, MetricType.COSINE]: - return self.metric_type.value.lower() - if self.metric_type in [MetricType.IP, MetricType.DP]: - return "dot" - msg = f"Metric type {self.metric_type} is not supported for LanceDB!" - raise ValueError(msg) - class LanceDBNoIndexConfig(LanceDBIndexConfig): + """No index — brute-force scan.""" + index: IndexType = IndexType.NONE def index_param(self) -> dict: return {} + def search_param(self) -> dict: + params = {} + if self.refine_factor > 0: + params["refine_factor"] = self.refine_factor + return params + class LanceDBAutoIndexConfig(LanceDBIndexConfig): + """AutoIndex — let LanceDB decide.""" + index: IndexType = IndexType.AUTOINDEX def index_param(self) -> dict: - return {} + return {"metric": self.parse_metric()} + def search_param(self) -> dict: + params = {} + if self.nprobes > 0: + params["nprobes"] = self.nprobes + if self.refine_factor > 0: + params["refine_factor"] = self.refine_factor + return params + + +class LanceDBIVFHNSWSQConfig(BaseModel, DBCaseConfig): + """IVF_HNSW_SQ index — IVF partitioning + HNSW graph + scalar quantization.""" -class LanceDBHNSWIndexConfig(LanceDBIndexConfig): - index: IndexType = IndexType.HNSW + index: IndexType = IndexType.IVF_HNSW_SQ + metric_type: MetricType = MetricType.L2 + num_partitions: int = 0 m: int = 0 ef_construction: int = 0 ef: int = 0 + nprobes: int = 0 + refine_factor: int = 0 - def index_param(self) -> dict: - params = LanceDBIndexConfig.index_param(self) + def parse_metric(self) -> str: + if self.metric_type in (MetricType.L2, MetricType.COSINE): + return self.metric_type.value.lower() + if self.metric_type in (MetricType.IP, MetricType.DP): + return "dot" + msg = f"Metric type {self.metric_type} is not supported for LanceDB!" + raise ValueError(msg) - # See https://lancedb.github.io/lancedb/python/python/#lancedb.index.HnswSq - params["index_type"] = "IVF_HNSW_SQ" + def index_param(self) -> dict: + params = { + "metric": self.parse_metric(), + "index_type": "IVF_HNSW_SQ", + } + if self.num_partitions > 0: + params["num_partitions"] = self.num_partitions if self.m > 0: params["m"] = self.m if self.ef_construction > 0: params["ef_construction"] = self.ef_construction - return params def search_param(self) -> dict: params = {} - if self.ef != 0: - params = {"ef": self.ef} + if self.ef > 0: + params["ef"] = self.ef + if self.nprobes > 0: + params["nprobes"] = self.nprobes + if self.refine_factor > 0: + params["refine_factor"] = self.refine_factor + return params + + +class LanceDBIVFHNSWPQConfig(BaseModel, DBCaseConfig): + """IVF_HNSW_PQ index — IVF partitioning + HNSW graph + product quantization.""" + index: IndexType = IndexType.IVF_HNSW_PQ + metric_type: MetricType = MetricType.L2 + num_partitions: int = 0 + num_sub_vectors: int = 0 + m: int = 0 + ef_construction: int = 0 + ef: int = 0 + nprobes: int = 0 + refine_factor: int = 0 + + def parse_metric(self) -> str: + if self.metric_type in (MetricType.L2, MetricType.COSINE): + return self.metric_type.value.lower() + if self.metric_type in (MetricType.IP, MetricType.DP): + return "dot" + msg = f"Metric type {self.metric_type} is not supported for LanceDB!" + raise ValueError(msg) + + def index_param(self) -> dict: + params = { + "metric": self.parse_metric(), + "index_type": "IVF_HNSW_PQ", + } + if self.num_partitions > 0: + params["num_partitions"] = self.num_partitions + if self.num_sub_vectors > 0: + params["num_sub_vectors"] = self.num_sub_vectors + if self.m > 0: + params["m"] = self.m + if self.ef_construction > 0: + params["ef_construction"] = self.ef_construction + return params + + def search_param(self) -> dict: + params = {} + if self.ef > 0: + params["ef"] = self.ef + if self.nprobes > 0: + params["nprobes"] = self.nprobes + if self.refine_factor > 0: + params["refine_factor"] = self.refine_factor return params _lancedb_case_config = { IndexType.IVFPQ: LanceDBIndexConfig, IndexType.AUTOINDEX: LanceDBAutoIndexConfig, - IndexType.HNSW: LanceDBHNSWIndexConfig, + IndexType.IVF_HNSW_SQ: LanceDBIVFHNSWSQConfig, + IndexType.IVF_HNSW_PQ: LanceDBIVFHNSWPQConfig, + IndexType.HNSW: LanceDBIVFHNSWSQConfig, # backward compat: HNSW maps to IVF_HNSW_SQ IndexType.NONE: LanceDBNoIndexConfig, } diff --git a/vectordb_bench/backend/clients/lancedb/lancedb.py b/vectordb_bench/backend/clients/lancedb/lancedb.py index 65330e2cb..41ba7bd9e 100644 --- a/vectordb_bench/backend/clients/lancedb/lancedb.py +++ b/vectordb_bench/backend/clients/lancedb/lancedb.py @@ -1,29 +1,44 @@ +"""Wrapper around the LanceDB vector database over VectorDB""" + import logging +import os from contextlib import contextmanager import lancedb import pyarrow as pa -from lancedb.pydantic import LanceModel + +from vectordb_bench.backend.filter import Filter, FilterOp from ..api import IndexType, VectorDB -from .config import LanceDBConfig, LanceDBIndexConfig +from .config import LanceDBIndexConfig log = logging.getLogger(__name__) - -class VectorModel(LanceModel): - id: int - vector: list[float] +# Rows per ``table.add`` call. Each call produces a new Lance data file +# (fragment), so enlarging this value directly controls the on-disk fragment +# size. Override via the ``LANCEDB_BATCH_SIZE`` environment variable. +# +# Rough sizing for float32 vectors: bytes_per_fragment ≈ rows * dim * 4. +# Example: dim=768, rows=170000 -> ~498 MB per fragment. +LANCEDB_BATCH_SIZE = int(os.environ.get("LANCEDB_BATCH_SIZE", "5000")) class LanceDB(VectorDB): + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, + ] + thread_safe: bool = False + def __init__( self, dim: int, - db_config: LanceDBConfig, + db_config: dict, db_case_config: LanceDBIndexConfig, collection_name: str = "vector_bench_test", drop_old: bool = False, + with_scalar_labels: bool = False, **kwargs, ): self.name = "LanceDB" @@ -32,79 +47,195 @@ def __init__( self.table_name = collection_name self.dim = dim self.uri = db_config["uri"] - # avoid the search_param being called every time during the search process - self.search_config = db_case_config.search_param() + self.storage_options = db_config.get("storage_options") or None + self.with_scalar_labels = with_scalar_labels + self.where_clause = None + + self._id_field = "id" + self._vector_field = "vector" + self._label_field = "label" - log.info(f"Search config: {self.search_config}") + # cache search params to avoid repeated calls + self.search_config = db_case_config.search_param() + log.info(f"LanceDB search config: {self.search_config}") - db = lancedb.connect(self.uri) + connect_kwargs = {} + if self.storage_options: + connect_kwargs["storage_options"] = self.storage_options + db = lancedb.connect(self.uri, **connect_kwargs) if drop_old: try: db.drop_table(self.table_name) + log.info(f"LanceDB dropped old table: {self.table_name}") except Exception as e: log.warning(f"Failed to drop table {self.table_name}: {e}") - - try: - db.open_table(self.table_name) - except Exception: - schema = pa.schema( - [pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float32(), list_size=self.dim))] - ) + # Always create a fresh table with the correct schema after drop. + # On remote storage (e.g. GooseFS) drop_table may not fully purge + # metadata immediately, causing open_table to succeed with a stale + # schema that is missing expected fields like 'id'. + schema = self._build_schema() db.create_table(self.table_name, schema=schema, mode="overwrite") + log.info(f"LanceDB created table: {self.table_name} (schema: {schema})") + else: + try: + db.open_table(self.table_name) + except Exception: + schema = self._build_schema() + db.create_table(self.table_name, schema=schema, mode="overwrite") + log.info(f"LanceDB created table: {self.table_name} (schema: {schema})") + + def _build_schema(self) -> pa.Schema: + fields = [ + pa.field(self._id_field, pa.int64()), + pa.field(self._vector_field, pa.list_(pa.float32(), list_size=self.dim)), + ] + if self.with_scalar_labels: + fields.append(pa.field(self._label_field, pa.utf8())) + return pa.schema(fields) @contextmanager def init(self): - self.db = lancedb.connect(self.uri) + connect_kwargs = {} + if self.storage_options: + connect_kwargs["storage_options"] = self.storage_options + self.db = lancedb.connect(self.uri, **connect_kwargs) self.table = self.db.open_table(self.table_name) yield self.db = None self.table = None + def __deepcopy__(self, memo: dict) -> "LanceDB": + """Custom deepcopy: skip live connection/table handles. + + The LanceDB ``Connection`` / ``Table`` objects wrap Rust bindings that + are not picklable. ``ConcurrentInsertRunner`` deep-copies the client + per thread for non-thread-safe DBs; the caller will then invoke + ``init()`` on the copy, which re-opens a fresh connection. + """ + cls = self.__class__ + new_obj = cls.__new__(cls) + memo[id(self)] = new_obj + from copy import deepcopy as _dc + + for k, v in self.__dict__.items(): + if k in ("db", "table"): + new_obj.__dict__[k] = None + else: + new_obj.__dict__[k] = _dc(v, memo) + return new_obj + def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], + labels_data: list[str] | None = None, **kwargs, ) -> tuple[int, Exception | None]: + assert self.table is not None, "Please call self.init() before" try: - data = [{"id": meta, "vector": emb} for meta, emb in zip(metadata, embeddings, strict=False)] - self.table.add(data) + log.info( + f"LanceDB insert_embeddings called with {len(embeddings)} rows, " + f"LANCEDB_BATCH_SIZE={LANCEDB_BATCH_SIZE} -> " + f"{(len(embeddings) + LANCEDB_BATCH_SIZE - 1) // LANCEDB_BATCH_SIZE} fragment(s)" + ) + for offset in range(0, len(embeddings), LANCEDB_BATCH_SIZE): + batch_emb = embeddings[offset : offset + LANCEDB_BATCH_SIZE] + batch_ids = metadata[offset : offset + LANCEDB_BATCH_SIZE] + + id_arr = pa.array(batch_ids, type=pa.int64()) + vec_arr = pa.FixedSizeListArray.from_arrays( + pa.array([v for emb in batch_emb for v in emb], type=pa.float32()), + list_size=self.dim, + ) + + if self.with_scalar_labels and labels_data is not None: + batch_labels = labels_data[offset : offset + LANCEDB_BATCH_SIZE] + label_arr = pa.array(batch_labels, type=pa.utf8()) + batch_table = pa.table( + { + self._id_field: id_arr, + self._vector_field: vec_arr, + self._label_field: label_arr, + } + ) + else: + batch_table = pa.table( + { + self._id_field: id_arr, + self._vector_field: vec_arr, + } + ) + self.table.add(batch_table) + return len(metadata), None except Exception as e: log.warning(f"Failed to insert data into LanceDB table ({self.table_name}), error: {e}") return 0, e + def prepare_filter(self, filters: Filter): + if filters.type == FilterOp.NonFilter: + self.where_clause = None + elif filters.type == FilterOp.NumGE: + self.where_clause = f"{self._id_field} >= {filters.int_value}" + elif filters.type == FilterOp.StrEqual: + self.where_clause = f"{self._label_field} = '{filters.label_value}'" + else: + msg = f"Unsupported filter for LanceDB: {filters}" + raise ValueError(msg) + def search_embedding( self, query: list[float], k: int = 100, - filters: dict | None = None, + **kwargs, ) -> list[int]: - if filters: - results = self.table.search(query).select(["id"]).where(f"id >= {filters['id']}", prefilter=True).limit(k) - if self.case_config.index == IndexType.IVFPQ and "nprobes" in self.search_config: - results = results.nprobes(self.search_config["nprobes"]).to_list() - elif self.case_config.index == IndexType.HNSW and "ef" in self.search_config: - results = results.ef(self.search_config["ef"]).to_list() - else: - results = results.to_list() - else: - results = self.table.search(query).select(["id"]).limit(k) - if self.case_config.index == IndexType.IVFPQ and "nprobes" in self.search_config: - results = results.nprobes(self.search_config["nprobes"]).to_list() - elif self.case_config.index == IndexType.HNSW and "ef" in self.search_config: - results = results.ef(self.search_config["ef"]).to_list() - else: - results = results.to_list() - - return [int(result["id"]) for result in results] + assert self.table is not None, "Please call self.init() before" + + # Include ``_distance`` in the projection to opt in to LanceDB's + # upcoming default behaviour. Without it, lance logs a per-query + # deprecation warning ("This search specified output columns but did + # not include `_distance`"). We only consume ``id`` downstream, so the + # extra column adds negligible overhead. + q = self.table.search(query).select([self._id_field, "_distance"]).limit(k) + + # apply filter + if self.where_clause: + q = q.where(self.where_clause, prefilter=True) + + # apply search parameters based on config + search_cfg = self.search_config + if "nprobes" in search_cfg: + q = q.nprobes(search_cfg["nprobes"]) + if "ef" in search_cfg: + q = q.ef(search_cfg["ef"]) + if "refine_factor" in search_cfg: + q = q.refine_factor(search_cfg["refine_factor"]) + + results = q.to_list() + return [int(r[self._id_field]) for r in results] def optimize(self, data_size: int | None = None): - if self.table and hasattr(self, "case_config") and self.case_config.index != IndexType.NONE: - log.info(f"Creating index for LanceDB table ({self.table_name})") - log.info(f"Index parameters: {self.case_config.index_param()}") - self.table.create_index(**self.case_config.index_param()) - # Better recall with IVF_PQ (though still bad) but breaks HNSW: https://github.com/lancedb/lancedb/issues/2369 - if self.case_config.index in (IndexType.IVFPQ, IndexType.AUTOINDEX): + assert self.table is not None, "Please call self.init() before" + + # Build index if configured + if self.case_config.index != IndexType.NONE: + index_params = self.case_config.index_param() + log.info(f"LanceDB creating index on table ({self.table_name}), params: {index_params}") + self.table.create_index(**index_params) + + # Compact fragments and clean up old versions for better performance. + # Prefer the unified ``table.optimize()`` API (lancedb >= 0.10), which + # internally handles both compaction and version cleanup without + # requiring the optional ``pylance`` package. Fall back to the legacy + # split APIs only if ``optimize`` is unavailable. + try: + if hasattr(self.table, "optimize"): self.table.optimize() + log.info(f"LanceDB optimize completed for table ({self.table_name})") + else: + self.table.compact_files() + self.table.cleanup_old_versions() + log.info(f"LanceDB compact_files + cleanup_old_versions completed for table ({self.table_name})") + except Exception as e: + log.warning(f"LanceDB optimize failed (non-fatal): {e}") diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 3e21746aa..beb11266c 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -13,7 +13,13 @@ ) from ..backend.clients.endee.cli import Endee from ..backend.clients.hologres.cli import HologresHGraph -from ..backend.clients.lancedb.cli import LanceDB +from ..backend.clients.lancedb.cli import ( + LanceDB, + LanceDBAutoIndex, + LanceDBIVFHNSWPQ, + LanceDBIVFHNSWSQ, + LanceDBIVFPQ, +) from ..backend.clients.lindorm.cli import LindormHNSW, LindormIVFBQ, LindormIVFPQ from ..backend.clients.mariadb.cli import MariaDBHNSW from ..backend.clients.memorydb.cli import MemoryDB @@ -68,6 +74,10 @@ cli.add_command(Clickhouse) cli.add_command(Vespa) cli.add_command(LanceDB) +cli.add_command(LanceDBAutoIndex) +cli.add_command(LanceDBIVFPQ) +cli.add_command(LanceDBIVFHNSWSQ) +cli.add_command(LanceDBIVFHNSWPQ) cli.add_command(HologresHGraph) cli.add_command(QdrantCloud) cli.add_command(QdrantLocal)