From f38fd5d9a8cc6af98d7a1c43e36653bcc74a4bd9 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Tue, 18 Nov 2025 22:38:18 +0200 Subject: [PATCH 1/4] feat: Support Ref schemas in lookup [Avro] --- .../src/avro_to_arrow/arrow_array_reader.rs | 961 +++++++++++++++++- 1 file changed, 952 insertions(+), 9 deletions(-) diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index 5b1f534ad78b..20128729240c 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -17,7 +17,7 @@ //! Avro to Arrow array readers -use apache_avro::schema::RecordSchema; +use apache_avro::schema::{EnumSchema, FixedSchema, Name, RecordSchema}; use apache_avro::{ error::Details as AvroErrorDetails, schema::{Schema as AvroSchema, SchemaKind}, @@ -45,8 +45,8 @@ use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::arrow_err; use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::{arrow_err, HashMap}; use num_traits::NumCast; use std::collections::BTreeMap; use std::io::Read; @@ -54,6 +54,11 @@ use std::sync::Arc; type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; +pub struct UnresolvedNames { + names_map: HashMap, + to_resolve: HashMap, +} + pub struct AvroArrowArrayReader<'a, R: Read> { reader: AvroReader<'a, R>, schema: SchemaRef, @@ -61,10 +66,51 @@ pub struct AvroArrowArrayReader<'a, R: Read> { } impl AvroArrowArrayReader<'_, R> { + fn resolve_refs( + unresolved: UnresolvedNames, + schema_lookup: &mut BTreeMap, + ) -> Result<(), ArrowError> { + let UnresolvedNames { + names_map, + to_resolve, + } = unresolved; + + to_resolve.into_iter().try_for_each(|(field_path, name)| { + // Get the original schema that was defiend with the name + let original_schema_location = + names_map.get(&name).ok_or(SchemaError(format!( + "Needed to resolve name {name:?} but it does not exist in schema" + )))?; + + // Get all paths that exit in this schema, but replace the initial path with the field_path we wish to resolve + let resolved = schema_lookup + .iter() + .filter(|(path, _)| { + path.starts_with(original_schema_location) + && *path != original_schema_location + }) + .map(|(path, pos)| { + let resolved_path = + path.replacen(original_schema_location, &field_path, 1); + (resolved_path, *pos) + }) + .collect::>(); + + // Extend our schema lookup with the resolved paths + schema_lookup.extend(resolved); + + Result::<_, ArrowError>::Ok(()) + }) + } + pub fn try_new(reader: R, schema: SchemaRef) -> Result { let reader = AvroReader::new(reader)?; let writer_schema = reader.writer_schema().clone(); - let schema_lookup = Self::schema_lookup(writer_schema)?; + + let (mut schema_lookup, unresolved) = Self::schema_lookup(writer_schema)?; + + Self::resolve_refs(unresolved, &mut schema_lookup)?; + Ok(Self { reader, schema, @@ -72,15 +118,30 @@ impl AvroArrowArrayReader<'_, R> { }) } - pub fn schema_lookup(schema: AvroSchema) -> Result> { + pub fn schema_lookup( + schema: AvroSchema, + ) -> Result<(BTreeMap, UnresolvedNames)> { match schema { AvroSchema::Record(RecordSchema { - fields, mut lookup, .. + fields, + mut lookup, + name, + .. }) => { + // Insert the root into our names map + let mut unresolved_names = UnresolvedNames { + names_map: HashMap::from([(name.clone(), "".to_string())]), + to_resolve: HashMap::new(), + }; for field in fields { - Self::child_schema_lookup(&field.name, &field.schema, &mut lookup)?; + Self::child_schema_lookup( + &field.name, + &field.schema, + &mut lookup, + &mut unresolved_names, + )?; } - Ok(lookup) + Ok((lookup, unresolved_names)) } _ => arrow_err!(SchemaError( "expected avro schema to be a record".to_string(), @@ -92,6 +153,7 @@ impl AvroArrowArrayReader<'_, R> { parent_field_name: &str, schema: &AvroSchema, schema_lookup: &'b mut BTreeMap, + unresolved_names: &'b mut UnresolvedNames, ) -> Result<&'b BTreeMap> { match schema { AvroSchema::Union(us) => { @@ -111,11 +173,20 @@ impl AvroArrowArrayReader<'_, R> { parent_field_name, sub_schema, schema_lookup, + unresolved_names, )?; } } } - AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { + AvroSchema::Record(RecordSchema { + fields, + lookup, + name, + .. + }) => { + unresolved_names + .names_map + .insert(name.clone(), parent_field_name.to_string()); lookup.iter().for_each(|(field_name, pos)| { schema_lookup .insert(format!("{parent_field_name}.{field_name}"), *pos); @@ -128,6 +199,7 @@ impl AvroArrowArrayReader<'_, R> { &sub_parent_field_name, &field.schema, schema_lookup, + unresolved_names, )?; } } @@ -136,8 +208,33 @@ impl AvroArrowArrayReader<'_, R> { parent_field_name, &schema.items, schema_lookup, + unresolved_names, + )?; + } + AvroSchema::Map(map_schema) => { + let sub_parent_field_name = format!("{parent_field_name}.value"); + Self::child_schema_lookup( + &sub_parent_field_name, + &map_schema.types, + schema_lookup, + unresolved_names, )?; } + AvroSchema::Fixed(FixedSchema { name, .. }) => { + unresolved_names + .names_map + .insert(name.clone(), parent_field_name.to_string()); + } + AvroSchema::Enum(EnumSchema { name, .. }) => { + unresolved_names + .names_map + .insert(name.clone(), parent_field_name.to_string()); + } + AvroSchema::Ref { name } => { + unresolved_names + .to_resolve + .insert(parent_field_name.to_string(), name.clone()); + } _ => (), } Ok(schema_lookup) @@ -1033,7 +1130,7 @@ where #[cfg(test)] mod test { use crate::avro_to_arrow::{Reader, ReaderBuilder}; - use arrow::array::Array; + use arrow::array::{Array, FixedSizeBinaryArray, StringArray}; use arrow::datatypes::{DataType, Fields}; use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::assert_batches_eq; @@ -1804,4 +1901,850 @@ mod test { ]; assert_batches_eq!(expected, &[batch]); } + + #[test] + fn test_avro_record_ref() { + // This schema defines an Address record once, then references it multiple times + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "Person", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "home_address", + "type": { + "type": "record", + "name": "Address", + "fields": [ + { + "name": "street", + "type": "string" + }, + { + "name": "city", + "type": "string" + }, + { + "name": "zip", + "type": "int" + } + ] + } + }, + { + "name": "work_address", + "type": "Address" + }, + { + "name": "billing_address", + "type": ["null", "Address"] + } + ] + }"#, + ) + .unwrap(); + + let person1 = apache_avro::to_value(serde_json::json!({ + "name": "Alice", + "home_address": { + "street": "123 Main St", + "city": "Springfield", + "zip": 12345 + }, + "work_address": { + "street": "456 Business Ave", + "city": "Metropolis", + "zip": 67890 + }, + "billing_address": { + "street": "789 Payment Ln", + "city": "Capital City", + "zip": 11111 + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let person2 = apache_avro::to_value(serde_json::json!({ + "name": "Bob", + "home_address": { + "street": "321 Oak Dr", + "city": "Shelbyville", + "zip": 54321 + }, + "work_address": { + "street": "654 Corporate Blvd", + "city": "Tech City", + "zip": 98765 + }, + "billing_address": null + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(person1).unwrap(); + w.append(person2).unwrap(); + let bytes = w.into_inner().unwrap(); + + // Define the Arrow schema explicitly to avoid schema conversion with Refs + let address_fields = Fields::from(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("city", DataType::Utf8, false), + Field::new("zip", DataType::Int32, false), + ]); + + let arrow_schema = Arc::new(arrow::datatypes::Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new( + "home_address", + DataType::Struct(address_fields.clone()), + false, + ), + Field::new( + "work_address", + DataType::Struct(address_fields.clone()), + false, + ), + Field::new("billing_address", DataType::Struct(address_fields), true), + ])); + + let mut reader = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 4); + + let expected = [ + "+-------+------------------------------------------------------+-----------------------------------------------------------+----------------------------------------------------------+", + "| name | home_address | work_address | billing_address |", + "+-------+------------------------------------------------------+-----------------------------------------------------------+----------------------------------------------------------+", + "| Alice | {street: 123 Main St, city: Springfield, zip: 12345} | {street: 456 Business Ave, city: Metropolis, zip: 67890} | {street: 789 Payment Ln, city: Capital City, zip: 11111} |", + "| Bob | {street: 321 Oak Dr, city: Shelbyville, zip: 54321} | {street: 654 Corporate Blvd, city: Tech City, zip: 98765} | |", + "+-------+------------------------------------------------------+-----------------------------------------------------------+----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_avro_enum_ref() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "Product", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "primary_category", + "type": { + "type": "enum", + "name": "Category", + "symbols": ["ELECTRONICS", "CLOTHING", "FOOD", "BOOKS"] + } + }, + { + "name": "secondary_category", + "type": ["null", "Category"] + }, + { + "name": "tertiary_category", + "type": "Category" + } + ] + }"#, + ) + .unwrap(); + + let p1 = apache_avro::to_value(serde_json::json!({ + "name": "Laptop", + "primary_category": "ELECTRONICS", + "secondary_category": "ELECTRONICS", + "tertiary_category": "ELECTRONICS" + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let p2 = apache_avro::to_value(serde_json::json!({ + "name": "T-Shirt", + "primary_category": "CLOTHING", + "secondary_category": null, + "tertiary_category": "CLOTHING" + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(p1).unwrap(); + w.append(p2).unwrap(); + let bytes = w.into_inner().unwrap(); + + // Define Arrow schema explicitly since read_schema doesn't support Refs yet + let arrow_schema = Arc::new(arrow::datatypes::Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("primary_category", DataType::Utf8, false), + Field::new("secondary_category", DataType::Utf8, true), + Field::new("tertiary_category", DataType::Utf8, false), + ])); + + let mut reader = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 4); + + let expected = [ + "+---------+------------------+--------------------+-------------------+", + "| name | primary_category | secondary_category | tertiary_category |", + "+---------+------------------+--------------------+-------------------+", + "| Laptop | ELECTRONICS | ELECTRONICS | ELECTRONICS |", + "| T-Shirt | CLOTHING | | CLOTHING |", + "+---------+------------------+--------------------+-------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_avro_fixed_ref() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "SecurityEvent", + "fields": [ + { + "name": "event_id", + "type": "string" + }, + { + "name": "hash1", + "type": { + "type": "fixed", + "name": "MD5Hash", + "size": 16 + } + }, + { + "name": "hash2", + "type": "MD5Hash" + }, + { + "name": "optional_hash", + "type": ["null", "MD5Hash"] + } + ] + }"#, + ) + .unwrap(); + + // For Avro fixed types, we need to use apache_avro::types::Value::Fixed directly + use apache_avro::types::Value; + + let hash1_bytes = vec![1u8; 16]; + let hash2_bytes = vec![2u8; 16]; + let hash3_bytes = vec![3u8; 16]; + + let e1 = Value::Record(vec![ + ("event_id".to_string(), Value::String("evt001".to_string())), + ("hash1".to_string(), Value::Fixed(16, hash1_bytes.clone())), + ("hash2".to_string(), Value::Fixed(16, hash2_bytes.clone())), + ( + "optional_hash".to_string(), + Value::Union(1, Box::new(Value::Fixed(16, hash3_bytes.clone()))), + ), + ]); + + let e2 = Value::Record(vec![ + ("event_id".to_string(), Value::String("evt002".to_string())), + ("hash1".to_string(), Value::Fixed(16, hash2_bytes.clone())), + ("hash2".to_string(), Value::Fixed(16, hash1_bytes.clone())), + ( + "optional_hash".to_string(), + Value::Union(0, Box::new(Value::Null)), + ), + ]); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(e1).unwrap(); + w.append(e2).unwrap(); + let bytes = w.into_inner().unwrap(); + + // Define Arrow schema explicitly + let arrow_schema = Arc::new(arrow::datatypes::Schema::new(vec![ + Field::new("event_id", DataType::Utf8, false), + Field::new("hash1", DataType::FixedSizeBinary(16), false), + Field::new("hash2", DataType::FixedSizeBinary(16), false), + Field::new("optional_hash", DataType::FixedSizeBinary(16), true), + ])); + + let mut reader = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 4); + + // Verify the data types + let schema = batch.schema(); + assert_eq!(schema.field(1).data_type(), &DataType::FixedSizeBinary(16)); + assert_eq!(schema.field(2).data_type(), &DataType::FixedSizeBinary(16)); + assert_eq!(schema.field(3).data_type(), &DataType::FixedSizeBinary(16)); + + // Verify we can read the data + let hash1_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(hash1_array.len(), 2); + assert_eq!(hash1_array.value(0), hash1_bytes.as_slice()); + assert_eq!(hash1_array.value(1), hash2_bytes.as_slice()); + } + + #[test] + fn test_avro_nested_record_ref() { + // Test record refs nested within arrays and other records + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "Organization", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "primary_contact", + "type": { + "type": "record", + "name": "Contact", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "email", + "type": "string" + } + ] + } + }, + { + "name": "secondary_contact", + "type": ["null", "Contact"] + }, + { + "name": "all_contacts", + "type": { + "type": "array", + "items": "Contact" + } + } + ] + }"#, + ) + .unwrap(); + + let o1 = apache_avro::to_value(serde_json::json!({ + "name": "Acme Corp", + "primary_contact": { + "name": "Alice", + "email": "alice@acme.com" + }, + "secondary_contact": { + "name": "Bob", + "email": "bob@acme.com" + }, + "all_contacts": [ + { + "name": "Alice", + "email": "alice@acme.com" + }, + { + "name": "Bob", + "email": "bob@acme.com" + }, + { + "name": "Charlie", + "email": "charlie@acme.com" + } + ] + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let o2 = apache_avro::to_value(serde_json::json!({ + "name": "Beta Inc", + "primary_contact": { + "name": "Dave", + "email": "dave@beta.com" + }, + "secondary_contact": null, + "all_contacts": [ + { + "name": "Dave", + "email": "dave@beta.com" + } + ] + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(o1).unwrap(); + w.append(o2).unwrap(); + let bytes = w.into_inner().unwrap(); + + // Define Arrow schema explicitly + let contact_fields = Fields::from(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("email", DataType::Utf8, false), + ]); + + let arrow_schema = Arc::new(arrow::datatypes::Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new( + "primary_contact", + DataType::Struct(contact_fields.clone()), + false, + ), + Field::new( + "secondary_contact", + DataType::Struct(contact_fields.clone()), + true, + ), + Field::new( + "all_contacts", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(contact_fields), + false, + ))), + false, + ), + ])); + + let mut reader = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 4); + + let expected = [ + "+-----------+--------------------------------------+----------------------------------+--------------------------------------------------------------------------------------------------------------------+", + "| name | primary_contact | secondary_contact | all_contacts |", + "+-----------+--------------------------------------+----------------------------------+--------------------------------------------------------------------------------------------------------------------+", + "| Acme Corp | {name: Alice, email: alice@acme.com} | {name: Bob, email: bob@acme.com} | [{name: Alice, email: alice@acme.com}, {name: Bob, email: bob@acme.com}, {name: Charlie, email: charlie@acme.com}] |", + "| Beta Inc | {name: Dave, email: dave@beta.com} | | [{name: Dave, email: dave@beta.com}] |", + "+-----------+--------------------------------------+----------------------------------+--------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_avro_combined_refs() { + // Test combining record, enum, and fixed refs in a single schema + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "Transaction", + "fields": [ + { + "name": "id", + "type": "string" + }, + { + "name": "status", + "type": { + "type": "enum", + "name": "Status", + "symbols": ["PENDING", "APPROVED", "REJECTED", "CANCELLED"] + } + }, + { + "name": "previous_status", + "type": ["null", "Status"] + }, + { + "name": "signature", + "type": { + "type": "fixed", + "name": "Signature", + "size": 32 + } + }, + { + "name": "backup_signature", + "type": ["null", "Signature"] + }, + { + "name": "user", + "type": { + "type": "record", + "name": "User", + "fields": [ + { + "name": "id", + "type": "string" + }, + { + "name": "role", + "type": "Status" + } + ] + } + }, + { + "name": "approver", + "type": ["null", "User"] + }, + { + "name": "status_history", + "type": { + "type": "array", + "items": "Status" + } + } + ] + }"#, + ) + .unwrap(); + + use apache_avro::types::Value; + + let sig1 = vec![1u8; 32]; + let sig2 = vec![2u8; 32]; + + let t1 = Value::Record(vec![ + ("id".to_string(), Value::String("txn001".to_string())), + ("status".to_string(), Value::Enum(1, "APPROVED".to_string())), + ( + "previous_status".to_string(), + Value::Union(1, Box::new(Value::Enum(0, "PENDING".to_string()))), + ), + ("signature".to_string(), Value::Fixed(32, sig1.clone())), + ( + "backup_signature".to_string(), + Value::Union(1, Box::new(Value::Fixed(32, sig2.clone()))), + ), + ( + "user".to_string(), + Value::Record(vec![ + ("id".to_string(), Value::String("user001".to_string())), + ("role".to_string(), Value::Enum(1, "APPROVED".to_string())), + ]), + ), + ( + "approver".to_string(), + Value::Union( + 1, + Box::new(Value::Record(vec![ + ("id".to_string(), Value::String("admin001".to_string())), + ("role".to_string(), Value::Enum(1, "APPROVED".to_string())), + ])), + ), + ), + ( + "status_history".to_string(), + Value::Array(vec![ + Value::Enum(0, "PENDING".to_string()), + Value::Enum(1, "APPROVED".to_string()), + ]), + ), + ]); + + let t2 = Value::Record(vec![ + ("id".to_string(), Value::String("txn002".to_string())), + ("status".to_string(), Value::Enum(2, "REJECTED".to_string())), + ( + "previous_status".to_string(), + Value::Union(0, Box::new(Value::Null)), + ), + ("signature".to_string(), Value::Fixed(32, sig2.clone())), + ( + "backup_signature".to_string(), + Value::Union(0, Box::new(Value::Null)), + ), + ( + "user".to_string(), + Value::Record(vec![ + ("id".to_string(), Value::String("user002".to_string())), + ("role".to_string(), Value::Enum(2, "REJECTED".to_string())), + ]), + ), + ( + "approver".to_string(), + Value::Union(0, Box::new(Value::Null)), + ), + ( + "status_history".to_string(), + Value::Array(vec![ + Value::Enum(0, "PENDING".to_string()), + Value::Enum(2, "REJECTED".to_string()), + ]), + ), + ]); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(t1).unwrap(); + w.append(t2).unwrap(); + let bytes = w.into_inner().unwrap(); + + // Define Arrow schema explicitly + let user_fields = Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("role", DataType::Utf8, false), + ]); + + let arrow_schema = Arc::new(arrow::datatypes::Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("status", DataType::Utf8, false), + Field::new("previous_status", DataType::Utf8, true), + Field::new("signature", DataType::FixedSizeBinary(32), false), + Field::new("backup_signature", DataType::FixedSizeBinary(32), true), + Field::new("user", DataType::Struct(user_fields.clone()), false), + Field::new("approver", DataType::Struct(user_fields), true), + Field::new( + "status_history", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + false, + ), + ])); + + let mut reader = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 8); + + // Verify the schema types + let schema = batch.schema(); + assert_eq!(schema.field(1).data_type(), &DataType::Utf8); // enum as string + assert_eq!(schema.field(2).data_type(), &DataType::Utf8); // nullable enum as string + assert_eq!(schema.field(3).data_type(), &DataType::FixedSizeBinary(32)); + assert_eq!(schema.field(4).data_type(), &DataType::FixedSizeBinary(32)); + assert!(matches!(schema.field(5).data_type(), DataType::Struct(_))); // User record + assert!(matches!(schema.field(6).data_type(), DataType::Struct(_))); // nullable User + assert!(matches!(schema.field(7).data_type(), DataType::List(_))); // status_history + + // Verify data content - check a few key fields + let status_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(status_array.value(0), "APPROVED"); + assert_eq!(status_array.value(1), "REJECTED"); + + let sig_array = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(sig_array.value(0), sig1.as_slice()); + assert_eq!(sig_array.value(1), sig2.as_slice()); + } + + #[test] + fn test_avro_multiple_record_refs() { + // Test multiple different record types being referenced + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "Order", + "fields": [ + { + "name": "order_id", + "type": "string" + }, + { + "name": "shipping_address", + "type": { + "type": "record", + "name": "Address", + "fields": [ + { + "name": "street", + "type": "string" + }, + { + "name": "city", + "type": "string" + } + ] + } + }, + { + "name": "billing_address", + "type": "Address" + }, + { + "name": "customer", + "type": { + "type": "record", + "name": "Customer", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "home_address", + "type": "Address" + } + ] + } + }, + { + "name": "gift_recipient", + "type": ["null", "Customer"] + } + ] + }"#, + ) + .unwrap(); + + let o1 = apache_avro::to_value(serde_json::json!({ + "order_id": "ord001", + "shipping_address": { + "street": "123 Ship St", + "city": "Shipping City" + }, + "billing_address": { + "street": "456 Bill Ave", + "city": "Billing Town" + }, + "customer": { + "name": "Alice", + "home_address": { + "street": "789 Home Rd", + "city": "Home City" + } + }, + "gift_recipient": { + "name": "Bob", + "home_address": { + "street": "321 Gift Ln", + "city": "Gift Town" + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let o2 = apache_avro::to_value(serde_json::json!({ + "order_id": "ord002", + "shipping_address": { + "street": "111 Main St", + "city": "Main City" + }, + "billing_address": { + "street": "111 Main St", + "city": "Main City" + }, + "customer": { + "name": "Charlie", + "home_address": { + "street": "111 Main St", + "city": "Main City" + } + }, + "gift_recipient": null + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(o1).unwrap(); + w.append(o2).unwrap(); + let bytes = w.into_inner().unwrap(); + + // Define Arrow schema explicitly + // When Address is used in nullable contexts (like inside nullable Customer), + // its fields need to be nullable too + let address_fields = Fields::from(vec![ + Field::new("street", DataType::Utf8, true), // Changed to true + Field::new("city", DataType::Utf8, true), // Changed to true + ]); + + let customer_fields = Fields::from(vec![ + Field::new("name", DataType::Utf8, true), // Changed to true + Field::new( + "home_address", + DataType::Struct(address_fields.clone()), + true, // Changed to true + ), + ]); + + let arrow_schema = Arc::new(arrow::datatypes::Schema::new(vec![ + Field::new("order_id", DataType::Utf8, false), + Field::new( + "shipping_address", + DataType::Struct(address_fields.clone()), + false, + ), + Field::new("billing_address", DataType::Struct(address_fields), false), + Field::new("customer", DataType::Struct(customer_fields.clone()), false), + Field::new("gift_recipient", DataType::Struct(customer_fields), true), + ])); + + let mut reader = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 5); + + let expected = [ + "+----------+--------------------------------------------+--------------------------------------------+-----------------------------------------------------------------------+-------------------------------------------------------------------+", + "| order_id | shipping_address | billing_address | customer | gift_recipient |", + "+----------+--------------------------------------------+--------------------------------------------+-----------------------------------------------------------------------+-------------------------------------------------------------------+", + "| ord001 | {street: 123 Ship St, city: Shipping City} | {street: 456 Bill Ave, city: Billing Town} | {name: Alice, home_address: {street: 789 Home Rd, city: Home City}} | {name: Bob, home_address: {street: 321 Gift Ln, city: Gift Town}} |", + "| ord002 | {street: 111 Main St, city: Main City} | {street: 111 Main St, city: Main City} | {name: Charlie, home_address: {street: 111 Main St, city: Main City}} | |", + "+----------+--------------------------------------------+--------------------------------------------+-----------------------------------------------------------------------+-------------------------------------------------------------------+" + ]; + assert_batches_eq!(expected, &[batch]); + } } From f454ebbcae49ffb848c48378d33ffe2868aef849 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Wed, 19 Nov 2025 01:25:01 +0200 Subject: [PATCH 2/4] feat: Support Ref schemas in lookup [Avro] --- .../src/avro_to_arrow/arrow_array_reader.rs | 176 +++-- .../src/avro_to_arrow/schema.rs | 669 +++++++++++++++++- 2 files changed, 719 insertions(+), 126 deletions(-) diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index 20128729240c..09a7d7641a55 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -17,10 +17,11 @@ //! Avro to Arrow array readers -use apache_avro::schema::{EnumSchema, FixedSchema, Name, RecordSchema}; use apache_avro::{ error::Details as AvroErrorDetails, - schema::{Schema as AvroSchema, SchemaKind}, + schema::{ + EnumSchema, FixedSchema, Name, RecordSchema, Schema as AvroSchema, SchemaKind, + }, types::Value, Error as AvroError, Reader as AvroReader, }; @@ -41,12 +42,11 @@ use arrow::datatypes::{ }; use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::ArrowError; -use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use datafusion_common::error::{DataFusionError, Result}; -use datafusion_common::{arrow_err, HashMap}; +use datafusion_common::{arrow_err, HashMap, HashSet}; use num_traits::NumCast; use std::collections::BTreeMap; use std::io::Read; @@ -54,9 +54,10 @@ use std::sync::Arc; type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; +#[derive(Default)] pub struct UnresolvedNames { - names_map: HashMap, - to_resolve: HashMap, + in_progress: HashSet, + all_subnames_in_name: HashMap>, } pub struct AvroArrowArrayReader<'a, R: Read> { @@ -66,50 +67,11 @@ pub struct AvroArrowArrayReader<'a, R: Read> { } impl AvroArrowArrayReader<'_, R> { - fn resolve_refs( - unresolved: UnresolvedNames, - schema_lookup: &mut BTreeMap, - ) -> Result<(), ArrowError> { - let UnresolvedNames { - names_map, - to_resolve, - } = unresolved; - - to_resolve.into_iter().try_for_each(|(field_path, name)| { - // Get the original schema that was defiend with the name - let original_schema_location = - names_map.get(&name).ok_or(SchemaError(format!( - "Needed to resolve name {name:?} but it does not exist in schema" - )))?; - - // Get all paths that exit in this schema, but replace the initial path with the field_path we wish to resolve - let resolved = schema_lookup - .iter() - .filter(|(path, _)| { - path.starts_with(original_schema_location) - && *path != original_schema_location - }) - .map(|(path, pos)| { - let resolved_path = - path.replacen(original_schema_location, &field_path, 1); - (resolved_path, *pos) - }) - .collect::>(); - - // Extend our schema lookup with the resolved paths - schema_lookup.extend(resolved); - - Result::<_, ArrowError>::Ok(()) - }) - } - pub fn try_new(reader: R, schema: SchemaRef) -> Result { let reader = AvroReader::new(reader)?; let writer_schema = reader.writer_schema().clone(); - let (mut schema_lookup, unresolved) = Self::schema_lookup(writer_schema)?; - - Self::resolve_refs(unresolved, &mut schema_lookup)?; + let schema_lookup = Self::schema_lookup(writer_schema)?; Ok(Self { reader, @@ -118,9 +80,7 @@ impl AvroArrowArrayReader<'_, R> { }) } - pub fn schema_lookup( - schema: AvroSchema, - ) -> Result<(BTreeMap, UnresolvedNames)> { + pub fn schema_lookup(schema: AvroSchema) -> Result> { match schema { AvroSchema::Record(RecordSchema { fields, @@ -128,11 +88,9 @@ impl AvroArrowArrayReader<'_, R> { name, .. }) => { - // Insert the root into our names map - let mut unresolved_names = UnresolvedNames { - names_map: HashMap::from([(name.clone(), "".to_string())]), - to_resolve: HashMap::new(), - }; + let mut unresolved_names = UnresolvedNames::default(); + unresolved_names.in_progress.insert(name.clone()); + for field in fields { Self::child_schema_lookup( &field.name, @@ -141,9 +99,13 @@ impl AvroArrowArrayReader<'_, R> { &mut unresolved_names, )?; } - Ok((lookup, unresolved_names)) + + unresolved_names.in_progress.remove(&name); + assert!(unresolved_names.in_progress.is_empty()); + + Ok(lookup) } - _ => arrow_err!(SchemaError( + _ => arrow_err!(ArrowError::SchemaError( "expected avro schema to be a record".to_string(), )), } @@ -184,24 +146,36 @@ impl AvroArrowArrayReader<'_, R> { name, .. }) => { - unresolved_names - .names_map - .insert(name.clone(), parent_field_name.to_string()); + let inserted = unresolved_names.in_progress.insert(name.clone()); + if !inserted { + return arrow_err!(ArrowError::SchemaError(format!("Detected circular reference while resolving schema lookup for record schema with name: {name}"))); + } + + let mut inner_lookup = BTreeMap::new(); lookup.iter().for_each(|(field_name, pos)| { - schema_lookup - .insert(format!("{parent_field_name}.{field_name}"), *pos); + inner_lookup.insert(field_name.clone(), *pos); }); for field in fields { - let sub_parent_field_name = - format!("{}.{}", parent_field_name, field.name); Self::child_schema_lookup( - &sub_parent_field_name, + &field.name, &field.schema, - schema_lookup, + &mut inner_lookup, unresolved_names, )?; } + + // Extend the parent schema lookup with the inner lookup entries + schema_lookup.extend(inner_lookup.iter().map(|(lookup_entry, pos)| { + (format!("{parent_field_name}.{lookup_entry}"), *pos) + })); + + // Store the inner lookup for potential references + unresolved_names + .all_subnames_in_name + .insert(name.clone(), inner_lookup); + + unresolved_names.in_progress.remove(name); } AvroSchema::Array(schema) => { Self::child_schema_lookup( @@ -222,18 +196,34 @@ impl AvroArrowArrayReader<'_, R> { } AvroSchema::Fixed(FixedSchema { name, .. }) => { unresolved_names - .names_map - .insert(name.clone(), parent_field_name.to_string()); + .all_subnames_in_name + .insert(name.clone(), BTreeMap::new()); } AvroSchema::Enum(EnumSchema { name, .. }) => { unresolved_names - .names_map - .insert(name.clone(), parent_field_name.to_string()); + .all_subnames_in_name + .insert(name.clone(), BTreeMap::new()); } AvroSchema::Ref { name } => { - unresolved_names - .to_resolve - .insert(parent_field_name.to_string(), name.clone()); + // Detect circular references + if unresolved_names.in_progress.contains(name) { + return arrow_err!(ArrowError::SchemaError(format!("Detected circular reference while resolving schema lookup for record schema with name: {name}"))); + } + + let subnames = unresolved_names + .all_subnames_in_name + .get(name) + .ok_or_else(|| { + AvroError::new(AvroErrorDetails::SchemaResolutionError( + name.clone(), + )) + })?; + + schema_lookup.extend( + subnames + .iter() + .map(|(k, v)| (format!("{parent_field_name}.{k}"), *v)), + ); } _ => (), } @@ -376,7 +366,7 @@ impl AvroArrowArrayReader<'_, R> { ); self.list_array_string_array_builder::(&dtype, col_name, rows) } - ref e => Err(SchemaError(format!( + ref e => Err(ArrowError::SchemaError(format!( "Data type is currently not supported for dictionaries in list : {e}" ))), } @@ -403,7 +393,7 @@ impl AvroArrowArrayReader<'_, R> { Box::new(ListBuilder::new(values_builder)) } e => { - return Err(SchemaError(format!( + return Err(ArrowError::SchemaError(format!( "Nested list data builder type is not supported: {e}" ))) } @@ -426,7 +416,7 @@ impl AvroArrowArrayReader<'_, R> { } else if !matches!(value, Value::Record(_)) { vec![resolve_string(value)?] } else { - return Err(SchemaError( + return Err(ArrowError::SchemaError( "Only scalars are currently supported in Avro arrays".to_string(), )); }; @@ -438,7 +428,7 @@ impl AvroArrowArrayReader<'_, R> { let builder = builder .as_any_mut() .downcast_mut::>() - .ok_or_else(||SchemaError( + .ok_or_else(||ArrowError::SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -453,7 +443,7 @@ impl AvroArrowArrayReader<'_, R> { builder.append(true); } DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||SchemaError( + let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -468,7 +458,7 @@ impl AvroArrowArrayReader<'_, R> { builder.append(true); } e => { - return Err(SchemaError(format!( + return Err(ArrowError::SchemaError(format!( "Nested list data builder type is not supported: {e}" ))) } @@ -537,10 +527,12 @@ impl AvroArrowArrayReader<'_, R> { DataType::UInt64 => { self.build_dictionary_array::(rows, col_name) } - _ => Err(SchemaError("unsupported dictionary key type".to_string())), + _ => Err(ArrowError::SchemaError( + "unsupported dictionary key type".to_string(), + )), } } else { - Err(SchemaError( + Err(ArrowError::SchemaError( "dictionary types other than UTF-8 not yet supported".to_string(), )) } @@ -614,7 +606,7 @@ impl AvroArrowArrayReader<'_, R> { DataType::UInt32 => self.read_primitive_list_values::(rows), DataType::UInt64 => self.read_primitive_list_values::(rows), DataType::Float16 => { - return Err(SchemaError("Float16 not supported".to_string())) + return Err(ArrowError::SchemaError("Float16 not supported".to_string())) } DataType::Float32 => self.read_primitive_list_values::(rows), DataType::Float64 => self.read_primitive_list_values::(rows), @@ -623,7 +615,7 @@ impl AvroArrowArrayReader<'_, R> { | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { - return Err(SchemaError( + return Err(ArrowError::SchemaError( "Temporal types are not yet supported, see ARROW-4803".to_string(), )) } @@ -702,7 +694,7 @@ impl AvroArrowArrayReader<'_, R> { .unwrap() } datatype => { - return Err(SchemaError(format!( + return Err(ArrowError::SchemaError(format!( "Nested list of {datatype} not supported" ))); } @@ -810,7 +802,7 @@ impl AvroArrowArrayReader<'_, R> { &field_path, ), t => { - return Err(SchemaError(format!( + return Err(ArrowError::SchemaError(format!( "TimeUnit {t:?} not supported with Time64" ))) } @@ -824,7 +816,7 @@ impl AvroArrowArrayReader<'_, R> { &field_path, ), t => { - return Err(SchemaError(format!( + return Err(ArrowError::SchemaError(format!( "TimeUnit {t:?} not supported with Time32" ))) } @@ -923,7 +915,7 @@ impl AvroArrowArrayReader<'_, R> { make_array(data) } _ => { - return Err(SchemaError(format!( + return Err(ArrowError::SchemaError(format!( "type {} not supported", field.data_type() ))) @@ -1029,7 +1021,7 @@ fn resolve_string(v: &Value) -> ArrowResult> { Value::Null => Ok(None), other => Err(AvroError::new(AvroErrorDetails::GetString(other.clone()))), } - .map_err(|e| SchemaError(format!("expected resolvable string : {e}"))) + .map_err(|e| ArrowError::SchemaError(format!("expected resolvable string : {e}"))) } fn resolve_u8(v: &Value) -> Option { @@ -1194,7 +1186,7 @@ mod test { let a_array = as_list_array(batch.column(col_id_index)).unwrap(); assert_eq!( *a_array.data_type(), - DataType::List(Arc::new(Field::new("element", DataType::Int64, true))) + DataType::new_list(DataType::Int64, true) ); let array = a_array.value(0); assert_eq!(*array.data_type(), DataType::Int64); @@ -2702,16 +2694,16 @@ mod test { // When Address is used in nullable contexts (like inside nullable Customer), // its fields need to be nullable too let address_fields = Fields::from(vec![ - Field::new("street", DataType::Utf8, true), // Changed to true - Field::new("city", DataType::Utf8, true), // Changed to true + Field::new("street", DataType::Utf8, true), + Field::new("city", DataType::Utf8, true), ]); let customer_fields = Fields::from(vec![ - Field::new("name", DataType::Utf8, true), // Changed to true + Field::new("name", DataType::Utf8, true), Field::new( "home_address", DataType::Struct(address_fields.clone()), - true, // Changed to true + true, ), ]); diff --git a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs index 3fce0d4826a2..2d5714a19a63 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs @@ -23,22 +23,36 @@ use apache_avro::Schema as AvroSchema; use arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode}; use arrow::datatypes::{Field, UnionFields}; use datafusion_common::error::Result; +use datafusion_common::HashSet; use std::collections::HashMap; use std::sync::Arc; +#[derive(Default)] +struct SchemaResolver { + names_lookup: HashMap, + in_progress: HashSet, +} + /// Converts an avro schema to an arrow schema pub fn to_arrow_schema(avro_schema: &apache_avro::Schema) -> Result { let mut schema_fields = vec![]; match avro_schema { - AvroSchema::Record(RecordSchema { fields, .. }) => { + AvroSchema::Record(RecordSchema { fields, name, .. }) => { + let mut resolver = SchemaResolver::default(); + resolver.in_progress.insert(name.clone()); + for field in fields { schema_fields.push(schema_to_field_with_props( &field.schema, Some(&field.name), field.is_nullable(), Some(external_props(&field.schema)), + &mut resolver, )?) } + + // Not really relevant anymore but for correctness + resolver.in_progress.remove(name); } schema => schema_fields.push(schema_to_field(schema, Some(""), false)?), } @@ -52,7 +66,13 @@ fn schema_to_field( name: Option<&str>, nullable: bool, ) -> Result { - schema_to_field_with_props(schema, name, nullable, Default::default()) + schema_to_field_with_props( + schema, + name, + nullable, + Default::default(), + &mut SchemaResolver::default(), + ) } fn schema_to_field_with_props( @@ -60,10 +80,29 @@ fn schema_to_field_with_props( name: Option<&str>, nullable: bool, props: Option>, + resolver: &mut SchemaResolver, ) -> Result { let mut nullable = nullable; let field_type: DataType = match schema { - AvroSchema::Ref { .. } => todo!("Add support for AvroSchema::Ref"), + AvroSchema::Ref { name } => { + // We can't have an infinitely recursing schema in avro, + // so return an error for these kinds of references + if resolver.in_progress.contains(name) { + return Err(apache_avro::Error::new( + apache_avro::error::Details::SchemaResolutionError(name.clone()), + ) + .into()); + } + + if let Some(dt) = resolver.names_lookup.get(name) { + dt.clone() + } else { + return Err(apache_avro::Error::new( + apache_avro::error::Details::SchemaResolutionError(name.clone()), + ) + .into()); + } + } AvroSchema::Null => DataType::Null, AvroSchema::Boolean => DataType::Boolean, AvroSchema::Int => DataType::Int32, @@ -72,15 +111,22 @@ fn schema_to_field_with_props( AvroSchema::Double => DataType::Float64, AvroSchema::Bytes => DataType::Binary, AvroSchema::String => DataType::Utf8, - AvroSchema::Array(item_schema) => DataType::List(Arc::new( - schema_to_field_with_props(&item_schema.items, Some("element"), false, None)?, - )), + AvroSchema::Array(item_schema) => { + DataType::List(Arc::new(schema_to_field_with_props( + &item_schema.items, + Some("item"), + false, + None, + resolver, + )?)) + } AvroSchema::Map(value_schema) => { let value_field = schema_to_field_with_props( &value_schema.types, Some("value"), false, None, + resolver, )?; DataType::Dictionary( Box::new(DataType::Utf8), @@ -103,9 +149,15 @@ fn schema_to_field_with_props( .iter() .find(|&schema| !matches!(schema, AvroSchema::Null)) { - schema_to_field_with_props(schema, None, has_nullable, None)? - .data_type() - .clone() + schema_to_field_with_props( + schema, + None, + has_nullable, + None, + resolver, + )? + .data_type() + .clone() } else { return Err(apache_avro::Error::new( apache_avro::error::Details::GetUnionDuplicate, @@ -115,13 +167,23 @@ fn schema_to_field_with_props( } else { let fields = sub_schemas .iter() - .map(|s| schema_to_field_with_props(s, None, has_nullable, None)) + .map(|s| { + schema_to_field_with_props(s, None, has_nullable, None, resolver) + }) .collect::>>()?; let type_ids = 0_i8..fields.len() as i8; DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) } } - AvroSchema::Record(RecordSchema { fields, .. }) => { + AvroSchema::Record(RecordSchema { fields, name, .. }) => { + let inserted = resolver.in_progress.insert(name.clone()); + if !inserted { + return Err(apache_avro::Error::new( + apache_avro::error::Details::SchemaResolutionError(name.clone()), + ) + .into()); + } + let fields: Result<_> = fields .iter() .map(|field| { @@ -137,14 +199,49 @@ fn schema_to_field_with_props( Some(&field.name), false, Some(props), + resolver, ) }) .collect(); - DataType::Struct(fields?) + + let dtype = DataType::Struct(fields?); + + let previous = resolver.names_lookup.insert(name.clone(), dtype.clone()); + if previous.is_some() { + return Err(apache_avro::Error::new( + apache_avro::error::Details::SchemaResolutionError(name.clone()), + ) + .into()); + } + resolver.in_progress.remove(name); + + dtype + } + AvroSchema::Enum(EnumSchema { name, .. }) => { + let dtype = DataType::Utf8; + + let existing = resolver.names_lookup.insert(name.clone(), dtype.clone()); + if existing.is_some() { + return Err(apache_avro::Error::new( + apache_avro::error::Details::SchemaResolutionError(name.clone()), + ) + .into()); + } + + dtype } - AvroSchema::Enum(EnumSchema { .. }) => DataType::Utf8, - AvroSchema::Fixed(FixedSchema { size, .. }) => { - DataType::FixedSizeBinary(*size as i32) + AvroSchema::Fixed(FixedSchema { size, name, .. }) => { + let dtype = DataType::FixedSizeBinary(*size as i32); + + let existing = resolver.names_lookup.insert(name.clone(), dtype.clone()); + if existing.is_some() { + return Err(apache_avro::Error::new( + apache_avro::error::Details::SchemaResolutionError(name.clone()), + ) + .into()); + } + + dtype } AvroSchema::Decimal(DecimalSchema { precision, scale, .. @@ -314,10 +411,8 @@ mod test { use super::{aliased, external_props, to_arrow_schema}; use apache_avro::schema::{Alias, EnumSchema, FixedSchema, Name, RecordSchema}; use apache_avro::Schema as AvroSchema; - use arrow::datatypes::DataType::{Binary, Float32, Float64, Timestamp, Utf8}; - use arrow::datatypes::DataType::{Boolean, Int32, Int64}; - use arrow::datatypes::TimeUnit::Microsecond; - use arrow::datatypes::{Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion_common::DataFusionError; fn alias(name: &str) -> Alias { Alias::new(name).unwrap() @@ -444,17 +539,21 @@ mod test { let arrow_schema = to_arrow_schema(&schema.unwrap()); assert!(arrow_schema.is_ok(), "{arrow_schema:?}"); let expected = Schema::new(vec![ - Field::new("id", Int32, true), - Field::new("bool_col", Boolean, true), - Field::new("tinyint_col", Int32, true), - Field::new("smallint_col", Int32, true), - Field::new("int_col", Int32, true), - Field::new("bigint_col", Int64, true), - Field::new("float_col", Float32, true), - Field::new("double_col", Float64, true), - Field::new("date_string_col", Binary, true), - Field::new("string_col", Binary, true), - Field::new("timestamp_col", Timestamp(Microsecond, None), true), + Field::new("id", DataType::Int32, true), + Field::new("bool_col", DataType::Boolean, true), + Field::new("tinyint_col", DataType::Int32, true), + Field::new("smallint_col", DataType::Int32, true), + Field::new("int_col", DataType::Int32, true), + Field::new("bigint_col", DataType::Int64, true), + Field::new("float_col", DataType::Float32, true), + Field::new("double_col", DataType::Float64, true), + Field::new("date_string_col", DataType::Binary, true), + Field::new("string_col", DataType::Binary, true), + Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), ]); assert_eq!(arrow_schema.unwrap(), expected); } @@ -496,10 +595,10 @@ mod test { // should not use Avro Record names. let expected_arrow_schema = Schema::new(vec![Field::new( "col1", - arrow::datatypes::DataType::Struct( + DataType::Struct( vec![ - Field::new("col2", Utf8, false), - Field::new("col3", Utf8, true), + Field::new("col2", DataType::Utf8, false), + Field::new("col3", DataType::Utf8, true), ] .into(), ), @@ -517,7 +616,509 @@ mod test { assert!(arrow_schema.is_ok(), "{arrow_schema:?}"); assert_eq!( arrow_schema.unwrap(), - Schema::new(vec![Field::new("", Utf8, false)]) + Schema::new(vec![Field::new("", DataType::Utf8, false)]) ); } + + #[test] + fn test_self_referential_record_rejected() { + // Self-referential schemas should be rejected + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Person", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "friend", + "type": ["null", "Person"] + } + ] + }"#, + ) + .unwrap(); + + let result = to_arrow_schema(&avro_schema); + let DataFusionError::AvroError(err) = result.as_ref().unwrap_err() else { + panic!("Expected AvroError but got {result:?}"); + }; + + let apache_avro::error::Details::SchemaResolutionError(name) = err.details() + else { + panic!("Expected SchemaResolutionError but got {:?}", err.details()); + }; + + assert_eq!(name.name, "Person"); + assert_eq!(name.namespace, None); + } + + #[test] + fn test_mutually_recursive_records_rejected() { + // Mutually recursive schemas should be rejected + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Node", + "fields": [ + { + "name": "value", + "type": "int" + }, + { + "name": "children", + "type": { + "type": "array", + "items": "Node" + } + } + ] + }"#, + ) + .unwrap(); + + let result = to_arrow_schema(&avro_schema); + assert!( + result.is_err(), + "Self-referential array schemas should be rejected" + ); + } + + #[test] + fn test_enum_ref() { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Message", + "fields": [ + { + "name": "priority", + "type": { + "type": "enum", + "name": "Priority", + "symbols": ["LOW", "MEDIUM", "HIGH"] + } + }, + { + "name": "fallback_priority", + "type": ["null", "Priority"] + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 2); + assert_eq!(arrow_schema.field(0).data_type(), &DataType::Utf8); + assert_eq!(arrow_schema.field(1).data_type(), &DataType::Utf8); + assert!(arrow_schema.field(1).is_nullable()); + } + + #[test] + fn test_fixed_ref() { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "HashRecord", + "fields": [ + { + "name": "primary_hash", + "type": { + "type": "fixed", + "name": "MD5", + "size": 16 + } + }, + { + "name": "secondary_hash", + "type": ["null", "MD5"] + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 2); + assert_eq!( + arrow_schema.field(0).data_type(), + &DataType::FixedSizeBinary(16) + ); + assert_eq!( + arrow_schema.field(1).data_type(), + &DataType::FixedSizeBinary(16) + ); + assert!(arrow_schema.field(1).is_nullable()); + } + + #[test] + fn test_multiple_refs_same_type() { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Container", + "fields": [ + { + "name": "status1", + "type": { + "type": "enum", + "name": "Status", + "symbols": ["ACTIVE", "INACTIVE"] + } + }, + { + "name": "status2", + "type": "Status" + }, + { + "name": "status3", + "type": ["null", "Status"] + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 3); + assert_eq!(arrow_schema.field(0).data_type(), &DataType::Utf8); + assert_eq!(arrow_schema.field(1).data_type(), &DataType::Utf8); + assert_eq!(arrow_schema.field(2).data_type(), &DataType::Utf8); + assert!(arrow_schema.field(2).is_nullable()); + } + + #[test] + fn test_non_recursive_nested_record_ref() { + // Non-recursive nested records with refs should work + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Outer", + "fields": [ + { + "name": "id_type", + "type": { + "type": "fixed", + "name": "UUID", + "size": 16 + } + }, + { + "name": "nested", + "type": { + "type": "record", + "name": "Inner", + "fields": [ + { + "name": "inner_id", + "type": "UUID" + } + ] + } + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 2); + assert_eq!( + arrow_schema.field(0).data_type(), + &DataType::FixedSizeBinary(16) + ); + + if let DataType::Struct(fields) = arrow_schema.field(1).data_type() { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].data_type(), &DataType::FixedSizeBinary(16)); + } else { + panic!("Expected Struct type for nested field"); + } + } + + #[test] + fn test_ref_in_array() { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "ArrayContainer", + "fields": [ + { + "name": "priority_def", + "type": { + "type": "enum", + "name": "Priority", + "symbols": ["LOW", "HIGH"] + } + }, + { + "name": "priorities", + "type": { + "type": "array", + "items": "Priority" + } + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 2); + + if let DataType::List(item_field) = arrow_schema.field(1).data_type() { + assert_eq!(item_field.data_type(), &DataType::Utf8); + } else { + panic!("Expected List type for array field"); + } + } + + #[test] + fn test_invalid_ref() { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "BadRecord", + "fields": [ + { + "name": "bad_ref", + "type": "NonExistentType" + } + ] + }"#, + ); + + // This should either fail during Avro parsing or during Arrow conversion + assert!(avro_schema.is_err() || to_arrow_schema(&avro_schema.unwrap()).is_err()); + } + + #[test] + fn test_namespaced_ref() { + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Container", + "namespace": "com.example", + "fields": [ + { + "name": "status_def", + "type": { + "type": "enum", + "name": "Status", + "namespace": "com.example.types", + "symbols": ["OK", "ERROR"] + } + }, + { + "name": "status_ref", + "type": "com.example.types.Status" + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 2); + assert_eq!(arrow_schema.field(0).data_type(), &DataType::Utf8); + assert_eq!(arrow_schema.field(1).data_type(), &DataType::Utf8); + } + + #[test] + fn test_complex_non_recursive_ref_graph() { + // Multiple types referencing each other without cycles + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Root", + "fields": [ + { + "name": "hash_type", + "type": { + "type": "fixed", + "name": "Hash", + "size": 32 + } + }, + { + "name": "status_type", + "type": { + "type": "enum", + "name": "Status", + "symbols": ["PENDING", "COMPLETE"] + } + }, + { + "name": "data", + "type": { + "type": "record", + "name": "Data", + "fields": [ + { + "name": "id", + "type": "Hash" + }, + { + "name": "state", + "type": "Status" + } + ] + } + }, + { + "name": "backup_hash", + "type": ["null", "Hash"] + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 4); + assert_eq!( + arrow_schema.field(0).data_type(), + &DataType::FixedSizeBinary(32) + ); + assert_eq!(arrow_schema.field(1).data_type(), &DataType::Utf8); + + if let DataType::Struct(fields) = arrow_schema.field(2).data_type() { + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].data_type(), &DataType::FixedSizeBinary(32)); + assert_eq!(fields[1].data_type(), &DataType::Utf8); + } else { + panic!("Expected Struct type for data field"); + } + + assert_eq!( + arrow_schema.field(3).data_type(), + &DataType::FixedSizeBinary(32) + ); + assert!(arrow_schema.field(3).is_nullable()); + } + + #[test] + fn test_duplicate_type_name_rejected() { + // Defining the same named type twice should be rejected + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Container", + "fields": [ + { + "name": "first", + "type": { + "type": "enum", + "name": "Status", + "symbols": ["OK"] + } + }, + { + "name": "second", + "type": { + "type": "enum", + "name": "Status", + "symbols": ["ERROR"] + } + } + ] + }"#, + ); + + // Avro parser itself should reject this, or our converter should + assert!(avro_schema.is_err() || to_arrow_schema(&avro_schema.unwrap()).is_err()); + } + + #[test] + fn test_deeply_nested_ref_chain() { + // Test a chain of references without cycles + let avro_schema = AvroSchema::parse_str( + r#" + { + "type": "record", + "name": "Level1", + "fields": [ + { + "name": "id_type", + "type": { + "type": "fixed", + "name": "ID", + "size": 8 + } + }, + { + "name": "level2", + "type": { + "type": "record", + "name": "Level2", + "fields": [ + { + "name": "id", + "type": "ID" + }, + { + "name": "level3", + "type": { + "type": "record", + "name": "Level3", + "fields": [ + { + "name": "id", + "type": "ID" + } + ] + } + } + ] + } + } + ] + }"#, + ) + .unwrap(); + + let arrow_schema = to_arrow_schema(&avro_schema).unwrap(); + + assert_eq!(arrow_schema.fields().len(), 2); + assert_eq!( + arrow_schema.field(0).data_type(), + &DataType::FixedSizeBinary(8) + ); + + // Verify the nested structure + if let DataType::Struct(level2_fields) = arrow_schema.field(1).data_type() { + assert_eq!(level2_fields.len(), 2); + assert_eq!(level2_fields[0].data_type(), &DataType::FixedSizeBinary(8)); + + if let DataType::Struct(level3_fields) = level2_fields[1].data_type() { + assert_eq!(level3_fields.len(), 1); + assert_eq!(level3_fields[0].data_type(), &DataType::FixedSizeBinary(8)); + } else { + panic!("Expected Struct type for level3"); + } + } else { + panic!("Expected Struct type for level2"); + } + } } From 0948f9790f44b2118cda90497d65403c5eee2506 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Wed, 19 Nov 2025 01:36:00 +0200 Subject: [PATCH 3/4] rename struct --- .../datasource-avro/src/avro_to_arrow/arrow_array_reader.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index 09a7d7641a55..dec44d5e8aa9 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -55,7 +55,7 @@ use std::sync::Arc; type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; #[derive(Default)] -pub struct UnresolvedNames { +pub struct NamesResolver { in_progress: HashSet, all_subnames_in_name: HashMap>, } @@ -88,7 +88,7 @@ impl AvroArrowArrayReader<'_, R> { name, .. }) => { - let mut unresolved_names = UnresolvedNames::default(); + let mut unresolved_names = NamesResolver::default(); unresolved_names.in_progress.insert(name.clone()); for field in fields { @@ -115,7 +115,7 @@ impl AvroArrowArrayReader<'_, R> { parent_field_name: &str, schema: &AvroSchema, schema_lookup: &'b mut BTreeMap, - unresolved_names: &'b mut UnresolvedNames, + unresolved_names: &'b mut NamesResolver, ) -> Result<&'b BTreeMap> { match schema { AvroSchema::Union(us) => { From ea3d958226823c417547e286fcc50902f5a198f4 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Wed, 19 Nov 2025 10:58:45 +0200 Subject: [PATCH 4/4] address CR --- .../datasource-avro/src/avro_to_arrow/arrow_array_reader.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index dec44d5e8aa9..57d26a4d4751 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -55,7 +55,7 @@ use std::sync::Arc; type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; #[derive(Default)] -pub struct NamesResolver { +struct NamesResolver { in_progress: HashSet, all_subnames_in_name: HashMap>, } @@ -101,7 +101,7 @@ impl AvroArrowArrayReader<'_, R> { } unresolved_names.in_progress.remove(&name); - assert!(unresolved_names.in_progress.is_empty()); + debug_assert!(unresolved_names.in_progress.is_empty()); Ok(lookup) }