4242from .blob import BlobFile
4343from .dependencies import (
4444 _check_for_numpy ,
45+ _check_for_torch ,
4546 torch ,
4647)
4748from .dependencies import numpy as np
@@ -2546,7 +2547,7 @@ def create_index(
25462547 train : bool = True ,
25472548 # distributed indexing parameters
25482549 fragment_ids : Optional [List [int ]] = None ,
2549- fragment_uuid : Optional [str ] = None ,
2550+ index_uuid : Optional [str ] = None ,
25502551 * ,
25512552 target_partition_size : Optional [int ] = None ,
25522553 ** kwargs ,
@@ -2624,7 +2625,7 @@ def create_index(
26242625 method creates temporary index metadata but does not commit the index
26252626 to the dataset. The index can be committed later using
26262627 merge_index_metadata(index_uuid, "VECTOR", column=..., index_name=...).
2627- fragment_uuid : str, optional
2628+ index_uuid : str, optional
26282629 A UUID to use for fragment-level distributed indexing. Multiple
26292630 fragment-level indices need to share UUID for later merging.
26302631 If not provided, a new UUID will be generated.
@@ -2795,6 +2796,34 @@ def create_index(
27952796
27962797 # Handle timing for various parts of accelerated builds
27972798 timers = {}
2799+
2800+ # Early detection and gating: Torch detected ⇒ enforce single-node
2801+ # & skip distributed keys. Also normalize index_file_version for
2802+ # downstream accelerator behavior.
2803+ idx_ver_obj = kwargs .get ("index_file_version" )
2804+ idx_ver_str = None
2805+ try :
2806+ if isinstance (idx_ver_obj , str ):
2807+ idx_ver_str = idx_ver_obj
2808+ elif hasattr (idx_ver_obj , "value" ):
2809+ idx_ver_str = str (idx_ver_obj .value )
2810+ elif hasattr (idx_ver_obj , "name" ):
2811+ idx_ver_str = str (idx_ver_obj .name )
2812+ else :
2813+ idx_ver_str = str (idx_ver_obj )
2814+ except Exception :
2815+ idx_ver_str = None
2816+ # NOTE: Do not pass any distributed-related params when torch is involved
2817+ torch_detected_early = accelerator is not None
2818+ if torch_detected_early :
2819+ if fragment_ids is not None or index_uuid is not None :
2820+ LOGGER .info (
2821+ "Torch detected (early); enforce single-node indexing "
2822+ "(distributed is CPU-only)."
2823+ )
2824+ fragment_ids = None
2825+ index_uuid = None
2826+
27982827 if accelerator is not None :
27992828 from .vector import (
28002829 one_pass_assign_ivf_pq_on_accelerator ,
@@ -2843,10 +2872,21 @@ def create_index(
28432872 )
28442873 LOGGER .info ("ivf+pq transform time: %ss" , ivfpq_assign_time )
28452874
2846- kwargs ["precomputed_shuffle_buffers" ] = shuffle_buffers
2847- kwargs ["precomputed_shuffle_buffers_path" ] = os .path .join (
2848- shuffle_output_dir , "data"
2849- )
2875+ # IMPORTANT: For V3 index file version, avoid passing precomputed
2876+ # PQ shuffle buffers to prevent PQ codebook mismatch (Rust retrains
2877+ # quantizer and ignores provided codebook).
2878+ ver = (idx_ver_str or "V3" ).upper ()
2879+ if ver == "LEGACY" :
2880+ kwargs ["precomputed_shuffle_buffers" ] = shuffle_buffers
2881+ kwargs ["precomputed_shuffle_buffers_path" ] = os .path .join (
2882+ shuffle_output_dir , "data"
2883+ )
2884+ else :
2885+ LOGGER .info (
2886+ "IndexFileVersion=%s detected; skip precomputed shuffle "
2887+ "buffers to stabilize IVF_PQ" ,
2888+ ver ,
2889+ )
28502890 if index_type .startswith ("IVF" ):
28512891 if (ivf_centroids is not None ) and (ivf_centroids_file is not None ):
28522892 raise ValueError (
@@ -2941,8 +2981,11 @@ def create_index(
29412981 )
29422982 kwargs ["num_sub_vectors" ] = num_sub_vectors
29432983
2944- if pq_codebook is not None :
2945- # User provided IVF centroids
2984+ # Only attach PQ codebook for LEGACY format; V3 retrains PQ and
2985+ # ignores user codebook.
2986+ ver = (idx_ver_str or "V3" ).upper ()
2987+ if pq_codebook is not None and ver == "LEGACY" :
2988+ # User provided PQ codebook
29462989 if _check_for_numpy (pq_codebook ) and isinstance (
29472990 pq_codebook , np .ndarray
29482991 ):
@@ -2968,18 +3011,56 @@ def create_index(
29683011 [pq_codebook ], ["_pq_codebook" ]
29693012 )
29703013 kwargs ["pq_codebook" ] = pq_codebook_batch
3014+ elif pq_codebook is not None :
3015+ LOGGER .info (
3016+ "IndexFileVersion=%s detected; skip passing pq_codebook "
3017+ "to avoid mismatch" ,
3018+ ver ,
3019+ )
29713020
29723021 if shuffle_partition_batches is not None :
29733022 kwargs ["shuffle_partition_batches" ] = shuffle_partition_batches
29743023 if shuffle_partition_concurrency is not None :
29753024 kwargs ["shuffle_partition_concurrency" ] = shuffle_partition_concurrency
29763025
2977- # Add fragment_ids and fragment_uuid to kwargs if provided for
3026+ # Add fragment_ids and index_uuid to kwargs if provided for
29783027 # distributed indexing
3028+ # IMPORTANT: Distributed indexing is CPU-only. Enforce single-node when
3029+ # accelerator or torch-related path is detected.
3030+ torch_detected = False
3031+ try :
3032+ if accelerator is not None :
3033+ torch_detected = True
3034+ else :
3035+ impl = kwargs .get ("implementation" )
3036+ use_torch_flag = kwargs .get ("use_torch" ) is True
3037+ one_pass_flag = kwargs .get ("one_pass_ivfpq" ) is True
3038+ torch_centroids = _check_for_torch (ivf_centroids )
3039+ torch_codebook = _check_for_torch (pq_codebook )
3040+ if (
3041+ (isinstance (impl , str ) and impl .lower () == "torch" )
3042+ or use_torch_flag
3043+ or one_pass_flag
3044+ or torch_centroids
3045+ or torch_codebook
3046+ ):
3047+ torch_detected = True
3048+ except Exception :
3049+ # Be conservative: if detection fails, do not modify behavior
3050+ pass
3051+
3052+ if torch_detected :
3053+ if fragment_ids is not None or index_uuid is not None :
3054+ LOGGER .info (
3055+ "Torch detected; "
3056+ "enforce single-node indexing (distributed is CPU-only)."
3057+ )
3058+ fragment_ids = None
3059+ index_uuid = None
29793060 if fragment_ids is not None :
29803061 kwargs ["fragment_ids" ] = fragment_ids
2981- if fragment_uuid is not None :
2982- kwargs ["fragment_uuid " ] = fragment_uuid
3062+ if index_uuid is not None :
3063+ kwargs ["index_uuid " ] = index_uuid
29833064
29843065 timers ["final_create_index:start" ] = time .time ()
29853066 self ._ds .create_index (
0 commit comments