Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions java/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use jni::sys::{jbyteArray, jlong};
use jni::{objects::JObject, JNIEnv};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::{CleanupPolicy, RemovalStats};
use lance::dataset::index::LanceIndexStoreExt;
use lance::dataset::optimize::{compact_files, CompactionOptions as RustCompactionOptions};
use lance::dataset::refs::{Ref, TagContents};
use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt};
Expand All @@ -41,7 +40,6 @@ use lance::table::format::{BasePath, Fragment};
use lance_core::datatypes::Schema as LanceSchema;
use lance_index::optimize::OptimizeOptions;
use lance_index::scalar::btree::BTreeParameters;
use lance_index::scalar::lance_format::LanceIndexStore;
use lance_index::DatasetIndexExt;
use lance_index::{IndexParams, IndexType};
use lance_io::object_store::ObjectStoreRegistry;
Expand Down Expand Up @@ -975,44 +973,12 @@ fn inner_merge_index_metadata(
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;

RT.block_on(async {
let index_store = LanceIndexStore::from_dataset_for_new(&dataset_guard.inner, &index_uuid)?;
let object_store = dataset_guard.inner.object_store();
let index_dir = dataset_guard.inner.indices_dir().child(index_uuid);

match index_type {
IndexType::Inverted => lance_index::scalar::inverted::builder::merge_index_files(
object_store,
&index_dir,
Arc::new(index_store),
)
.await
.map_err(|e| {
Error::runtime_error(format!(
"Cannot create index of type: {:?}. Caused by: {:?}",
index_type,
e.to_string()
))
}),
IndexType::BTree => lance_index::scalar::btree::merge_index_files(
object_store,
&index_dir,
Arc::new(index_store),
batch_readhead,
)
dataset_guard
.inner
.merge_index_metadata(&index_uuid, index_type, batch_readhead)
.await
.map_err(|e| {
Error::runtime_error(format!(
"Cannot create index of type: {:?}. Caused by: {:?}",
index_type,
e.to_string()
))
}),
_ => Err(Error::input_error(format!(
"Cannot merge index type: {:?}. Only supports BTREE and INVERTED now.",
index_type
))),
}
})
})?;
Ok(())
}

#[no_mangle]
Expand Down
1 change: 1 addition & 0 deletions java/lance-jni/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ mod storage_options;
pub mod traits;
mod transaction;
pub mod utils;
mod vector_trainer;

pub use error::Error;
pub use error::Result;
Expand Down
68 changes: 64 additions & 4 deletions java/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

use std::sync::Arc;

use arrow::array::Float32Array;
use jni::objects::{JMap, JObject, JString, JValue, JValueGen};
use arrow::array::{ArrayRef, FixedSizeListArray, Float32Array};
use arrow_schema::{DataType, Field};
use jni::objects::{JFloatArray, JMap, JObject, JString, JValue, JValueGen};
use jni::sys::{jboolean, jfloat, jlong};
use jni::JNIEnv;
use lance::dataset::optimize::CompactionOptions;
Expand Down Expand Up @@ -256,14 +257,52 @@ pub fn get_vector_index_params(
let shuffle_partition_concurrency = env
.get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionConcurrency")?;

let ivf_params = IvfBuildParams {
let mut ivf_params = IvfBuildParams {
num_partitions: Some(num_partitions),
max_iters,
sample_rate,
shuffle_partition_batches,
shuffle_partition_concurrency,
..Default::default()
};

// Optional pre-trained IVF centroids from Java IvfBuildParams
// Method signature: float[] getCentroids()
let centroids_obj = env
.call_method(&ivf_params_obj, "getCentroids", "()[F", &[])?
.l()?;

if !centroids_obj.is_null() {
let jarray: JFloatArray = centroids_obj.into();
let length = env.get_array_length(&jarray)?;
if length > 0 {
if (length as usize) % num_partitions != 0 {
return Err(Error::input_error(format!(
"Invalid IVF centroids: length {} is not divisible by num_partitions {}",
length, num_partitions
)));
}
let mut buffer = vec![0.0f32; length as usize];
env.get_float_array_region(&jarray, 0, &mut buffer)?;
let dimension = buffer.len() / num_partitions;

let values = Float32Array::from(buffer);
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, false)),
dimension as i32,
Arc::new(values) as ArrayRef,
None,
)
.map_err(|e| {
Error::input_error(format!(
"Failed to construct FixedSizeListArray for IVF centroids: {e}"
))
})?;

ivf_params.centroids = Some(Arc::new(fsl));
}
}

