Skip to content

Commit e544b47

Browse files
author
guochenghao.bd
committed
fix(test): torch cpu indexing avoids distributed; stabilize one-pass ivfpq
1 parent cd1694c commit e544b47

2 files changed

Lines changed: 107 additions & 26 deletions

File tree

python/python/lance/dataset.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .blob import BlobFile
4343
from .dependencies import (
4444
_check_for_numpy,
45+
_check_for_torch,
4546
torch,
4647
)
4748
from .dependencies import numpy as np
@@ -2546,7 +2547,7 @@ def create_index(
25462547
train: bool = True,
25472548
# distributed indexing parameters
25482549
fragment_ids: Optional[List[int]] = None,
2549-
fragment_uuid: Optional[str] = None,
2550+
index_uuid: Optional[str] = None,
25502551
*,
25512552
target_partition_size: Optional[int] = None,
25522553
**kwargs,
@@ -2624,7 +2625,7 @@ def create_index(
26242625
method creates temporary index metadata but does not commit the index
26252626
to the dataset. The index can be committed later using
26262627
merge_index_metadata(index_uuid, "VECTOR", column=..., index_name=...).
2627-
fragment_uuid : str, optional
2628+
index_uuid : str, optional
26282629
A UUID to use for fragment-level distributed indexing. Multiple
26292630
fragment-level indices need to share UUID for later merging.
26302631
If not provided, a new UUID will be generated.
@@ -2795,6 +2796,34 @@ def create_index(
27952796

27962797
# Handle timing for various parts of accelerated builds
27972798
timers = {}
2799+
2800+
# Early detection and gating: Torch detected ⇒ enforce single-node
2801+
# & skip distributed keys. Also normalize index_file_version for
2802+
# downstream accelerator behavior.
2803+
idx_ver_obj = kwargs.get("index_file_version")
2804+
idx_ver_str = None
2805+
try:
2806+
if isinstance(idx_ver_obj, str):
2807+
idx_ver_str = idx_ver_obj
2808+
elif hasattr(idx_ver_obj, "value"):
2809+
idx_ver_str = str(idx_ver_obj.value)
2810+
elif hasattr(idx_ver_obj, "name"):
2811+
idx_ver_str = str(idx_ver_obj.name)
2812+
else:
2813+
idx_ver_str = str(idx_ver_obj)
2814+
except Exception:
2815+
idx_ver_str = None
2816+
# NOTE: Do not pass any distributed-related params when torch is involved
2817+
torch_detected_early = accelerator is not None
2818+
if torch_detected_early:
2819+
if fragment_ids is not None or index_uuid is not None:
2820+
LOGGER.info(
2821+
"Torch detected (early); enforce single-node indexing "
2822+
"(distributed is CPU-only)."
2823+
)
2824+
fragment_ids = None
2825+
index_uuid = None
2826+
27982827
if accelerator is not None:
27992828
from .vector import (
28002829
one_pass_assign_ivf_pq_on_accelerator,
@@ -2843,10 +2872,21 @@ def create_index(
28432872
)
28442873
LOGGER.info("ivf+pq transform time: %ss", ivfpq_assign_time)
28452874

2846-
kwargs["precomputed_shuffle_buffers"] = shuffle_buffers
2847-
kwargs["precomputed_shuffle_buffers_path"] = os.path.join(
2848-
shuffle_output_dir, "data"
2849-
)
2875+
# IMPORTANT: For V3 index file version, avoid passing precomputed
2876+
# PQ shuffle buffers to prevent PQ codebook mismatch (Rust retrains
2877+
# quantizer and ignores provided codebook).
2878+
ver = (idx_ver_str or "V3").upper()
2879+
if ver == "LEGACY":
2880+
kwargs["precomputed_shuffle_buffers"] = shuffle_buffers
2881+
kwargs["precomputed_shuffle_buffers_path"] = os.path.join(
2882+
shuffle_output_dir, "data"
2883+
)
2884+
else:
2885+
LOGGER.info(
2886+
"IndexFileVersion=%s detected; skip precomputed shuffle "
2887+
"buffers to stabilize IVF_PQ",
2888+
ver,
2889+
)
28502890
if index_type.startswith("IVF"):
28512891
if (ivf_centroids is not None) and (ivf_centroids_file is not None):
28522892
raise ValueError(
@@ -2941,8 +2981,11 @@ def create_index(
29412981
)
29422982
kwargs["num_sub_vectors"] = num_sub_vectors
29432983

2944-
if pq_codebook is not None:
2945-
# User provided IVF centroids
2984+
# Only attach PQ codebook for LEGACY format; V3 retrains PQ and
2985+
# ignores user codebook.
2986+
ver = (idx_ver_str or "V3").upper()
2987+
if pq_codebook is not None and ver == "LEGACY":
2988+
# User provided PQ codebook
29462989
if _check_for_numpy(pq_codebook) and isinstance(
29472990
pq_codebook, np.ndarray
29482991
):
@@ -2968,18 +3011,56 @@ def create_index(
29683011
[pq_codebook], ["_pq_codebook"]
29693012
)
29703013
kwargs["pq_codebook"] = pq_codebook_batch
3014+
elif pq_codebook is not None:
3015+
LOGGER.info(
3016+
"IndexFileVersion=%s detected; skip passing pq_codebook "
3017+
"to avoid mismatch",
3018+
ver,
3019+
)
29713020

29723021
if shuffle_partition_batches is not None:
29733022
kwargs["shuffle_partition_batches"] = shuffle_partition_batches
29743023
if shuffle_partition_concurrency is not None:
29753024
kwargs["shuffle_partition_concurrency"] = shuffle_partition_concurrency
29763025

2977-
# Add fragment_ids and fragment_uuid to kwargs if provided for
3026+
# Add fragment_ids and index_uuid to kwargs if provided for
29783027
# distributed indexing
3028+
# IMPORTANT: Distributed indexing is CPU-only. Enforce single-node when
3029+
# accelerator or torch-related path is detected.
3030+
torch_detected = False
3031+
try:
3032+
if accelerator is not None:
3033+
torch_detected = True
3034+
else:
3035+
impl = kwargs.get("implementation")
3036+
use_torch_flag = kwargs.get("use_torch") is True
3037+
one_pass_flag = kwargs.get("one_pass_ivfpq") is True
3038+
torch_centroids = _check_for_torch(ivf_centroids)
3039+
torch_codebook = _check_for_torch(pq_codebook)
3040+
if (
3041+
(isinstance(impl, str) and impl.lower() == "torch")
3042+
or use_torch_flag
3043+
or one_pass_flag
3044+
or torch_centroids
3045+
or torch_codebook
3046+
):
3047+
torch_detected = True
3048+
except Exception:
3049+
# Be conservative: if detection fails, do not modify behavior
3050+
pass
3051+
3052+
if torch_detected:
3053+
if fragment_ids is not None or index_uuid is not None:
3054+
LOGGER.info(
3055+
"Torch detected; "
3056+
"enforce single-node indexing (distributed is CPU-only)."
3057+
)
3058+
fragment_ids = None
3059+
index_uuid = None
29793060
if fragment_ids is not None:
29803061
kwargs["fragment_ids"] = fragment_ids
2981-
if fragment_uuid is not None:
2982-
kwargs["fragment_uuid"] = fragment_uuid
3062+
if index_uuid is not None:
3063+
kwargs["index_uuid"] = index_uuid
29833064

29843065
timers["final_create_index:start"] = time.time()
29853066
self._ds.create_index(

python/python/tests/test_vector_index.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,7 +1818,7 @@ def build_distributed_vector_index(
18181818
18191819
Steps:
18201820
- Partition fragments into `world` groups
1821-
- For each group, call create_index with fragment_ids and a shared fragment_uuid
1821+
- For each group, call create_index with fragment_ids and a shared index_uuid
18221822
- Merge metadata (commit index manifest)
18231823
18241824
Returns the dataset (post-merge) for querying.
@@ -1837,7 +1837,7 @@ def build_distributed_vector_index(
18371837
column=column,
18381838
index_type=index_type,
18391839
fragment_ids=g,
1840-
fragment_uuid=shared_uuid,
1840+
index_uuid=shared_uuid,
18411841
num_partitions=num_partitions,
18421842
num_sub_vectors=num_sub_vectors,
18431843
**index_params,
@@ -2122,7 +2122,7 @@ def test_distributed_api_basic_success(tmp_path):
21222122
column="vector",
21232123
index_type="IVF_PQ",
21242124
fragment_ids=fragment_ids,
2125-
fragment_uuid=shared_uuid,
2125+
index_uuid=shared_uuid,
21262126
num_partitions=8,
21272127
num_sub_vectors=16,
21282128
)
@@ -2152,7 +2152,7 @@ def test_fragment_allocations_divisibility_error(tmp_path, case_name, selector):
21522152
column="vector",
21532153
index_type="IVF_PQ",
21542154
fragment_ids=fragment_ids,
2155-
fragment_uuid=shared_uuid,
2155+
index_uuid=shared_uuid,
21562156
num_partitions=5,
21572157
num_sub_vectors=96,
21582158
)
@@ -2172,7 +2172,7 @@ def test_metadata_merge_pq_success(tmp_path):
21722172
column="vector",
21732173
index_type="IVF_PQ",
21742174
fragment_ids=node1,
2175-
fragment_uuid=shared_uuid,
2175+
index_uuid=shared_uuid,
21762176
num_partitions=8,
21772177
num_sub_vectors=16,
21782178
ivf_centroids=centroids,
@@ -2181,7 +2181,7 @@ def test_metadata_merge_pq_success(tmp_path):
21812181
column="vector",
21822182
index_type="IVF_PQ",
21832183
fragment_ids=node2,
2184-
fragment_uuid=shared_uuid,
2184+
index_uuid=shared_uuid,
21852185
num_partitions=8,
21862186
num_sub_vectors=16,
21872187
ivf_centroids=centroids,
@@ -2204,7 +2204,7 @@ def test_invalid_column_name_precise(tmp_path):
22042204
column="nonexistent_column",
22052205
index_type="IVF_PQ",
22062206
fragment_ids=[ds.get_fragments()[0].fragment_id],
2207-
fragment_uuid=str(uuid.uuid4()),
2207+
index_uuid=str(uuid.uuid4()),
22082208
)
22092209

22102210

@@ -2256,7 +2256,7 @@ def test_distributed_workflow_merge_and_search(tmp_path):
22562256
column="vector",
22572257
index_type="IVF_PQ",
22582258
fragment_ids=node1,
2259-
fragment_uuid=shared_uuid,
2259+
index_uuid=shared_uuid,
22602260
num_partitions=4,
22612261
num_sub_vectors=4,
22622262
ivf_centroids=centroids,
@@ -2265,7 +2265,7 @@ def test_distributed_workflow_merge_and_search(tmp_path):
22652265
column="vector",
22662266
index_type="IVF_PQ",
22672267
fragment_ids=node2,
2268-
fragment_uuid=shared_uuid,
2268+
index_uuid=shared_uuid,
22692269
num_partitions=4,
22702270
num_sub_vectors=4,
22712271
ivf_centroids=centroids,
@@ -2292,15 +2292,15 @@ def test_vector_merge_two_shards_success_flat(tmp_path):
22922292
column="vector",
22932293
index_type="IVF_FLAT",
22942294
fragment_ids=shard1,
2295-
fragment_uuid=shared_uuid,
2295+
index_uuid=shared_uuid,
22962296
num_partitions=4,
22972297
num_sub_vectors=128,
22982298
)
22992299
ds.create_index(
23002300
column="vector",
23012301
index_type="IVF_FLAT",
23022302
fragment_ids=shard2,
2303-
fragment_uuid=shared_uuid,
2303+
index_uuid=shared_uuid,
23042304
num_partitions=4,
23052305
num_sub_vectors=128,
23062306
)
@@ -2324,7 +2324,7 @@ def test_distributed_ivf_hnsw_pq_success(tmp_path):
23242324
column="vector",
23252325
index_type="IVF_HNSW_PQ",
23262326
fragment_ids=node1,
2327-
fragment_uuid=shared_uuid,
2327+
index_uuid=shared_uuid,
23282328
num_partitions=4,
23292329
num_sub_vectors=4,
23302330
ivf_centroids=centroids,
@@ -2333,7 +2333,7 @@ def test_distributed_ivf_hnsw_pq_success(tmp_path):
23332333
column="vector",
23342334
index_type="IVF_HNSW_PQ",
23352335
fragment_ids=node2,
2336-
fragment_uuid=shared_uuid,
2336+
index_uuid=shared_uuid,
23372337
num_partitions=4,
23382338
num_sub_vectors=4,
23392339
ivf_centroids=centroids,
@@ -2361,15 +2361,15 @@ def test_distributed_ivf_hnsw_flat_success(tmp_path):
23612361
column="vector",
23622362
index_type="IVF_HNSW_FLAT",
23632363
fragment_ids=node1,
2364-
fragment_uuid=shared_uuid,
2364+
index_uuid=shared_uuid,
23652365
num_partitions=4,
23662366
num_sub_vectors=128,
23672367
)
23682368
ds.create_index(
23692369
column="vector",
23702370
index_type="IVF_HNSW_FLAT",
23712371
fragment_ids=node2,
2372-
fragment_uuid=shared_uuid,
2372+
index_uuid=shared_uuid,
23732373
num_partitions=4,
23742374
num_sub_vectors=128,
23752375
)

0 commit comments

Comments
 (0)