diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index f8158b2211a2..4e6a72e87aac 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -326,10 +326,11 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> { mod tests { use super::*; use crate::VariantArrayBuilder; - use arrow::array::{Array, Float64Array, Int64Array}; + use arrow::array::{Array, FixedSizeBinaryArray, Float64Array, Int64Array}; use arrow::datatypes::{DataType, Field, Fields}; use parquet_variant::{ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder}; use std::sync::Arc; + use uuid::Uuid; #[test] fn test_already_shredded_input_error() { @@ -369,6 +370,73 @@ mod tests { shred_variant(&input, &list_schema).expect_err("unsupported"); } + #[test] + fn test_invalid_fixed_size_binary_shredding() { + let mock_uuid_1 = Uuid::new_v4(); + + let input = VariantArray::from_iter([Some(Variant::from(mock_uuid_1)), None]); + + // shred_variant only supports FixedSizeBinary(16). Any other length will err. + let err = shred_variant(&input, &DataType::FixedSizeBinary(17)).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: FixedSizeBinary(17) is not a valid variant shredding type. Only FixedSizeBinary(16) for UUID is supported." + ); + } + + #[test] + fn test_uuid_shredding() { + let mock_uuid_1 = Uuid::new_v4(); + let mock_uuid_2 = Uuid::new_v4(); + + let input = VariantArray::from_iter([ + Some(Variant::from(mock_uuid_1)), + None, + Some(Variant::from(false)), + Some(Variant::from(mock_uuid_2)), + ]); + + let variant_array = shred_variant(&input, &DataType::FixedSizeBinary(16)).unwrap(); + + // // inspect the typed_value Field and make sure it contains the canonical Uuid extension type + // let typed_value_field = variant_array + // .inner() + // .fields() + // .into_iter() + // .find(|f| f.name() == "typed_value") + // .unwrap(); + + // assert!( + // typed_value_field + // .try_extension_type::() + // .is_ok() + // ); + + // probe the downcasted typed_value array to make sure uuids are shredded correctly + let uuids = variant_array + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(uuids.len(), 4); + + assert!(!uuids.is_null(0)); + + let got_uuid_1: &[u8] = uuids.value(0); + assert_eq!(got_uuid_1, mock_uuid_1.as_bytes()); + + assert!(uuids.is_null(1)); + assert!(uuids.is_null(2)); + + assert!(!uuids.is_null(3)); + + let got_uuid_2: &[u8] = uuids.value(3); + assert_eq!(got_uuid_2, mock_uuid_2.as_bytes()); + } + #[test] fn test_primitive_shredding_comprehensive() { // Test mixed scenarios in a single array @@ -869,6 +937,187 @@ mod tests { assert!(value_field3.is_null(0)); // fully shredded, no remaining fields } + #[test] + fn test_uuid_shredding_in_objects() { + let mock_uuid_1 = Uuid::new_v4(); + let mock_uuid_2 = Uuid::new_v4(); + let mock_uuid_3 = Uuid::new_v4(); + + let mut builder = VariantArrayBuilder::new(6); + + // Row 0: Fully shredded object with both UUID fields + builder + .new_object() + .with_field("id", mock_uuid_1) + .with_field("session_id", mock_uuid_2) + .finish(); + + // Row 1: Partially shredded object - UUID fields plus extra field + builder + .new_object() + .with_field("id", mock_uuid_2) + .with_field("session_id", mock_uuid_3) + .with_field("name", "test_user") + .finish(); + + // Row 2: Missing UUID field (no session_id) + builder.new_object().with_field("id", mock_uuid_1).finish(); + + // Row 3: Type mismatch - id is UUID but session_id is a string + builder + .new_object() + .with_field("id", mock_uuid_3) + .with_field("session_id", "not-a-uuid") + .finish(); + + // Row 4: Object with non-UUID value in id field + builder + .new_object() + .with_field("id", 12345i64) + .with_field("session_id", mock_uuid_1) + .finish(); + + // Row 5: Null + builder.append_null(); + + let input = builder.build(); + + let fields = Fields::from(vec![ + Field::new("id", DataType::FixedSizeBinary(16), true), + Field::new("session_id", DataType::FixedSizeBinary(16), true), + ]); + let target_schema = DataType::Struct(fields); + + let result = shred_variant(&input, &target_schema).unwrap(); + + assert!(result.value_field().is_some()); + assert!(result.typed_value_field().is_some()); + assert_eq!(result.len(), 6); + + let metadata = result.metadata_field(); + let value = result.value_field().unwrap(); + let typed_value = result + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Extract id and session_id fields from typed_value struct + let id_field = + ShreddedVariantFieldArray::try_new(typed_value.column_by_name("id").unwrap()).unwrap(); + let session_id_field = + ShreddedVariantFieldArray::try_new(typed_value.column_by_name("session_id").unwrap()) + .unwrap(); + + let id_value = id_field + .value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let id_typed_value = id_field + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let session_id_value = session_id_field + .value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let session_id_typed_value = session_id_field + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Row 0: Fully shredded - both UUID fields shred successfully + assert!(result.is_valid(0)); + + assert!(value.is_null(0)); // fully shredded, no remaining fields + assert!(id_value.is_null(0)); + assert!(session_id_value.is_null(0)); + + assert!(typed_value.is_valid(0)); + assert!(id_typed_value.is_valid(0)); + assert!(session_id_typed_value.is_valid(0)); + + assert_eq!(id_typed_value.value(0), mock_uuid_1.as_bytes()); + assert_eq!(session_id_typed_value.value(0), mock_uuid_2.as_bytes()); + + // Row 1: Partially shredded - value contains extra name field + assert!(result.is_valid(1)); + + assert!(value.is_valid(1)); // contains unshredded "name" field + assert!(typed_value.is_valid(1)); + + assert!(id_value.is_null(1)); + assert!(id_typed_value.is_valid(1)); + assert_eq!(id_typed_value.value(1), mock_uuid_2.as_bytes()); + + assert!(session_id_value.is_null(1)); + assert!(session_id_typed_value.is_valid(1)); + assert_eq!(session_id_typed_value.value(1), mock_uuid_3.as_bytes()); + + // Verify the value field contains the name field + let row_1_variant = Variant::new(metadata.value(1), value.value(1)); + let Variant::Object(obj) = row_1_variant else { + panic!("Expected object"); + }; + + assert_eq!(obj.get("name"), Some(Variant::from("test_user"))); + + // Row 2: Missing session_id field + assert!(result.is_valid(2)); + + assert!(value.is_null(2)); // fully shredded, no extra fields + assert!(typed_value.is_valid(2)); + + assert!(id_value.is_null(2)); + assert!(id_typed_value.is_valid(2)); + assert_eq!(id_typed_value.value(2), mock_uuid_1.as_bytes()); + + assert!(session_id_value.is_null(2)); + assert!(session_id_typed_value.is_null(2)); // missing field + + // Row 3: Type mismatch - session_id is a string, not UUID + assert!(result.is_valid(3)); + + assert!(value.is_null(3)); // no extra fields + assert!(typed_value.is_valid(3)); + + assert!(id_value.is_null(3)); + assert!(id_typed_value.is_valid(3)); + assert_eq!(id_typed_value.value(3), mock_uuid_3.as_bytes()); + + assert!(session_id_value.is_valid(3)); // type mismatch, stored in value + assert!(session_id_typed_value.is_null(3)); + let session_id_variant = Variant::new(metadata.value(3), session_id_value.value(3)); + assert_eq!(session_id_variant, Variant::from("not-a-uuid")); + + // Row 4: Type mismatch - id is int64, not UUID + assert!(result.is_valid(4)); + + assert!(value.is_null(4)); // no extra fields + assert!(typed_value.is_valid(4)); + + assert!(id_value.is_valid(4)); // type mismatch, stored in value + assert!(id_typed_value.is_null(4)); + let id_variant = Variant::new(metadata.value(4), id_value.value(4)); + assert_eq!(id_variant, Variant::from(12345i64)); + + assert!(session_id_value.is_null(4)); + assert!(session_id_typed_value.is_valid(4)); + assert_eq!(session_id_typed_value.value(4), mock_uuid_1.as_bytes()); + + // Row 5: Null + assert!(result.is_null(5)); + } + #[test] fn test_spec_compliance() { let input = VariantArray::from_iter(vec![Variant::from(42i64), Variant::from("hello")]); diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index faf64c20d0a2..998de36d18d3 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -16,7 +16,8 @@ // under the License. use arrow::array::{ - ArrayRef, BinaryViewArray, BooleanBuilder, NullArray, NullBufferBuilder, PrimitiveBuilder, + ArrayRef, BinaryViewArray, BooleanBuilder, FixedSizeBinaryBuilder, NullArray, + NullBufferBuilder, PrimitiveBuilder, }; use arrow::compute::{CastOptions, DecimalCast}; use arrow::datatypes::{self, DataType, DecimalType}; @@ -60,6 +61,7 @@ pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> { TimestampNanoNtz(VariantToTimestampNtzArrowRowBuilder<'a, datatypes::TimestampNanosecondType>), Time(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Time64MicrosecondType>), Date(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Date32Type>), + Uuid(VariantToUuidArrowRowBuilder<'a>), } /// Builder for converting variant values into strongly typed Arrow arrays. @@ -101,6 +103,7 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { TimestampNanoNtz(b) => b.append_null(), Time(b) => b.append_null(), Date(b) => b.append_null(), + Uuid(b) => b.append_null(), } } @@ -130,6 +133,7 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { TimestampNanoNtz(b) => b.append_value(value), Time(b) => b.append_value(value), Date(b) => b.append_value(value), + Uuid(b) => b.append_value(value), } } @@ -159,6 +163,7 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { TimestampNanoNtz(b) => b.finish(), Time(b) => b.finish(), Date(b) => b.finish(), + Uuid(b) => b.finish(), } } } @@ -200,98 +205,116 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>( ) -> Result> { use PrimitiveVariantToArrowRowBuilder::*; - let builder = - match data_type { - DataType::Null => Null(VariantToNullArrowRowBuilder::new(cast_options, capacity)), - DataType::Boolean => { - Boolean(VariantToBooleanArrowRowBuilder::new(cast_options, capacity)) - } - DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Decimal32(precision, scale) => Decimal32( - VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, - ), - DataType::Decimal64(precision, scale) => Decimal64( - VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, - ), - DataType::Decimal128(precision, scale) => Decimal128( - VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, - ), - DataType::Decimal256(precision, scale) => Decimal256( - VariantToDecimalArrowRowBuilder::new(cast_options, capacity, *precision, *scale)?, - ), - DataType::Timestamp(TimeUnit::Microsecond, None) => TimestampMicroNtz( - VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), - ), - DataType::Timestamp(TimeUnit::Microsecond, tz) => TimestampMicro( - VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), - ), - DataType::Timestamp(TimeUnit::Nanosecond, None) => TimestampNanoNtz( - VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), - ), - DataType::Timestamp(TimeUnit::Nanosecond, tz) => TimestampNano( - VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), - ), - DataType::Date32 => Date(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Time64(TimeUnit::Microsecond) => Time( - VariantToPrimitiveArrowRowBuilder::new(cast_options, capacity), - ), - _ if data_type.is_primitive() => { - return Err(ArrowError::NotYetImplemented(format!( - "Primitive data_type {data_type:?} not yet implemented" - ))); - } - _ => { - return Err(ArrowError::InvalidArgumentError(format!( - "Not a primitive type: {data_type:?}" - ))); - } - }; + let builder = match data_type { + DataType::Null => Null(VariantToNullArrowRowBuilder::new(cast_options, capacity)), + DataType::Boolean => Boolean(VariantToBooleanArrowRowBuilder::new(cast_options, capacity)), + DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Decimal32(precision, scale) => Decimal32(VariantToDecimalArrowRowBuilder::new( + cast_options, + capacity, + *precision, + *scale, + )?), + DataType::Decimal64(precision, scale) => Decimal64(VariantToDecimalArrowRowBuilder::new( + cast_options, + capacity, + *precision, + *scale, + )?), + DataType::Decimal128(precision, scale) => Decimal128(VariantToDecimalArrowRowBuilder::new( + cast_options, + capacity, + *precision, + *scale, + )?), + DataType::Decimal256(precision, scale) => Decimal256(VariantToDecimalArrowRowBuilder::new( + cast_options, + capacity, + *precision, + *scale, + )?), + DataType::Timestamp(TimeUnit::Microsecond, None) => TimestampMicroNtz( + VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => TimestampMicro( + VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), + ), + DataType::Timestamp(TimeUnit::Nanosecond, None) => TimestampNanoNtz( + VariantToTimestampNtzArrowRowBuilder::new(cast_options, capacity), + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => TimestampNano( + VariantToTimestampArrowRowBuilder::new(cast_options, capacity, tz.clone()), + ), + DataType::Date32 => Date(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Time64(TimeUnit::Microsecond) => Time(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::FixedSizeBinary(16) => { + Uuid(VariantToUuidArrowRowBuilder::new(cast_options, capacity)) + } + DataType::FixedSizeBinary(size) => { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedSizeBinary({size}) is not a valid variant shredding type. Only FixedSizeBinary(16) for UUID is supported." + ))); + } + _ if data_type.is_primitive() => { + return Err(ArrowError::NotYetImplemented(format!( + "Primitive data_type {data_type:?} not yet implemented" + ))); + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Not a primitive type: {data_type:?}" + ))); + } + }; Ok(builder) } @@ -519,6 +542,49 @@ where } } +/// Builder for converting variant values to FixedSizeBinary(16) for UUIDs +pub(crate) struct VariantToUuidArrowRowBuilder<'a> { + builder: FixedSizeBinaryBuilder, + cast_options: &'a CastOptions<'a>, +} + +impl<'a> VariantToUuidArrowRowBuilder<'a> { + fn new(cast_options: &'a CastOptions<'a>, capacity: usize) -> Self { + Self { + builder: FixedSizeBinaryBuilder::with_capacity(capacity, 16), + cast_options, + } + } + + fn append_null(&mut self) -> Result<()> { + self.builder.append_null(); + Ok(()) + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + match value.as_uuid() { + Some(uuid) => { + self.builder + .append_value(uuid.as_bytes()) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + + Ok(true) + } + None if self.cast_options.safe => { + self.builder.append_null(); + Ok(false) + } + None => Err(ArrowError::CastError(format!( + "Failed to extract UUID from variant {value:?}", + ))), + } + } + + fn finish(mut self) -> Result { + Ok(Arc::new(self.builder.finish())) + } +} + /// Builder for creating VariantArray output (for path extraction without type conversion) pub(crate) struct VariantToBinaryVariantArrowRowBuilder { metadata: BinaryViewArray,