diff --git a/java/lance-jni/src/blocking_dataset.rs b/java/lance-jni/src/blocking_dataset.rs index 04eaf4d0730..1afe2762dcd 100644 --- a/java/lance-jni/src/blocking_dataset.rs +++ b/java/lance-jni/src/blocking_dataset.rs @@ -26,7 +26,6 @@ use jni::sys::{jbyteArray, jlong}; use jni::{objects::JObject, JNIEnv}; use lance::dataset::builder::DatasetBuilder; use lance::dataset::cleanup::{CleanupPolicy, RemovalStats}; -use lance::dataset::index::LanceIndexStoreExt; use lance::dataset::optimize::{compact_files, CompactionOptions as RustCompactionOptions}; use lance::dataset::refs::{Ref, TagContents}; use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt}; @@ -41,7 +40,6 @@ use lance::table::format::{BasePath, Fragment}; use lance_core::datatypes::Schema as LanceSchema; use lance_index::optimize::OptimizeOptions; use lance_index::scalar::btree::BTreeParameters; -use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::DatasetIndexExt; use lance_index::{IndexParams, IndexType}; use lance_io::object_store::ObjectStoreRegistry; @@ -975,44 +973,12 @@ fn inner_merge_index_metadata( unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; RT.block_on(async { - let index_store = LanceIndexStore::from_dataset_for_new(&dataset_guard.inner, &index_uuid)?; - let object_store = dataset_guard.inner.object_store(); - let index_dir = dataset_guard.inner.indices_dir().child(index_uuid); - - match index_type { - IndexType::Inverted => lance_index::scalar::inverted::builder::merge_index_files( - object_store, - &index_dir, - Arc::new(index_store), - ) - .await - .map_err(|e| { - Error::runtime_error(format!( - "Cannot create index of type: {:?}. Caused by: {:?}", - index_type, - e.to_string() - )) - }), - IndexType::BTree => lance_index::scalar::btree::merge_index_files( - object_store, - &index_dir, - Arc::new(index_store), - batch_readhead, - ) + dataset_guard + .inner + .merge_index_metadata(&index_uuid, index_type, batch_readhead) .await - .map_err(|e| { - Error::runtime_error(format!( - "Cannot create index of type: {:?}. Caused by: {:?}", - index_type, - e.to_string() - )) - }), - _ => Err(Error::input_error(format!( - "Cannot merge index type: {:?}. Only supports BTREE and INVERTED now.", - index_type - ))), - } - }) + })?; + Ok(()) } #[no_mangle] diff --git a/java/lance-jni/src/lib.rs b/java/lance-jni/src/lib.rs index 566f77dd110..b3fa0402f38 100644 --- a/java/lance-jni/src/lib.rs +++ b/java/lance-jni/src/lib.rs @@ -57,6 +57,7 @@ mod storage_options; pub mod traits; mod transaction; pub mod utils; +mod vector_trainer; pub use error::Error; pub use error::Result; diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 02c5596d74b..78677599215 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -3,8 +3,9 @@ use std::sync::Arc; -use arrow::array::Float32Array; -use jni::objects::{JMap, JObject, JString, JValue, JValueGen}; +use arrow::array::{ArrayRef, FixedSizeListArray, Float32Array}; +use arrow_schema::{DataType, Field}; +use jni::objects::{JFloatArray, JMap, JObject, JString, JValue, JValueGen}; use jni::sys::{jboolean, jfloat, jlong}; use jni::JNIEnv; use lance::dataset::optimize::CompactionOptions; @@ -256,7 +257,7 @@ pub fn get_vector_index_params( let shuffle_partition_concurrency = env .get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionConcurrency")?; - let ivf_params = IvfBuildParams { + let mut ivf_params = IvfBuildParams { num_partitions: Some(num_partitions), max_iters, sample_rate, @@ -264,6 +265,44 @@ pub fn get_vector_index_params( shuffle_partition_concurrency, ..Default::default() }; + + // Optional pre-trained IVF centroids from Java IvfBuildParams + // Method signature: float[] getCentroids() + let centroids_obj = env + .call_method(&ivf_params_obj, "getCentroids", "()[F", &[])? + .l()?; + + if !centroids_obj.is_null() { + let jarray: JFloatArray = centroids_obj.into(); + let length = env.get_array_length(&jarray)?; + if length > 0 { + if (length as usize) % num_partitions != 0 { + return Err(Error::input_error(format!( + "Invalid IVF centroids: length {} is not divisible by num_partitions {}", + length, num_partitions + ))); + } + let mut buffer = vec![0.0f32; length as usize]; + env.get_float_array_region(&jarray, 0, &mut buffer)?; + let dimension = buffer.len() / num_partitions; + + let values = Float32Array::from(buffer); + let fsl = FixedSizeListArray::try_new( + Arc::new(Field::new("item", DataType::Float32, false)), + dimension as i32, + Arc::new(values) as ArrayRef, + None, + ) + .map_err(|e| { + Error::input_error(format!( + "Failed to construct FixedSizeListArray for IVF centroids: {e}" + )) + })?; + + ivf_params.centroids = Some(Arc::new(fsl)); + } + } + stages.push(StageParams::Ivf(ivf_params)); // Parse HnswBuildParams @@ -305,13 +344,34 @@ pub fn get_vector_index_params( env.get_int_as_usize_from_method(&pq_obj, "getKmeansRedos")?; let sample_rate = env.get_int_as_usize_from_method(&pq_obj, "getSampleRate")?; + // Optional pre-trained PQ codebook from Java PQBuildParams + // Method signature: float[] getCodebook() + let codebook_obj = env + .call_method(&pq_obj, "getCodebook", "()[F", &[])? + .l()?; + + let codebook = if !codebook_obj.is_null() { + let jarray: JFloatArray = codebook_obj.into(); + let length = env.get_array_length(&jarray)?; + if length > 0 { + let mut buffer = vec![0.0f32; length as usize]; + env.get_float_array_region(&jarray, 0, &mut buffer)?; + let values = Float32Array::from(buffer); + Some(Arc::new(values) as _) + } else { + None + } + } else { + None + }; + Ok(PQBuildParams { num_sub_vectors, num_bits, max_iters, kmeans_redos, + codebook, sample_rate, - ..Default::default() }) }, )?; diff --git a/java/lance-jni/src/vector_trainer.rs b/java/lance-jni/src/vector_trainer.rs new file mode 100755 index 00000000000..e2d6012859e --- /dev/null +++ b/java/lance-jni/src/vector_trainer.rs @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; +use crate::error::{Error, Result}; +use crate::ffi::JNIEnvExt; +use crate::RT; + +use arrow::array::{FixedSizeListArray, Float32Array}; +use jni::objects::{JClass, JFloatArray, JObject, JString}; +use jni::sys::jfloatArray; +use jni::JNIEnv; +use lance::index::vector::utils::get_vector_dim; +use lance_index::vector::ivf::builder::IvfBuildParams as RustIvfBuildParams; +use lance_index::vector::pq::builder::PQBuildParams as RustPQBuildParams; +use lance_linalg::distance::MetricType; + +/// Flatten a FixedSizeList into a contiguous Vec. +fn flatten_fixed_size_list_to_f32(arr: &FixedSizeListArray) -> Result> { + let values = arr + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::input_error(format!( + "Expected FixedSizeList, got value type {}", + arr.value_type() + )) + })?; + + Ok(values.values().to_vec()) +} + +fn build_ivf_params_from_java( + env: &mut JNIEnv, + ivf_params_obj: &JObject, +) -> Result { + let num_partitions = env.get_int_as_usize_from_method(ivf_params_obj, "getNumPartitions")?; + let max_iters = env.get_int_as_usize_from_method(ivf_params_obj, "getMaxIters")?; + let sample_rate = env.get_int_as_usize_from_method(ivf_params_obj, "getSampleRate")?; + let shuffle_partition_batches = + env.get_int_as_usize_from_method(ivf_params_obj, "getShufflePartitionBatches")?; + let shuffle_partition_concurrency = + env.get_int_as_usize_from_method(ivf_params_obj, "getShufflePartitionConcurrency")?; + + Ok(RustIvfBuildParams { + num_partitions: Some(num_partitions), + max_iters, + sample_rate, + shuffle_partition_batches, + shuffle_partition_concurrency, + ..Default::default() + }) +} + +fn build_pq_params_from_java( + env: &mut JNIEnv, + pq_params_obj: &JObject, +) -> Result { + let num_sub_vectors = env.get_int_as_usize_from_method(pq_params_obj, "getNumSubVectors")?; + let num_bits = env.get_int_as_usize_from_method(pq_params_obj, "getNumBits")?; + let max_iters = env.get_int_as_usize_from_method(pq_params_obj, "getMaxIters")?; + let kmeans_redos = env.get_int_as_usize_from_method(pq_params_obj, "getKmeansRedos")?; + let sample_rate = env.get_int_as_usize_from_method(pq_params_obj, "getSampleRate")?; + + Ok(RustPQBuildParams { + num_sub_vectors, + num_bits, + max_iters, + kmeans_redos, + codebook: None, + sample_rate, + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainIvfCentroids<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + dataset_obj: JObject<'local>, // org.lance.Dataset + column_jstr: JString<'local>, // java.lang.String + ivf_params_obj: JObject<'local>, // org.lance.index.vector.IvfBuildParams +) -> jfloatArray { + ok_or_throw_with_return!( + env, + inner_train_ivf_centroids(&mut env, dataset_obj, column_jstr, ivf_params_obj) + .map(|arr| arr.into_raw()), + JFloatArray::default().into_raw() + ) +} + +fn inner_train_ivf_centroids<'local>( + env: &mut JNIEnv<'local>, + dataset_obj: JObject<'local>, + column_jstr: JString<'local>, + ivf_params_obj: JObject<'local>, +) -> Result> { + let column: String = env.get_string(&column_jstr)?.into(); + let ivf_params = build_ivf_params_from_java(env, &ivf_params_obj)?; + + let flattened: Vec = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(dataset_obj, NATIVE_DATASET) }?; + let dataset = &dataset_guard.inner; + + let dim = get_vector_dim(dataset.schema(), &column)?; + + // For now we default to L2 metric; tests and Java bindings currently use L2. + let metric_type = MetricType::L2; + + let ivf_model = RT.block_on(lance::index::vector::ivf::build_ivf_model( + dataset, + &column, + dim, + metric_type, + &ivf_params, + ))?; + + let centroids = ivf_model + .centroids + .ok_or_else(|| Error::runtime_error("IVF model missing centroids".to_string()))?; + + flatten_fixed_size_list_to_f32(¢roids)? + }; + + let jarray = env.new_float_array(flattened.len() as i32)?; + env.set_float_array_region(&jarray, 0, &flattened)?; + Ok(jarray) +} + +#[no_mangle] +pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainPqCodebook<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + dataset_obj: JObject<'local>, // org.lance.Dataset + column_jstr: JString<'local>, // java.lang.String + pq_params_obj: JObject<'local>, // org.lance.index.vector.PQBuildParams +) -> jfloatArray { + ok_or_throw_with_return!( + env, + inner_train_pq_codebook(&mut env, dataset_obj, column_jstr, pq_params_obj) + .map(|arr| arr.into_raw()), + JFloatArray::default().into_raw() + ) +} + +fn inner_train_pq_codebook<'local>( + env: &mut JNIEnv<'local>, + dataset_obj: JObject<'local>, + column_jstr: JString<'local>, + pq_params_obj: JObject<'local>, +) -> Result> { + let column: String = env.get_string(&column_jstr)?.into(); + let pq_params = build_pq_params_from_java(env, &pq_params_obj)?; + + let flattened: Vec = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(dataset_obj, NATIVE_DATASET) }?; + let dataset = &dataset_guard.inner; + + let dim = get_vector_dim(dataset.schema(), &column)?; + let metric_type = MetricType::L2; + + let pq = RT.block_on(lance::index::vector::pq::build_pq_model( + dataset, + &column, + dim, + metric_type, + &pq_params, + None, + ))?; + + flatten_fixed_size_list_to_f32(&pq.codebook)? + }; + + let jarray = env.new_float_array(flattened.len() as i32)?; + env.set_float_array_region(&jarray, 0, &flattened)?; + Ok(jarray) +} diff --git a/java/src/main/java/org/lance/Dataset.java b/java/src/main/java/org/lance/Dataset.java index 3ee0033f45a..239de47b67a 100644 --- a/java/src/main/java/org/lance/Dataset.java +++ b/java/src/main/java/org/lance/Dataset.java @@ -880,7 +880,7 @@ public Index createIndex(IndexOptions options) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); return nativeCreateIndex( options.getColumns(), - options.getIndexType().ordinal(), + options.getIndexType().getValue(), options.getIndexName(), options.getIndexParams(), options.isReplace(), diff --git a/java/src/main/java/org/lance/index/vector/IvfBuildParams.java b/java/src/main/java/org/lance/index/vector/IvfBuildParams.java index c9a795a03cc..4b8ace8786f 100644 --- a/java/src/main/java/org/lance/index/vector/IvfBuildParams.java +++ b/java/src/main/java/org/lance/index/vector/IvfBuildParams.java @@ -29,6 +29,7 @@ public class IvfBuildParams { private final int shufflePartitionBatches; private final int shufflePartitionConcurrency; private final boolean useResidual; + private final float[] centroids; private IvfBuildParams(Builder builder) { this.numPartitions = builder.numPartitions; @@ -37,6 +38,7 @@ private IvfBuildParams(Builder builder) { this.shufflePartitionBatches = builder.shufflePartitionBatches; this.shufflePartitionConcurrency = builder.shufflePartitionConcurrency; this.useResidual = builder.useResidual; + this.centroids = builder.centroids; } public static class Builder { @@ -46,6 +48,7 @@ public static class Builder { private int shufflePartitionBatches = 1024 * 10; private int shufflePartitionConcurrency = 2; private boolean useResidual = true; + private float[] centroids = null; /** * Parameters for building an IVF index. Train IVF centroids for the given vector column. This @@ -125,6 +128,19 @@ public Builder setUseResidual(boolean useResidual) { return this; } + /** + * Set pre-trained IVF centroids. + * + *