stages.push(StageParams::Ivf(ivf_params));

// Parse HnswBuildParams
Expand Down Expand Up @@ -305,13 +344,34 @@ pub fn get_vector_index_params(
env.get_int_as_usize_from_method(&pq_obj, "getKmeansRedos")?;
let sample_rate = env.get_int_as_usize_from_method(&pq_obj, "getSampleRate")?;

// Optional pre-trained PQ codebook from Java PQBuildParams
// Method signature: float[] getCodebook()
let codebook_obj = env
.call_method(&pq_obj, "getCodebook", "()[F", &[])?
.l()?;

let codebook = if !codebook_obj.is_null() {
let jarray: JFloatArray = codebook_obj.into();
let length = env.get_array_length(&jarray)?;
if length > 0 {
let mut buffer = vec![0.0f32; length as usize];
env.get_float_array_region(&jarray, 0, &mut buffer)?;
let values = Float32Array::from(buffer);
Some(Arc::new(values) as _)
} else {
None
}
} else {
None
};

Ok(PQBuildParams {
num_sub_vectors,
num_bits,
max_iters,
kmeans_redos,
codebook,
sample_rate,
..Default::default()
})
},
)?;
Expand Down
179 changes: 179 additions & 0 deletions java/lance-jni/src/vector_trainer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET};
use crate::error::{Error, Result};
use crate::ffi::JNIEnvExt;
use crate::RT;

use arrow::array::{FixedSizeListArray, Float32Array};
use jni::objects::{JClass, JFloatArray, JObject, JString};
use jni::sys::jfloatArray;
use jni::JNIEnv;
use lance::index::vector::utils::get_vector_dim;
use lance_index::vector::ivf::builder::IvfBuildParams as RustIvfBuildParams;
use lance_index::vector::pq::builder::PQBuildParams as RustPQBuildParams;
use lance_linalg::distance::MetricType;

/// Flatten a FixedSizeList<Float32> into a contiguous Vec<f32>.
fn flatten_fixed_size_list_to_f32(arr: &FixedSizeListArray) -> Result<Vec<f32>> {
let values = arr
.values()
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| {
Error::input_error(format!(
"Expected FixedSizeList<Float32>, got value type {}",
arr.value_type()
))
})?;

Ok(values.values().to_vec())
}

fn build_ivf_params_from_java(
env: &mut JNIEnv,
ivf_params_obj: &JObject,
) -> Result<RustIvfBuildParams> {
let num_partitions = env.get_int_as_usize_from_method(ivf_params_obj, "getNumPartitions")?;
let max_iters = env.get_int_as_usize_from_method(ivf_params_obj, "getMaxIters")?;
let sample_rate = env.get_int_as_usize_from_method(ivf_params_obj, "getSampleRate")?;
let shuffle_partition_batches =
env.get_int_as_usize_from_method(ivf_params_obj, "getShufflePartitionBatches")?;
let shuffle_partition_concurrency =
env.get_int_as_usize_from_method(ivf_params_obj, "getShufflePartitionConcurrency")?;

Ok(RustIvfBuildParams {
num_partitions: Some(num_partitions),
max_iters,
sample_rate,
shuffle_partition_batches,
shuffle_partition_concurrency,
..Default::default()
})
}

