diff --git a/README.md b/README.md index b9eb134..514a6e2 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ SELECT count(*) FROM lance_ns.main.my_dataset; ```sql -- Search a vector column, returning distances in `_distance` (smaller is closer) SELECT id, label, _distance -FROM lance_vector_search('path/to/dataset.lance', 'vec', [0.1, 0.2, 0.3, 0.4]::FLOAT[], +FROM lance_vector_search('path/to/dataset.lance', 'vec', [0.1, 0.2, 0.3, 0.4]::FLOAT[4], k = 5, prefilter = true) ORDER BY _distance ASC; ``` @@ -146,7 +146,7 @@ ORDER BY _distance ASC; - Positional arguments: - `uri` (VARCHAR): Dataset root path or object store URI (e.g. `s3://...`). - `vector_column` (VARCHAR): Vector column name. - - `query_vector` (FLOAT[] or DOUBLE[]): Query vector (must be non-empty; values are cast to float32). + - `query_vector` (FLOAT[dim] or DOUBLE[dim], preferred): Query vector (must be non-empty; values are cast to float32). `FLOAT[]` / `DOUBLE[]` are also accepted. - Named parameters: - `k` (BIGINT, default `10`): Number of results to return. - `prefilter` (BOOLEAN, default `false`): If `true`, filters are applied before top-k selection. @@ -181,7 +181,7 @@ ORDER BY _score DESC; -- Combine vector and text scores, returning `_hybrid_score` in addition to `_distance` / `_score` SELECT id, _hybrid_score, _distance, _score FROM lance_hybrid_search('path/to/dataset.lance', - 'vec', [0.1, 0.2, 0.3, 0.4]::FLOAT[], + 'vec', [0.1, 0.2, 0.3, 0.4]::FLOAT[4], 'text', 'puppy', k = 10, prefilter = false, alpha = 0.5, oversample_factor = 4) @@ -192,7 +192,7 @@ ORDER BY _hybrid_score DESC; - Positional arguments: - `uri` (VARCHAR): Dataset root path or object store URI (e.g. `s3://...`). - `vector_column` (VARCHAR): Vector column name. - - `query_vector` (FLOAT[] or DOUBLE[]): Query vector (must be non-empty; values are cast to float32). + - `query_vector` (FLOAT[dim] or DOUBLE[dim], preferred): Query vector (must be non-empty; values are cast to float32). `FLOAT[]` / `DOUBLE[]` are also accepted. - `text_column` (VARCHAR): Text column name. - `query` (VARCHAR): Query string. - Named parameters: diff --git a/rust/ffi/write.rs b/rust/ffi/write.rs index cb96336..e026771 100644 --- a/rust/ffi/write.rs +++ b/rust/ffi/write.rs @@ -3,12 +3,16 @@ use std::ffi::{c_char, c_void, CStr}; use std::ptr; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::Arc; use std::sync::Mutex; -use std::sync::RwLock; use std::thread::JoinHandle; -use arrow_array::{make_array, RecordBatch, RecordBatchReader, StructArray}; -use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; +use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, Float64Builder}; +use arrow_array::{ + make_array, Array, FixedSizeListArray, Float32Array, Float64Array, LargeListArray, ListArray, + RecordBatch, RecordBatchReader, StructArray, +}; +use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; use lance::dataset::{CommitBuilder, Dataset, InsertBuilder, WriteMode, WriteParams}; use lance::io::ObjectStoreParams; @@ -60,10 +64,9 @@ impl RecordBatchReader for ReceiverRecordBatchReader { } struct WriterHandle { - schema: SchemaRef, + input_schema: SchemaRef, data_type: DataType, - sender: RwLock>>, - join: Mutex>>>, + state: Mutex, batches_sent: AtomicU64, } @@ -72,24 +75,453 @@ enum WriterResult { Uncommitted(lance::dataset::transaction::Transaction), } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum WriterKind { + Committed, + Uncommitted, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum VectorListKind { + List, + LargeList, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum VectorElementType { + Float32, + Float64, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct VectorConversion { + col_idx: usize, + dim: usize, + list_kind: VectorListKind, + element_type: VectorElementType, +} + +struct WriterState { + kind: WriterKind, + path: String, + params: WriteParams, + + vector_candidates: Vec, + buffered_batches: Vec, + + output_schema: Option, + output_sender: Option>, + output_join: Option>>, +} + impl Drop for WriterHandle { fn drop(&mut self) { - let sender = self - .sender - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .take(); + let (sender, join) = { + let mut guard = self.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + (guard.output_sender.take(), guard.output_join.take()) + }; drop(sender); - let mut guard = self - .join - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - if let Some(join) = guard.take() { + if let Some(join) = join { let _ = join.join(); } } } +const MAX_VECTOR_DIM_INFERENCE_BATCHES: usize = 4; + +fn is_variable_list_vector_type(dt: &DataType) -> Option<(VectorListKind, VectorElementType)> { + match dt { + DataType::List(field) => match field.data_type() { + DataType::Float32 => Some((VectorListKind::List, VectorElementType::Float32)), + DataType::Float64 => Some((VectorListKind::List, VectorElementType::Float64)), + _ => None, + }, + DataType::LargeList(field) => match field.data_type() { + DataType::Float32 => Some((VectorListKind::LargeList, VectorElementType::Float32)), + DataType::Float64 => Some((VectorListKind::LargeList, VectorElementType::Float64)), + _ => None, + }, + _ => None, + } +} + +fn infer_vector_dim_from_array( + array: &dyn Array, + list_kind: VectorListKind, +) -> Option> { + match list_kind { + VectorListKind::List => { + let list = array.as_any().downcast_ref::()?; + for i in 0..list.len() { + if list.is_null(i) { + continue; + } + let dim = list.value_length(i) as usize; + if dim == 0 { + return Some(Err("vector dim must be non-zero".to_string())); + } + return Some(Ok(dim)); + } + None + } + VectorListKind::LargeList => { + let list = array.as_any().downcast_ref::()?; + for i in 0..list.len() { + if list.is_null(i) { + continue; + } + let dim = list.value_length(i) as usize; + if dim == 0 { + return Some(Err("vector dim must be non-zero".to_string())); + } + return Some(Ok(dim)); + } + None + } + } +} + +fn validate_list_vector_dim( + array: &dyn Array, + list_kind: VectorListKind, + expected_dim: usize, +) -> Result<(), String> { + match list_kind { + VectorListKind::List => { + let list = array + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector column is not ListArray".to_string())?; + for i in 0..list.len() { + if list.is_null(i) { + continue; + } + let dim = list.value_length(i) as usize; + if dim != expected_dim { + return Err(format!( + "vector dim mismatch: expected {expected_dim} got {dim}" + )); + } + } + Ok(()) + } + VectorListKind::LargeList => { + let list = array + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector column is not LargeListArray".to_string())?; + for i in 0..list.len() { + if list.is_null(i) { + continue; + } + let dim = list.value_length(i) as usize; + if dim != expected_dim { + return Err(format!( + "vector dim mismatch: expected {expected_dim} got {dim}" + )); + } + } + Ok(()) + } + } +} + +fn convert_list_array_to_fixed_size( + array: &dyn Array, + list_kind: VectorListKind, + element_type: VectorElementType, + dim: usize, +) -> Result { + let dim_i32 = i32::try_from(dim).map_err(|_| "vector dim is too large".to_string())?; + + match (list_kind, element_type) { + (VectorListKind::List, VectorElementType::Float32) => { + let list = array + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector column is not ListArray".to_string())?; + let values = list + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector values are not Float32".to_string())?; + let field = match list.data_type() { + DataType::List(field) => field.clone(), + _ => return Err("vector column has unexpected data type".to_string()), + }; + + let mut builder = + FixedSizeListBuilder::with_capacity(Float32Builder::new(), dim_i32, list.len()) + .with_field(field); + let offsets = list.value_offsets(); + for i in 0..list.len() { + if list.is_null(i) { + for _ in 0..dim { + builder.values().append_null(); + } + builder.append(false); + continue; + } + let len = list.value_length(i) as usize; + if len != dim { + return Err(format!( + "vector dim mismatch: expected {dim} got {len}" + )); + } + let start = offsets[i] as usize; + for j in 0..dim { + let idx = start + j; + if idx >= values.len() { + return Err("vector offsets are out of bounds".to_string()); + } + if values.is_null(idx) { + builder.values().append_null(); + } else { + builder.values().append_value(values.value(idx)); + } + } + builder.append(true); + } + Ok(builder.finish()) + } + (VectorListKind::List, VectorElementType::Float64) => { + let list = array + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector column is not ListArray".to_string())?; + let values = list + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector values are not Float64".to_string())?; + let field = match list.data_type() { + DataType::List(field) => field.clone(), + _ => return Err("vector column has unexpected data type".to_string()), + }; + + let mut builder = + FixedSizeListBuilder::with_capacity(Float64Builder::new(), dim_i32, list.len()) + .with_field(field); + let offsets = list.value_offsets(); + for i in 0..list.len() { + if list.is_null(i) { + for _ in 0..dim { + builder.values().append_null(); + } + builder.append(false); + continue; + } + let len = list.value_length(i) as usize; + if len != dim { + return Err(format!( + "vector dim mismatch: expected {dim} got {len}" + )); + } + let start = offsets[i] as usize; + for j in 0..dim { + let idx = start + j; + if idx >= values.len() { + return Err("vector offsets are out of bounds".to_string()); + } + if values.is_null(idx) { + builder.values().append_null(); + } else { + builder.values().append_value(values.value(idx)); + } + } + builder.append(true); + } + Ok(builder.finish()) + } + (VectorListKind::LargeList, VectorElementType::Float32) => { + let list = array + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector column is not LargeListArray".to_string())?; + let values = list + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector values are not Float32".to_string())?; + let field = match list.data_type() { + DataType::LargeList(field) => field.clone(), + _ => return Err("vector column has unexpected data type".to_string()), + }; + + let mut builder = + FixedSizeListBuilder::with_capacity(Float32Builder::new(), dim_i32, list.len()) + .with_field(field); + let offsets = list.value_offsets(); + for i in 0..list.len() { + if list.is_null(i) { + for _ in 0..dim { + builder.values().append_null(); + } + builder.append(false); + continue; + } + let len = list.value_length(i) as usize; + if len != dim { + return Err(format!( + "vector dim mismatch: expected {dim} got {len}" + )); + } + let start = offsets[i] as usize; + for j in 0..dim { + let idx = start + j; + if idx >= values.len() { + return Err("vector offsets are out of bounds".to_string()); + } + if values.is_null(idx) { + builder.values().append_null(); + } else { + builder.values().append_value(values.value(idx)); + } + } + builder.append(true); + } + Ok(builder.finish()) + } + (VectorListKind::LargeList, VectorElementType::Float64) => { + let list = array + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector column is not LargeListArray".to_string())?; + let values = list + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| "vector values are not Float64".to_string())?; + let field = match list.data_type() { + DataType::LargeList(field) => field.clone(), + _ => return Err("vector column has unexpected data type".to_string()), + }; + + let mut builder = + FixedSizeListBuilder::with_capacity(Float64Builder::new(), dim_i32, list.len()) + .with_field(field); + let offsets = list.value_offsets(); + for i in 0..list.len() { + if list.is_null(i) { + for _ in 0..dim { + builder.values().append_null(); + } + builder.append(false); + continue; + } + let len = list.value_length(i) as usize; + if len != dim { + return Err(format!( + "vector dim mismatch: expected {dim} got {len}" + )); + } + let start = offsets[i] as usize; + for j in 0..dim { + let idx = start + j; + if idx >= values.len() { + return Err("vector offsets are out of bounds".to_string()); + } + if values.is_null(idx) { + builder.values().append_null(); + } else { + builder.values().append_value(values.value(idx)); + } + } + builder.append(true); + } + Ok(builder.finish()) + } + } +} + +fn build_output_schema( + input_schema: &SchemaRef, + conversions: &[VectorConversion], +) -> Result { + if conversions.is_empty() { + return Ok(input_schema.clone()); + } + let mut fields = input_schema.fields().as_ref().to_vec(); + for conv in conversions { + let idx = conv.col_idx; + if idx >= fields.len() { + return Err("vector column index is out of bounds".to_string()); + } + let original = fields[idx].as_ref(); + let (list_kind, element_type) = is_variable_list_vector_type(original.data_type()) + .ok_or_else(|| "vector column has unexpected data type".to_string())?; + if list_kind != conv.list_kind || element_type != conv.element_type { + return Err("vector column has unexpected data type".to_string()); + } + let child_field = match original.data_type() { + DataType::List(field) | DataType::LargeList(field) => field.clone(), + _ => return Err("vector column has unexpected data type".to_string()), + }; + let dim_i32 = + i32::try_from(conv.dim).map_err(|_| "vector dim is too large".to_string())?; + fields[idx] = Arc::new(Field::new( + original.name(), + DataType::FixedSizeList(child_field, dim_i32), + original.is_nullable(), + )); + } + Ok(Arc::new(Schema::new(fields))) +} + +fn convert_record_batch( + input_batch: &RecordBatch, + output_schema: &SchemaRef, + conversions: &[VectorConversion], +) -> Result { + if conversions.is_empty() { + return Ok(RecordBatch::try_new(output_schema.clone(), input_batch.columns().to_vec()) + .map_err(|e| e.to_string())?); + } + let mut cols = input_batch.columns().to_vec(); + for conv in conversions { + let arr = cols + .get(conv.col_idx) + .ok_or_else(|| "vector column index is out of bounds".to_string())? + .as_ref(); + validate_list_vector_dim(arr, conv.list_kind, conv.dim)?; + let fixed = convert_list_array_to_fixed_size(arr, conv.list_kind, conv.element_type, conv.dim)?; + cols[conv.col_idx] = Arc::new(fixed); + } + RecordBatch::try_new(output_schema.clone(), cols).map_err(|e| e.to_string()) +} + +fn spawn_writer_thread( + kind: WriterKind, + path: String, + params: WriteParams, + schema: SchemaRef, + receiver: Receiver, +) -> JoinHandle> { + std::thread::spawn(move || -> Result { + let reader = ReceiverRecordBatchReader::new(schema, receiver); + match kind { + WriterKind::Committed => { + let fut = Dataset::write(reader, &path, Some(params)); + match runtime::block_on(fut) { + Ok(Ok(_)) => Ok(WriterResult::Committed), + Ok(Err(err)) => Err(err.to_string()), + Err(err) => Err(format!("runtime: {err}")), + } + } + WriterKind::Uncommitted => { + let source: Box = Box::new(reader); + let builder = InsertBuilder::new(path.as_str()).with_params(¶ms); + let fut = builder.execute_uncommitted_stream(source); + match runtime::block_on(fut) { + Ok(Ok(txn)) => Ok(WriterResult::Uncommitted(txn)), + Ok(Err(err)) => Err(err.to_string()), + Err(err) => Err(format!("runtime: {err}")), + } + } + } + }) +} + #[no_mangle] pub unsafe extern "C" fn lance_open_writer_with_storage_options( path: *const c_char, @@ -249,8 +681,6 @@ fn open_uncommitted_writer_inner( ) })?; - let (sender, receiver) = sync_channel::(2); - let mut store_params = ObjectStoreParams::default(); if !storage_options.is_empty() { store_params.storage_options = Some(storage_options); @@ -264,25 +694,31 @@ fn open_uncommitted_writer_inner( store_params: Some(store_params), ..Default::default() }; - - let schema_for_thread = schema.clone(); - let join = std::thread::spawn(move || -> Result { - let reader = ReceiverRecordBatchReader::new(schema_for_thread, receiver); - let source: Box = Box::new(reader); - let builder = InsertBuilder::new(path.as_str()).with_params(¶ms); - let fut = builder.execute_uncommitted_stream(source); - match runtime::block_on(fut) { - Ok(Ok(txn)) => Ok(WriterResult::Uncommitted(txn)), - Ok(Err(err)) => Err(err.to_string()), - Err(err) => Err(format!("runtime: {err}")), + let mut vector_candidates = Vec::::new(); + for (idx, field) in schema.fields().iter().enumerate() { + if let Some((list_kind, element_type)) = is_variable_list_vector_type(field.data_type()) { + vector_candidates.push(VectorConversion { + col_idx: idx, + dim: 0, + list_kind, + element_type, + }); } - }); + } Ok(WriterHandle { - schema, + input_schema: schema.clone(), data_type, - sender: RwLock::new(Some(sender)), - join: Mutex::new(Some(join)), + state: Mutex::new(WriterState { + kind: WriterKind::Uncommitted, + path, + params, + vector_candidates, + buffered_batches: Vec::new(), + output_schema: None, + output_sender: None, + output_join: None, + }), batches_sent: AtomicU64::new(0), }) } @@ -378,8 +814,6 @@ fn open_writer_inner( ) })?; - let (sender, receiver) = sync_channel::(2); - let mut store_params = ObjectStoreParams::default(); if !storage_options.is_empty() { store_params.storage_options = Some(storage_options); @@ -394,22 +828,31 @@ fn open_writer_inner( ..Default::default() }; - let schema_for_thread = schema.clone(); - let join = std::thread::spawn(move || -> Result { - let reader = ReceiverRecordBatchReader::new(schema_for_thread, receiver); - let fut = Dataset::write(reader, &path, Some(params)); - match runtime::block_on(fut) { - Ok(Ok(_)) => Ok(WriterResult::Committed), - Ok(Err(err)) => Err(err.to_string()), - Err(err) => Err(format!("runtime: {err}")), + let mut vector_candidates = Vec::::new(); + for (idx, field) in schema.fields().iter().enumerate() { + if let Some((list_kind, element_type)) = is_variable_list_vector_type(field.data_type()) { + vector_candidates.push(VectorConversion { + col_idx: idx, + dim: 0, + list_kind, + element_type, + }); } - }); + } Ok(WriterHandle { - schema, + input_schema: schema.clone(), data_type, - sender: RwLock::new(Some(sender)), - join: Mutex::new(Some(join)), + state: Mutex::new(WriterState { + kind: WriterKind::Committed, + path, + params, + vector_candidates, + buffered_batches: Vec::new(), + output_schema: None, + output_sender: None, + output_join: None, + }), batches_sent: AtomicU64::new(0), }) } @@ -437,18 +880,6 @@ fn writer_write_batch_inner(writer: *mut c_void, array: *mut c_void) -> FfiResul } let handle = unsafe { &*(writer as *const WriterHandle) }; - let sender = { - let guard = handle - .sender - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - guard - .as_ref() - .ok_or_else(|| { - FfiError::new(ErrorCode::DatasetWriteBatch, "writer is already finished") - })? - .clone() - }; let raw_array = unsafe { ptr::read(array as *mut RawArrowArray) }; unsafe { @@ -468,17 +899,119 @@ fn writer_write_batch_inner(writer: *mut c_void, array: *mut c_void) -> FfiResul .downcast_ref::() .ok_or_else(|| FfiError::new(ErrorCode::DatasetWriteBatch, "array is not a struct"))?; - let batch = RecordBatch::try_new(handle.schema.clone(), struct_array.columns().to_vec()) + let input_batch = RecordBatch::try_new(handle.input_schema.clone(), struct_array.columns().to_vec()) .map_err(|err| { FfiError::new(ErrorCode::DatasetWriteBatch, format!("record batch: {err}")) })?; - sender.send(batch).map_err(|_| { - FfiError::new( - ErrorCode::DatasetWriteBatch, - "writer background task exited", - ) - })?; + let (sender, to_send) = { + let mut guard = handle.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + + if guard.output_sender.is_none() { + guard.buffered_batches.push(input_batch); + + if !guard.vector_candidates.is_empty() { + let batches = guard.buffered_batches.clone(); + for cand in guard.vector_candidates.iter_mut() { + if cand.dim != 0 { + continue; + } + for batch in batches.iter() { + let arr = batch + .column(cand.col_idx) + .as_ref(); + match infer_vector_dim_from_array(arr, cand.list_kind) { + Some(Ok(dim)) => { + cand.dim = dim; + break; + } + Some(Err(e)) => { + return Err(FfiError::new(ErrorCode::DatasetWriteBatch, e)); + } + None => {} + } + } + } + + let batches = guard.buffered_batches.clone(); + for cand in guard.vector_candidates.iter() { + if cand.dim == 0 { + continue; + } + for batch in batches.iter() { + let arr = batch.column(cand.col_idx).as_ref(); + if let Err(e) = validate_list_vector_dim(arr, cand.list_kind, cand.dim) { + return Err(FfiError::new(ErrorCode::DatasetWriteBatch, e)); + } + } + } + } + + let can_start = guard.vector_candidates.iter().all(|c| c.dim != 0) + || guard.buffered_batches.len() >= MAX_VECTOR_DIM_INFERENCE_BATCHES; + if can_start { + let conversions: Vec = guard + .vector_candidates + .iter() + .filter(|c| c.dim != 0) + .cloned() + .collect(); + + let output_schema = build_output_schema(&handle.input_schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteBatch, e))?; + let (sender, receiver) = sync_channel::(2); + let join = spawn_writer_thread( + guard.kind, + guard.path.clone(), + guard.params.clone(), + output_schema.clone(), + receiver, + ); + + let buffered = std::mem::take(&mut guard.buffered_batches); + let mut out_batches = Vec::with_capacity(buffered.len()); + for b in buffered.iter() { + let out = convert_record_batch(b, &output_schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteBatch, e))?; + out_batches.push(out); + } + + guard.output_schema = Some(output_schema); + guard.output_sender = Some(sender.clone()); + guard.output_join = Some(join); + (Some(sender), out_batches) + } else { + (None, Vec::new()) + } + } else { + let sender = guard.output_sender.as_ref().cloned(); + let schema = guard + .output_schema + .as_ref() + .ok_or_else(|| FfiError::new(ErrorCode::DatasetWriteBatch, "writer is not initialized"))? + .clone(); + let conversions: Vec = guard + .vector_candidates + .iter() + .filter(|c| c.dim != 0) + .cloned() + .collect(); + let out = convert_record_batch(&input_batch, &schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteBatch, e))?; + (sender, vec![out]) + } + }; + + if let Some(sender) = sender { + for batch in to_send { + sender.send(batch).map_err(|_| { + FfiError::new( + ErrorCode::DatasetWriteBatch, + "writer background task exited", + ) + })?; + } + } handle.batches_sent.fetch_add(1, Ordering::Relaxed); @@ -505,43 +1038,65 @@ fn writer_finish_inner(writer: *mut c_void) -> FfiResult<()> { } let handle = unsafe { &*(writer as *const WriterHandle) }; - - if handle.batches_sent.load(Ordering::Acquire) == 0 { - let sender = { - let guard = handle - .sender - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - guard.clone() - }; - if let Some(sender) = sender { - let empty = RecordBatch::new_empty(handle.schema.clone()); - sender.send(empty).map_err(|_| { - FfiError::new( - ErrorCode::DatasetWriteFinish, - "writer background task exited", - ) - })?; + let (sender, join, to_send) = { + let mut guard = handle.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + if guard.output_sender.is_none() { + let conversions: Vec = guard + .vector_candidates + .iter() + .filter(|c| c.dim != 0) + .cloned() + .collect(); + let output_schema = build_output_schema(&handle.input_schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteFinish, e))?; + let (sender, receiver) = sync_channel::(2); + let join = spawn_writer_thread( + guard.kind, + guard.path.clone(), + guard.params.clone(), + output_schema.clone(), + receiver, + ); + let buffered = std::mem::take(&mut guard.buffered_batches); + let mut out_batches = Vec::with_capacity(buffered.len() + 1); + for b in buffered.iter() { + let out = convert_record_batch(b, &output_schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteFinish, e))?; + out_batches.push(out); + } + if handle.batches_sent.load(Ordering::Acquire) == 0 { + out_batches.push(RecordBatch::new_empty(output_schema.clone())); + } + guard.output_schema = Some(output_schema); + guard.output_sender = Some(sender.clone()); + guard.output_join = Some(join); + guard.buffered_batches = out_batches; } - } - let sender = handle - .sender - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .take(); - drop(sender); - - let join = { - let mut guard = handle - .join - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - guard.take().ok_or_else(|| { - FfiError::new(ErrorCode::DatasetWriteFinish, "writer is already finished") - })? + let sender = guard + .output_sender + .as_ref() + .cloned() + .ok_or_else(|| FfiError::new(ErrorCode::DatasetWriteFinish, "writer is not initialized"))?; + let join = guard + .output_join + .take() + .ok_or_else(|| FfiError::new(ErrorCode::DatasetWriteFinish, "writer is already finished"))?; + let to_send = std::mem::take(&mut guard.buffered_batches); + guard.output_sender = None; + (sender, join, to_send) }; + for b in to_send { + sender.send(b).map_err(|_| { + FfiError::new( + ErrorCode::DatasetWriteFinish, + "writer background task exited", + ) + })?; + } + drop(sender); + match join.join() { Ok(Ok(WriterResult::Committed)) => Ok(()), Ok(Ok(WriterResult::Uncommitted(_))) => Err(FfiError::new( @@ -588,46 +1143,72 @@ fn writer_finish_uncommitted_inner( } let handle = unsafe { &*(writer as *const WriterHandle) }; + let (sender, join, to_send) = { + let mut guard = handle.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + if guard.output_sender.is_none() { + let conversions: Vec = guard + .vector_candidates + .iter() + .filter(|c| c.dim != 0) + .cloned() + .collect(); + let output_schema = build_output_schema(&handle.input_schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteFinishUncommitted, e))?; + let (sender, receiver) = sync_channel::(2); + let join = spawn_writer_thread( + guard.kind, + guard.path.clone(), + guard.params.clone(), + output_schema.clone(), + receiver, + ); + let buffered = std::mem::take(&mut guard.buffered_batches); + let mut out_batches = Vec::with_capacity(buffered.len() + 1); + for b in buffered.iter() { + let out = convert_record_batch(b, &output_schema, &conversions) + .map_err(|e| FfiError::new(ErrorCode::DatasetWriteFinishUncommitted, e))?; + out_batches.push(out); + } + if handle.batches_sent.load(Ordering::Acquire) == 0 { + out_batches.push(RecordBatch::new_empty(output_schema.clone())); + } + guard.output_schema = Some(output_schema); + guard.output_sender = Some(sender.clone()); + guard.output_join = Some(join); + guard.buffered_batches = out_batches; + } - if handle.batches_sent.load(Ordering::Acquire) == 0 { - let sender = { - let guard = handle - .sender - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - guard.clone() - }; - if let Some(sender) = sender { - let empty = RecordBatch::new_empty(handle.schema.clone()); - sender.send(empty).map_err(|_| { + let sender = guard + .output_sender + .as_ref() + .cloned() + .ok_or_else(|| { FfiError::new( ErrorCode::DatasetWriteFinishUncommitted, - "writer background task exited", + "writer is not initialized", ) })?; - } - } - - let sender = handle - .sender - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .take(); - drop(sender); - - let join = { - let mut guard = handle - .join - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - guard.take().ok_or_else(|| { + let join = guard.output_join.take().ok_or_else(|| { FfiError::new( ErrorCode::DatasetWriteFinishUncommitted, "writer is already finished", ) - })? + })?; + let to_send = std::mem::take(&mut guard.buffered_batches); + guard.output_sender = None; + (sender, join, to_send) }; + for b in to_send { + sender.send(b).map_err(|_| { + FfiError::new( + ErrorCode::DatasetWriteFinishUncommitted, + "writer background task exited", + ) + })?; + } + drop(sender); + let txn = match join.join() { Ok(Ok(WriterResult::Uncommitted(txn))) => txn, Ok(Ok(WriterResult::Committed)) => { diff --git a/src/lance_search.cpp b/src/lance_search.cpp index a1d3df1..21e9b1e 100644 --- a/src/lance_search.cpp +++ b/src/lance_search.cpp @@ -79,11 +79,17 @@ static vector ParseQueryVector(const Value &value, throw InvalidInputException(function_name + " requires a non-null query vector"); } - if (value.type().id() != LogicalTypeId::LIST) { + if (value.type().id() != LogicalTypeId::LIST && + value.type().id() != LogicalTypeId::ARRAY) { throw InvalidInputException(function_name + - " requires query vector to be a LIST"); + " requires query vector to be a LIST or ARRAY"); + } + vector children; + if (value.type().id() == LogicalTypeId::LIST) { + children = ListValue::GetChildren(value); + } else { + children = ArrayValue::GetChildren(value); } - auto children = ListValue::GetChildren(value); if (children.empty()) { throw InvalidInputException(function_name + " requires a non-empty query vector"); diff --git a/test/sql/bench_bigann_tiny.test b/test/sql/bench_bigann_tiny.test index 6531e4c..43de8e5 100644 --- a/test/sql/bench_bigann_tiny.test +++ b/test/sql/bench_bigann_tiny.test @@ -51,7 +51,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, use_index = false ) r @@ -95,7 +95,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, prefilter = true, use_index = false @@ -141,7 +141,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, use_index = true ) r @@ -173,7 +173,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, use_index = true ) r @@ -211,7 +211,7 @@ FROM lance_vector_search( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, prefilter = true, use_index = true @@ -243,7 +243,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, prefilter = true, use_index = true @@ -285,7 +285,7 @@ FROM lance_vector_search( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 1, use_index = true ); @@ -315,7 +315,7 @@ FROM lance_vector_search( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 1, use_index = false ); diff --git a/test/sql/index_ddl.test b/test/sql/index_ddl.test index 9d68efe..3d5d226 100644 --- a/test/sql/index_ddl.test +++ b/test/sql/index_ddl.test @@ -59,7 +59,7 @@ FROM lance_vector_search( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 1, use_index = true, explain_verbose = true @@ -90,7 +90,7 @@ FROM lance_vector_search( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 1, use_index = false, explain_verbose = true @@ -121,7 +121,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, use_index = true ) @@ -153,7 +153,7 @@ WITH got AS ( -0.3673502, -0.68120277, 0.41665298 - ]::FLOAT[], + ]::FLOAT[16], k = 10, use_index = true ) r diff --git a/test/sql/s3_scan_minio.test b/test/sql/s3_scan_minio.test index 3c5a763..aff8e22 100644 --- a/test/sql/s3_scan_minio.test +++ b/test/sql/s3_scan_minio.test @@ -58,7 +58,7 @@ SELECT id FROM lance_hybrid_search( 's3://${LANCE_S3_BUCKET}/${LANCE_S3_SEARCH_DATASET_PREFIX}', 'vec', - [0.0, 0.0, 0.0, 0.0]::FLOAT[], + [0.0, 0.0, 0.0, 0.0]::FLOAT[4], 'text', 'puppy', k = 3, diff --git a/test/sql/search_functions.test b/test/sql/search_functions.test index af2889d..f9af15b 100644 --- a/test/sql/search_functions.test +++ b/test/sql/search_functions.test @@ -6,7 +6,7 @@ require lance # Invalid dataset root statement error -SELECT * FROM lance_vector_search('dummy_path.lance', 'vec', [1.0]::FLOAT[], k = 1) +SELECT * FROM lance_vector_search('dummy_path.lance', 'vec', [1.0]::FLOAT[1], k = 1) ---- IO Error: Failed to open Lance dataset: dummy_path.lance (Lance error: dataset open 'dummy_path.lance': @@ -18,7 +18,7 @@ Invalid Input Error: lance_vector_search requires a non-empty query vector # Non-positive k is rejected statement error -SELECT * FROM lance_vector_search('test/data/test_data.lance', 'vec', [1.0]::FLOAT[], k = 0) +SELECT * FROM lance_vector_search('test/data/test_data.lance', 'vec', [1.0]::FLOAT[1], k = 0) ---- Invalid Input Error: lance_vector_search requires k > 0 @@ -49,7 +49,7 @@ SELECT id FROM lance_vector_search( 'test/data/search_test_data.lance', 'vec', - [0.0, 0.0, 0.0, 0.0]::FLOAT[], + [0.0, 0.0, 0.0, 0.0]::FLOAT[4], k = 3, use_index = false ) @@ -65,7 +65,7 @@ SELECT id FROM lance_vector_search( 'test/data/search_test_data.lance', 'vec', - [0.0, 0.0, 0.0, 0.0]::FLOAT[], + [0.0, 0.0, 0.0, 0.0]::FLOAT[4], k = 10, prefilter = true, use_index = false @@ -76,13 +76,87 @@ ORDER BY _distance 4 5 +# Vector search over ARRAY vectors written by DuckDB (preferred usage) +statement ok +COPY ( + SELECT * + FROM ( + VALUES + (1::BIGINT, 'duck'::VARCHAR, [0.9, 0.7, 0.1]::FLOAT[3]), + (2::BIGINT, 'horse'::VARCHAR, [0.3, 0.1, 0.5]::FLOAT[3]), + (3::BIGINT, 'dragon'::VARCHAR, [0.5, 0.2, 0.7]::FLOAT[3]) + ) AS t(id, animal, vec) +) TO 'test/.tmp/knn_array_f32.lance' (FORMAT lance, mode 'overwrite'); + +query I +SELECT id +FROM lance_vector_search( + 'test/.tmp/knn_array_f32.lance', + 'vec', + [0.8, 0.7, 0.2]::FLOAT[3], + k = 1, + use_index = false +) +ORDER BY _distance; +---- +1 + +statement ok +COPY ( + SELECT * + FROM ( + VALUES + (1::BIGINT, 'duck'::VARCHAR, [0.9, 0.7, 0.1]::DOUBLE[3]), + (2::BIGINT, 'horse'::VARCHAR, [0.3, 0.1, 0.5]::DOUBLE[3]), + (3::BIGINT, 'dragon'::VARCHAR, [0.5, 0.2, 0.7]::DOUBLE[3]) + ) AS t(id, animal, vec) +) TO 'test/.tmp/knn_array_f64.lance' (FORMAT lance, mode 'overwrite'); + +query I +SELECT id +FROM lance_vector_search( + 'test/.tmp/knn_array_f64.lance', + 'vec', + [0.8, 0.7, 0.2]::DOUBLE[3], + k = 1, + use_index = false +) +ORDER BY _distance; +---- +1 + +# Vector search over LIST vectors written by DuckDB is supported via writer normalization (issue #117) +statement ok +COPY ( + SELECT * + FROM ( + VALUES + (1::BIGINT, 'duck'::VARCHAR, [0.9, 0.7, 0.1]::FLOAT[]), + (2::BIGINT, 'horse'::VARCHAR, [0.3, 0.1, 0.5]::FLOAT[]), + (3::BIGINT, 'dragon'::VARCHAR, [0.5, 0.2, 0.7]::FLOAT[]) + ) AS t(id, animal, vec) +) TO 'test/.tmp/knn_list_f32.lance' (FORMAT lance, mode 'overwrite'); + +query I +SELECT id +FROM lance_vector_search( + 'test/.tmp/knn_list_f32.lance', + 'vec', + [0.8, 0.7, 0.2]::FLOAT[3], + k = 1, + use_index = false +) +ORDER BY _distance; +---- +1 + # Hybrid search (vector + fts) query I SELECT id FROM lance_hybrid_search( 'test/data/search_test_data.lance', 'vec', - [0.0, 0.0, 0.0, 0.0]::FLOAT[], + [0.0, 0.0, 0.0, 0.0]::FLOAT[4], 'text', 'puppy', k = 3, @@ -102,7 +176,7 @@ SELECT id FROM lance_hybrid_search( 'test/data/search_test_data.lance', 'vec', - [0.0, 0.0, 0.0, 0.0]::FLOAT[], + [0.0, 0.0, 0.0, 0.0]::FLOAT[4], 'text', 'puppy', k = 3,