The centroids are flattened as [numPartitions][dimension]. + * + * @param centroids pre-trained IVF centroids + * @return Builder + */ + public Builder setCentroids(float[] centroids) { + this.centroids = centroids; + return this; + } + public IvfBuildParams build() { return new IvfBuildParams(this); } @@ -154,6 +170,10 @@ public boolean useResidual() { return useResidual; } + public float[] getCentroids() { + return centroids; + } + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -163,6 +183,7 @@ public String toString() { .add("shufflePartitionBatches", shufflePartitionBatches) .add("shufflePartitionConcurrency", shufflePartitionConcurrency) .add("useResidual", useResidual) + .add("hasCentroids", centroids != null) .toString(); } } diff --git a/java/src/main/java/org/lance/index/vector/PQBuildParams.java b/java/src/main/java/org/lance/index/vector/PQBuildParams.java index 8d076bc44fc..1b414e4dd28 100644 --- a/java/src/main/java/org/lance/index/vector/PQBuildParams.java +++ b/java/src/main/java/org/lance/index/vector/PQBuildParams.java @@ -29,6 +29,7 @@ public class PQBuildParams { private final int maxIters; private final int kmeansRedos; private final int sampleRate; + private final float[] codebook; private PQBuildParams(Builder builder) { this.numSubVectors = builder.numSubVectors; @@ -36,6 +37,7 @@ private PQBuildParams(Builder builder) { this.maxIters = builder.maxIters; this.kmeansRedos = builder.kmeansRedos; this.sampleRate = builder.sampleRate; + this.codebook = builder.codebook; } public static class Builder { @@ -44,6 +46,7 @@ public static class Builder { private int maxIters = 50; private int kmeansRedos = 1; private int sampleRate = 256; + private float[] codebook = null; /** Create a new builder for training a PQ model. */ public Builder() {} @@ -96,6 +99,19 @@ public Builder setSampleRate(int sampleRate) { return this; } + /** + * Set pre-trained PQ codebook. + * + *

The codebook is flattened as [num_centroids][dimension]. + * + * @param codebook pre-trained PQ codebook + * @return Builder + */ + public Builder setCodebook(float[] codebook) { + this.codebook = codebook; + return this; + } + public PQBuildParams build() { return new PQBuildParams(this); } @@ -121,6 +137,10 @@ public int getSampleRate() { return sampleRate; } + public float[] getCodebook() { + return codebook; + } + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -129,6 +149,7 @@ public String toString() { .add("maxIters", maxIters) .add("kmeansRedos", kmeansRedos) .add("sampleRate", sampleRate) + .add("hasCodebook", codebook != null) .toString(); } } diff --git a/java/src/main/java/org/lance/index/vector/VectorIndexParams.java b/java/src/main/java/org/lance/index/vector/VectorIndexParams.java index c80e8e053fb..07159cdd048 100644 --- a/java/src/main/java/org/lance/index/vector/VectorIndexParams.java +++ b/java/src/main/java/org/lance/index/vector/VectorIndexParams.java @@ -43,9 +43,6 @@ private void validate() { if (hnswParams.isPresent() && !pqParams.isPresent() && !sqParams.isPresent()) { throw new IllegalArgumentException("HNSW must be combined with either PQ or SQ"); } - if (sqParams.isPresent() && !hnswParams.isPresent()) { - throw new IllegalArgumentException("IVF + SQ is not supported"); - } } /** diff --git a/java/src/main/java/org/lance/index/vector/VectorTrainer.java b/java/src/main/java/org/lance/index/vector/VectorTrainer.java new file mode 100755 index 00000000000..03081176bf1 --- /dev/null +++ b/java/src/main/java/org/lance/index/vector/VectorTrainer.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.vector; + +import org.lance.Dataset; +import org.lance.JniLoader; + +import org.apache.arrow.util.Preconditions; + +/** + * Training utilities for vector indexes. + * + *

These helpers expose the underlying Lance training routines so that callers can pre-train + * models (IVF centroids, PQ codebooks, SQ params) and then pass the resulting artifacts into + * distributed index build flows. + */ +public final class VectorTrainer { + + static { + JniLoader.ensureLoaded(); + } + + private VectorTrainer() {} + + /** + * Train IVF centroids for the given dataset column. + * + * @param dataset the dataset to sample training data from + * @param column the vector column name + * @param params IVF build parameters (numPartitions, sampleRate, etc.) + * @return a flattened array of centroids laid out as [numPartitions][dimension] + */ + public static float[] trainIvfCentroids(Dataset dataset, String column, IvfBuildParams params) { + Preconditions.checkArgument(dataset != null, "dataset cannot be null"); + Preconditions.checkArgument( + column != null && !column.isEmpty(), "column cannot be null or empty"); + Preconditions.checkArgument(params != null, "params cannot be null"); + return nativeTrainIvfCentroids(dataset, column, params); + } + + /** + * Train a PQ codebook for the given dataset column. + * + * @param dataset the dataset to sample training data from + * @param column the vector column name + * @param params PQ build parameters (numSubVectors, numBits, sampleRate, etc.) + * @return a flattened array of codebook entries laid out as [num_centroids][dimension] + */ + public static float[] trainPqCodebook(Dataset dataset, String column, PQBuildParams params) { + Preconditions.checkArgument(dataset != null, "dataset cannot be null"); + Preconditions.checkArgument( + column != null && !column.isEmpty(), "column cannot be null or empty"); + Preconditions.checkArgument(params != null, "params cannot be null"); + return nativeTrainPqCodebook(dataset, column, params); + } + + private static native float[] nativeTrainIvfCentroids( + Dataset dataset, String column, IvfBuildParams params); + + private static native float[] nativeTrainPqCodebook( + Dataset dataset, String column, PQBuildParams params); +} diff --git a/java/src/test/java/org/lance/JNITest.java b/java/src/test/java/org/lance/JNITest.java index 8bf335e2fa8..4b09de66631 100644 --- a/java/src/test/java/org/lance/JNITest.java +++ b/java/src/test/java/org/lance/JNITest.java @@ -172,17 +172,17 @@ public void testInvalidCombinationHnswWithoutPqOrSq() { } @Test - public void testInvalidCombinationSqWithoutHnsw() { + public void testValidCombinationIvfSqWithoutHnsw() { IvfBuildParams ivf = new IvfBuildParams.Builder().setNumPartitions(10).build(); SQBuildParams sq = new SQBuildParams.Builder().build(); - assertThrows( - IllegalArgumentException.class, - () -> { - new VectorIndexParams.Builder(ivf) - .setDistanceType(DistanceType.L2) - .setSqParams(sq) - .build(); - }); + JniTestHelper.parseIndexParams( + IndexParams.builder() + .setVectorIndexParams( + new VectorIndexParams.Builder(ivf) + .setDistanceType(DistanceType.L2) + .setSqParams(sq) + .build()) + .build()); } } diff --git a/java/src/test/java/org/lance/TestVectorDataset.java b/java/src/test/java/org/lance/TestVectorDataset.java index 96420902e9b..f05c7dc7abb 100644 --- a/java/src/test/java/org/lance/TestVectorDataset.java +++ b/java/src/test/java/org/lance/TestVectorDataset.java @@ -102,6 +102,8 @@ private FragmentMetadata createFragment(int batchIndex) throws IOException { for (int j = 0; j < 32; j++) { vecItemsVector.setSafe(i * 32 + j, (float) (i * 32 + j)); } + // Mark the fixed-size list value as non-null + vecVector.setNotNull(i); } root.setRowCount(80); @@ -127,6 +129,8 @@ public Dataset appendNewData() throws IOException { for (int j = 0; j < 32; j++) { vecItemsVector.setSafe(i * 32 + j, (float) i); } + // Mark the fixed-size list value as non-null + vecVector.setNotNull(i); } root.setRowCount(10); diff --git a/java/src/test/java/org/lance/ScalarIndexTest.java b/java/src/test/java/org/lance/index/ScalarIndexTest.java similarity index 98% rename from java/src/test/java/org/lance/ScalarIndexTest.java rename to java/src/test/java/org/lance/index/ScalarIndexTest.java index 61ae66c94d2..adb5fa83468 100644 --- a/java/src/test/java/org/lance/ScalarIndexTest.java +++ b/java/src/test/java/org/lance/index/ScalarIndexTest.java @@ -11,12 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.lance; +package org.lance.index; -import org.lance.index.Index; -import org.lance.index.IndexOptions; -import org.lance.index.IndexParams; -import org.lance.index.IndexType; +import org.lance.Dataset; +import org.lance.Fragment; +import org.lance.TestUtils; +import org.lance.Transaction; +import org.lance.WriteParams; import org.lance.index.scalar.ScalarIndexParams; import org.lance.ipc.LanceScanner; import org.lance.ipc.ScanOptions; @@ -107,7 +108,7 @@ public void testCreateBTreeIndex() throws Exception { } @Test - public void testCreateBTreeIndexDistributedly() throws Exception { + public void testCreateBTreeIndexDistributively() throws Exception { String datasetPath = tempDir.resolve("build_index_distributedly").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = diff --git a/java/src/test/java/org/lance/index/VectorIndexTest.java b/java/src/test/java/org/lance/index/VectorIndexTest.java new file mode 100755 index 00000000000..9d4b9791949 --- /dev/null +++ b/java/src/test/java/org/lance/index/VectorIndexTest.java @@ -0,0 +1,363 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index; + +import org.lance.Dataset; +import org.lance.Fragment; +import org.lance.TestVectorDataset; +import org.lance.Transaction; +import org.lance.index.vector.IvfBuildParams; +import org.lance.index.vector.PQBuildParams; +import org.lance.index.vector.SQBuildParams; +import org.lance.index.vector.VectorIndexParams; +import org.lance.index.vector.VectorTrainer; +import org.lance.operation.CreateIndex; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class VectorIndexTest { + + @TempDir Path tempDir; + + @Test + public void testCreateIvfFlatIndexDistributively() throws Exception { + try (TestVectorDataset testVectorDataset = + new TestVectorDataset(tempDir.resolve("merge_ivfflat_index_metadata"))) { + try (Dataset dataset = testVectorDataset.create()) { + List fragments = dataset.getFragments(); + assertTrue( + fragments.size() >= 2, + "Expected dataset to have at least two fragments for distributed indexing"); + + int numPartitions = 2; + + IvfBuildParams ivfTrainParams = + new IvfBuildParams.Builder().setNumPartitions(numPartitions).setMaxIters(1).build(); + + float[] centroids = + VectorTrainer.trainIvfCentroids( + dataset, TestVectorDataset.vectorColumnName, ivfTrainParams); + + IvfBuildParams ivfParams = + new IvfBuildParams.Builder() + .setNumPartitions(numPartitions) + .setMaxIters(1) + .setCentroids(centroids) + .build(); + + VectorIndexParams vectorIndexParams = + new VectorIndexParams.Builder(ivfParams).setDistanceType(DistanceType.L2).build(); + + IndexParams indexParams = + IndexParams.builder().setVectorIndexParams(vectorIndexParams).build(); + + UUID indexUUID = UUID.randomUUID(); + + // Partially create index on the first fragment + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList(TestVectorDataset.vectorColumnName), + IndexType.IVF_FLAT, + indexParams) + .withIndexName(TestVectorDataset.indexName) + .withIndexUUID(indexUUID.toString()) + .withFragmentIds(Collections.singletonList(fragments.get(0).getId())) + .build()); + + // Partially create index on the second fragment with the same UUID + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList(TestVectorDataset.vectorColumnName), + IndexType.IVF_FLAT, + indexParams) + .withIndexName(TestVectorDataset.indexName) + .withIndexUUID(indexUUID.toString()) + .withFragmentIds(Collections.singletonList(fragments.get(1).getId())) + .build()); + + // The index should not be visible before metadata merge & commit + assertFalse( + dataset.listIndexes().contains(TestVectorDataset.indexName), + "Partially created IVF_FLAT index should not present before commit"); + + // Merge index metadata for all fragment-level pieces + dataset.mergeIndexMetadata(indexUUID.toString(), IndexType.IVF_FLAT, Optional.empty()); + + int fieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals(TestVectorDataset.vectorColumnName)) + .findAny() + .orElseThrow( + () -> new RuntimeException("Cannot find vector field for TestVectorDataset")) + .getId(); + + long datasetVersion = dataset.version(); + + Index index = + Index.builder() + .uuid(indexUUID) + .name(TestVectorDataset.indexName) + .fields(Collections.singletonList(fieldId)) + .datasetVersion(datasetVersion) + .indexVersion(0) + .fragments( + fragments.stream().limit(2).map(Fragment::getId).collect(Collectors.toList())) + .build(); + + CreateIndex createIndexOp = + CreateIndex.builder().withNewIndices(Collections.singletonList(index)).build(); + + Transaction createIndexTx = + dataset.newTransactionBuilder().operation(createIndexOp).build(); + + try (Dataset newDataset = createIndexTx.commit()) { + assertEquals(datasetVersion + 1, newDataset.version()); + assertTrue(newDataset.listIndexes().contains(TestVectorDataset.indexName)); + } + } + } + } + + @Test + public void testCreateIvfPqIndexDistributively() throws Exception { + try (TestVectorDataset testVectorDataset = + new TestVectorDataset(tempDir.resolve("merge_ivfpq_index_metadata"))) { + try (Dataset dataset = testVectorDataset.create()) { + List fragments = dataset.getFragments(); + assertTrue( + fragments.size() >= 2, + "Expected dataset to have at least two fragments for distributed indexing"); + + int numPartitions = 2; + int numSubVectors = 2; + int numBits = 8; + + IvfBuildParams ivfTrainParams = + new IvfBuildParams.Builder().setNumPartitions(numPartitions).setMaxIters(1).build(); + + PQBuildParams pqTrainParams = + new PQBuildParams.Builder() + .setNumSubVectors(numSubVectors) + .setNumBits(numBits) + .setMaxIters(2) + .setSampleRate(256) + .build(); + + float[] centroids = + VectorTrainer.trainIvfCentroids( + dataset, TestVectorDataset.vectorColumnName, ivfTrainParams); + + float[] codebook = + VectorTrainer.trainPqCodebook( + dataset, TestVectorDataset.vectorColumnName, pqTrainParams); + + IvfBuildParams ivfParams = + new IvfBuildParams.Builder() + .setNumPartitions(numPartitions) + .setMaxIters(1) + .setCentroids(centroids) + .build(); + + PQBuildParams pqParams = + new PQBuildParams.Builder() + .setNumSubVectors(numSubVectors) + .setNumBits(numBits) + .setMaxIters(2) + .setSampleRate(256) + .setCodebook(codebook) + .build(); + + VectorIndexParams vectorIndexParams = + VectorIndexParams.withIvfPqParams(DistanceType.L2, ivfParams, pqParams); + + IndexParams indexParams = + IndexParams.builder().setVectorIndexParams(vectorIndexParams).build(); + + UUID indexUUID = UUID.randomUUID(); + + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList(TestVectorDataset.vectorColumnName), + IndexType.IVF_PQ, + indexParams) + .withIndexName(TestVectorDataset.indexName) + .withIndexUUID(indexUUID.toString()) + .withFragmentIds(Collections.singletonList(fragments.get(0).getId())) + .build()); + + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList(TestVectorDataset.vectorColumnName), + IndexType.IVF_PQ, + indexParams) + .withIndexName(TestVectorDataset.indexName) + .withIndexUUID(indexUUID.toString()) + .withFragmentIds(Collections.singletonList(fragments.get(1).getId())) + .build()); + + assertFalse( + dataset.listIndexes().contains(TestVectorDataset.indexName), + "Partially created IVF_PQ index should not present before commit"); + + dataset.mergeIndexMetadata(indexUUID.toString(), IndexType.IVF_PQ, Optional.empty()); + + int fieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals(TestVectorDataset.vectorColumnName)) + .findAny() + .orElseThrow( + () -> new RuntimeException("Cannot find vector field for TestVectorDataset")) + .getId(); + + long datasetVersion = dataset.version(); + + Index index = + Index.builder() + .uuid(indexUUID) + .name(TestVectorDataset.indexName) + .fields(Collections.singletonList(fieldId)) + .datasetVersion(datasetVersion) + .indexVersion(0) + .fragments( + fragments.stream().limit(2).map(Fragment::getId).collect(Collectors.toList())) + .build(); + + CreateIndex createIndexOp = + CreateIndex.builder().withNewIndices(Collections.singletonList(index)).build(); + + Transaction createIndexTx = + dataset.newTransactionBuilder().operation(createIndexOp).build(); + + try (Dataset newDataset = createIndexTx.commit()) { + assertEquals(datasetVersion + 1, newDataset.version()); + assertTrue(newDataset.listIndexes().contains(TestVectorDataset.indexName)); + } + } + } + } + + @Test + public void testCreateIvfSqIndexDistributively() throws Exception { + try (TestVectorDataset testVectorDataset = + new TestVectorDataset(tempDir.resolve("merge_ivfsq_index_metadata"))) { + try (Dataset dataset = testVectorDataset.create()) { + List fragments = dataset.getFragments(); + assertTrue( + fragments.size() >= 2, + "Expected dataset to have at least two fragments for distributed indexing"); + + int numPartitions = 2; + short numBits = 8; + + IvfBuildParams ivfTrainParams = + new IvfBuildParams.Builder().setNumPartitions(numPartitions).setMaxIters(1).build(); + + SQBuildParams sqParams = + new SQBuildParams.Builder().setNumBits(numBits).setSampleRate(256).build(); + + float[] centroids = + VectorTrainer.trainIvfCentroids( + dataset, TestVectorDataset.vectorColumnName, ivfTrainParams); + + IvfBuildParams ivfParams = + new IvfBuildParams.Builder() + .setNumPartitions(numPartitions) + .setMaxIters(1) + .setCentroids(centroids) + .build(); + + VectorIndexParams vectorIndexParams = + new VectorIndexParams.Builder(ivfParams) + .setDistanceType(DistanceType.L2) + .setSqParams(sqParams) + .build(); + + IndexParams indexParams = + IndexParams.builder().setVectorIndexParams(vectorIndexParams).build(); + + UUID indexUUID = UUID.randomUUID(); + + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList(TestVectorDataset.vectorColumnName), + IndexType.IVF_SQ, + indexParams) + .withIndexName(TestVectorDataset.indexName) + .withIndexUUID(indexUUID.toString()) + .withFragmentIds(Collections.singletonList(fragments.get(0).getId())) + .build()); + + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList(TestVectorDataset.vectorColumnName), + IndexType.IVF_SQ, + indexParams) + .withIndexName(TestVectorDataset.indexName) + .withIndexUUID(indexUUID.toString()) + .withFragmentIds(Collections.singletonList(fragments.get(1).getId())) + .build()); + + assertFalse( + dataset.listIndexes().contains(TestVectorDataset.indexName), + "Partially created IVF_SQ index should not present before commit"); + + dataset.mergeIndexMetadata(indexUUID.toString(), IndexType.IVF_SQ, Optional.empty()); + + int fieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals(TestVectorDataset.vectorColumnName)) + .findAny() + .orElseThrow( + () -> new RuntimeException("Cannot find vector field for TestVectorDataset")) + .getId(); + + long datasetVersion = dataset.version(); + + Index index = + Index.builder() + .uuid(indexUUID) + .name(TestVectorDataset.indexName) + .fields(Collections.singletonList(fieldId)) + .datasetVersion(datasetVersion) + .indexVersion(0) + .fragments( + fragments.stream().limit(2).map(Fragment::getId).collect(Collectors.toList())) + .build(); + + CreateIndex createIndexOp = + CreateIndex.builder().withNewIndices(Collections.singletonList(index)).build(); + + Transaction createIndexTx = + dataset.newTransactionBuilder().operation(createIndexOp).build(); + + try (Dataset newDataset = createIndexTx.commit()) { + assertEquals(datasetVersion + 1, newDataset.version()); + assertTrue(newDataset.listIndexes().contains(TestVectorDataset.indexName)); + } + } + } + } +} diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 09f28218241..2c4eb737bd6 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -34,7 +34,6 @@ use pyo3::{prelude::*, IntoPyObjectExt}; use snafu::location; use lance::dataset::cleanup::CleanupPolicyBuilder; -use lance::dataset::index::LanceIndexStoreExt; use lance::dataset::refs::{Ref, TagContents}; use lance::dataset::scanner::{ ColumnOrdering, DatasetRecordBatchStream, ExecutionStatsCallback, MaterializationStyle, @@ -66,7 +65,6 @@ use lance_file::reader::FileReaderOptions; use lance_index::scalar::inverted::query::{ BooleanQuery, BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, Operator, PhraseQuery, }; -use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::{ infer_system_index_type, metrics::NoOpMetricsCollector, scalar::inverted::query::Occur, }; @@ -2080,53 +2078,9 @@ impl Dataset { batch_readhead: Option, ) -> PyResult<()> { rt().block_on(None, async { - let store = LanceIndexStore::from_dataset_for_new(self.ds.as_ref(), index_uuid)?; - let index_dir = self.ds.indices_dir().child(index_uuid); - let index_type_up = index_type.to_uppercase(); - log::info!( - "merge_index_metadata called with index_type={} (upper={})", - index_type, - index_type_up - ); - match index_type_up.as_str() { - "INVERTED" | "FTS" => { - // Call merge_index_files function for inverted index - lance_index::scalar::inverted::builder::merge_index_files( - self.ds.object_store(), - &index_dir, - Arc::new(store), - ) - .await - } - "BTREE" => { - // Call merge_index_files function for btree index - lance_index::scalar::btree::merge_index_files( - self.ds.object_store(), - &index_dir, - Arc::new(store), - batch_readhead, - ) - .await - } - // Precise vector index types: IVF_FLAT, IVF_PQ, IVF_SQ - "IVF_FLAT" | "IVF_PQ" | "IVF_SQ" | "VECTOR" => { - // Merge distributed vector index partials and finalize root index via Lance IVF helper - lance::index::vector::ivf::finalize_distributed_merge( - self.ds.object_store(), - &index_dir, - Some(&index_type_up), - ) - .await?; - Ok(()) - } - _ => Err(lance::Error::InvalidInput { - source: Box::new(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - format!("Unsupported index type (patched): {}", index_type_up), - )), - location: location!(), - }), - } + self.ds + .merge_index_metadata(index_uuid, IndexType::try_from(index_type)?, batch_readhead) + .await })? .map_err(|err| PyValueError::new_err(err.to_string())) } diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 5ed4638b6cb..776619e5036 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -191,15 +191,13 @@ impl TryFrom<&str> for IndexType { fn try_from(value: &str) -> Result { match value { - "BTree" => Ok(Self::BTree), - "Bitmap" => Ok(Self::Bitmap), - "LabelList" => Ok(Self::LabelList), - "Inverted" => Ok(Self::Inverted), - "NGram" => Ok(Self::NGram), - "FragmentReuse" => Ok(Self::FragmentReuse), - "MemWal" => Ok(Self::MemWal), - "ZoneMap" => Ok(Self::ZoneMap), - "Vector" => Ok(Self::Vector), + "BTree" | "BTREE" => Ok(Self::BTree), + "Bitmap" | "BITMAP" => Ok(Self::Bitmap), + "LabelList" | "LABELLIST" => Ok(Self::LabelList), + "Inverted" | "INVERTED" => Ok(Self::Inverted), + "NGram" | "NGRAM" => Ok(Self::NGram), + "ZoneMap" | "ZONEMAP" => Ok(Self::ZoneMap), + "Vector" | "VECTOR" => Ok(Self::Vector), "IVF_FLAT" => Ok(Self::IvfFlat), "IVF_SQ" => Ok(Self::IvfSq), "IVF_PQ" => Ok(Self::IvfPq), @@ -207,6 +205,8 @@ impl TryFrom<&str> for IndexType { "IVF_HNSW_FLAT" => Ok(Self::IvfHnswFlat), "IVF_HNSW_SQ" => Ok(Self::IvfHnswSq), "IVF_HNSW_PQ" => Ok(Self::IvfHnswPq), + "FragmentReuse" => Ok(Self::FragmentReuse), + "MemWal" => Ok(Self::MemWal), _ => Err(Error::invalid_input( format!("invalid index type: {}", value), location!(), diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index e5e71887147..bf4bede7d50 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -33,7 +33,7 @@ use lance_datafusion::projection::ProjectionPlan; use lance_file::datatypes::populate_schema_dictionary; use lance_file::reader::FileReaderOptions; use lance_file::version::LanceFileVersion; -use lance_index::DatasetIndexExt; +use lance_index::{DatasetIndexExt, IndexType}; use lance_io::object_store::{ LanceNamespaceStorageOptionsProvider, ObjectStore, ObjectStoreParams, StorageOptions, StorageOptionsAccessor, StorageOptionsProvider, @@ -111,6 +111,7 @@ pub use blob::BlobFile; use hash_joiner::HashJoiner; use lance_core::box_error; pub use lance_core::ROW_ID; +use lance_index::scalar::lance_format::LanceIndexStore; use lance_namespace::models::{ CreateEmptyTableRequest, DeclareTableRequest, DeclareTableResponse, DescribeTableRequest, }; @@ -125,6 +126,7 @@ pub use write::merge_insert::{ WhenNotMatched, WhenNotMatchedBySource, }; +use crate::dataset::index::LanceIndexStoreExt; pub use write::update::{UpdateBuilder, UpdateJob}; #[allow(deprecated)] pub use write::{ @@ -2748,6 +2750,55 @@ impl Dataset { let stream = Box::new(stream); self.merge_impl(stream, left_on, right_on).await } + + pub async fn merge_index_metadata( + &self, + index_uuid: &str, + index_type: IndexType, + batch_readhead: Option, + ) -> Result<()> { + let store = LanceIndexStore::from_dataset_for_new(self, index_uuid)?; + let index_dir = self.indices_dir().child(index_uuid); + match index_type { + IndexType::Inverted => { + // Call merge_index_files function for inverted index + lance_index::scalar::inverted::builder::merge_index_files( + self.object_store(), + &index_dir, + Arc::new(store), + ) + .await + } + IndexType::BTree => { + // Call merge_index_files function for btree index + lance_index::scalar::btree::merge_index_files( + self.object_store(), + &index_dir, + Arc::new(store), + batch_readhead, + ) + .await + } + // Precise vector index types: IVF_FLAT, IVF_PQ, IVF_SQ + IndexType::IvfFlat | IndexType::IvfPq | IndexType::IvfSq | IndexType::Vector => { + // Merge distributed vector index partials and finalize root index via Lance IVF helper + crate::index::vector::ivf::finalize_distributed_merge( + self.object_store(), + &index_dir, + Some(index_type), + ) + .await?; + Ok(()) + } + _ => Err(Error::InvalidInput { + source: Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Unsupported index type (patched): {}", index_type), + )), + location: location!(), + }), + } + } } /// # Dataset metadata APIs diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 354c0ecaf87..c14cdeada81 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -1867,7 +1867,7 @@ async fn write_ivf_hnsw_file( pub async fn finalize_distributed_merge( object_store: &ObjectStore, index_dir: &object_store::path::Path, - requested_index_type: Option<&str>, + requested_index_type: Option, ) -> Result<()> { // Merge per-shard auxiliary files into a unified auxiliary.idx. lance_index::vector::distributed::index_merger::merge_partial_vector_auxiliary_files( @@ -1889,7 +1889,7 @@ pub async fn finalize_distributed_merge( fh, None, Arc::default(), - &lance_core::cache::LanceCache::no_cache(), + &LanceCache::no_cache(), V2ReaderOptions::default(), ) .await?;