fn build_pq_params_from_java(
env: &mut JNIEnv,
pq_params_obj: &JObject,
) -> Result<RustPQBuildParams> {
let num_sub_vectors = env.get_int_as_usize_from_method(pq_params_obj, "getNumSubVectors")?;
let num_bits = env.get_int_as_usize_from_method(pq_params_obj, "getNumBits")?;
let max_iters = env.get_int_as_usize_from_method(pq_params_obj, "getMaxIters")?;
let kmeans_redos = env.get_int_as_usize_from_method(pq_params_obj, "getKmeansRedos")?;
let sample_rate = env.get_int_as_usize_from_method(pq_params_obj, "getSampleRate")?;

Ok(RustPQBuildParams {
num_sub_vectors,
num_bits,
max_iters,
kmeans_redos,
codebook: None,
sample_rate,
})
}

#[no_mangle]
pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainIvfCentroids<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
dataset_obj: JObject<'local>, // org.lance.Dataset
column_jstr: JString<'local>, // java.lang.String
ivf_params_obj: JObject<'local>, // org.lance.index.vector.IvfBuildParams
) -> jfloatArray {
ok_or_throw_with_return!(
env,
inner_train_ivf_centroids(&mut env, dataset_obj, column_jstr, ivf_params_obj)
.map(|arr| arr.into_raw()),
JFloatArray::default().into_raw()
)
}

fn inner_train_ivf_centroids<'local>(
env: &mut JNIEnv<'local>,
dataset_obj: JObject<'local>,
column_jstr: JString<'local>,
ivf_params_obj: JObject<'local>,
) -> Result<JFloatArray<'local>> {
let column: String = env.get_string(&column_jstr)?.into();
let ivf_params = build_ivf_params_from_java(env, &ivf_params_obj)?;

let flattened: Vec<f32> = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(dataset_obj, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let dim = get_vector_dim(dataset.schema(), &column)?;

// For now we default to L2 metric; tests and Java bindings currently use L2.
let metric_type = MetricType::L2;

let ivf_model = RT.block_on(lance::index::vector::ivf::build_ivf_model(
dataset,
&column,
dim,
metric_type,
&ivf_params,
))?;

let centroids = ivf_model
.centroids
.ok_or_else(|| Error::runtime_error("IVF model missing centroids".to_string()))?;

flatten_fixed_size_list_to_f32(&centroids)?
};

let jarray = env.new_float_array(flattened.len() as i32)?;
env.set_float_array_region(&jarray, 0, &flattened)?;
Ok(jarray)
}

#[no_mangle]
pub extern "system" fn Java_org_lance_index_vector_VectorTrainer_nativeTrainPqCodebook<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
dataset_obj: JObject<'local>, // org.lance.Dataset
column_jstr: JString<'local>, // java.lang.String
pq_params_obj: JObject<'local>, // org.lance.index.vector.PQBuildParams
) -> jfloatArray {
ok_or_throw_with_return!(
env,
inner_train_pq_codebook(&mut env, dataset_obj, column_jstr, pq_params_obj)
.map(|arr| arr.into_raw()),
JFloatArray::default().into_raw()
)
}

fn inner_train_pq_codebook<'local>(
env: &mut JNIEnv<'local>,
dataset_obj: JObject<'local>,
column_jstr: JString<'local>,
pq_params_obj: JObject<'local>,
) -> Result<JFloatArray<'local>> {
let column: String = env.get_string(&column_jstr)?.into();
let pq_params = build_pq_params_from_java(env, &pq_params_obj)?;

let flattened: Vec<f32> = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(dataset_obj, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let dim = get_vector_dim(dataset.schema(), &column)?;
let metric_type = MetricType::L2;

let pq = RT.block_on(lance::index::vector::pq::build_pq_model(
dataset,
&column,
dim,
metric_type,
&pq_params,
None,
))?;

flatten_fixed_size_list_to_f32(&pq.codebook)?
};

let jarray = env.new_float_array(flattened.len() as i32)?;
env.set_float_array_region(&jarray, 0, &flattened)?;
Ok(jarray)
}
2 changes: 1 addition & 1 deletion java/src/main/java/org/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ public Index createIndex(IndexOptions options) {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
return nativeCreateIndex(
options.getColumns(),
options.getIndexType().ordinal(),
options.getIndexType().getValue(),
options.getIndexName(),
options.getIndexParams(),
options.isReplace(),
Expand Down
Loading