@@ -31,6 +31,9 @@ use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, Struct, St
3131use crate :: transform:: { BoxedTransformFunction , create_transform_function} ;
3232use crate :: { Error , ErrorKind , Result } ;
3333
34+ /// Column name for the projected partition values struct
35+ pub const PROJECTED_PARTITION_VALUE_COLUMN : & str = "_partition" ;
36+
3437/// The splitter used to split the record batch into multiple record batches by the partition spec.
3538/// 1. It will project and transform the input record batch based on the partition spec, get the partitioned record batch.
3639/// 2. Split the input record batch into multiple record batches based on the partitioned record batch.
@@ -40,11 +43,12 @@ use crate::{Error, ErrorKind, Result};
4043pub struct RecordBatchPartitionSplitter {
4144 schema : SchemaRef ,
4245 partition_spec : PartitionSpecRef ,
43- projector : RecordBatchProjector ,
46+ projector : Option < RecordBatchProjector > ,
4447 transform_functions : Vec < BoxedTransformFunction > ,
4548
4649 partition_type : StructType ,
4750 partition_arrow_type : DataType ,
51+ has_partition_column : bool ,
4852}
4953
5054// # TODO
@@ -58,6 +62,7 @@ impl RecordBatchPartitionSplitter {
5862 /// * `input_schema` - The Arrow schema of the input record batches
5963 /// * `iceberg_schema` - The Iceberg schema reference
6064 /// * `partition_spec` - The partition specification reference
65+ /// * `has_partition_column` - If true, expects a pre-computed partition column in the input batch
6166 ///
6267 /// # Returns
6368 ///
@@ -66,47 +71,55 @@ impl RecordBatchPartitionSplitter {
6671 input_schema : ArrowSchemaRef ,
6772 iceberg_schema : SchemaRef ,
6873 partition_spec : PartitionSpecRef ,
74+ has_partition_column : bool ,
6975 ) -> Result < Self > {
70- let projector = RecordBatchProjector :: new (
71- input_schema,
72- & partition_spec
73- . fields ( )
74- . iter ( )
75- . map ( |field| field. source_id )
76- . collect :: < Vec < _ > > ( ) ,
77- // The source columns, selected by ids, must be a primitive type and cannot be contained in a map or list, but may be nested in a struct.
78- // ref: https://iceberg.apache.org/spec/#partitioning
79- |field| {
80- if !field. data_type ( ) . is_primitive ( ) {
81- return Ok ( None ) ;
82- }
83- field
84- . metadata ( )
85- . get ( PARQUET_FIELD_ID_META_KEY )
86- . map ( |s| {
87- s. parse :: < i64 > ( )
88- . map_err ( |e| Error :: new ( ErrorKind :: Unexpected , e. to_string ( ) ) )
89- } )
90- . transpose ( )
91- } ,
92- |_| true ,
93- ) ?;
94- let transform_functions = partition_spec
95- . fields ( )
96- . iter ( )
97- . map ( |field| create_transform_function ( & field. transform ) )
98- . collect :: < Result < Vec < _ > > > ( ) ?;
99-
10076 let partition_type = partition_spec. partition_type ( & iceberg_schema) ?;
10177 let partition_arrow_type = type_to_arrow_type ( & Type :: Struct ( partition_type. clone ( ) ) ) ?;
10278
79+ let ( projector, transform_functions) = if has_partition_column {
80+ // Skip projector and transform initialization when partition column is pre-computed
81+ ( None , Vec :: new ( ) )
82+ } else {
83+ let projector = RecordBatchProjector :: new (
84+ input_schema,
85+ & partition_spec
86+ . fields ( )
87+ . iter ( )
88+ . map ( |field| field. source_id )
89+ . collect :: < Vec < _ > > ( ) ,
90+ // The source columns, selected by ids, must be a primitive type and cannot be contained in a map or list, but may be nested in a struct.
91+ // ref: https://iceberg.apache.org/spec/#partitioning
92+ |field| {
93+ if !field. data_type ( ) . is_primitive ( ) {
94+ return Ok ( None ) ;
95+ }
96+ field
97+ . metadata ( )
98+ . get ( PARQUET_FIELD_ID_META_KEY )
99+ . map ( |s| {
100+ s. parse :: < i64 > ( )
101+ . map_err ( |e| Error :: new ( ErrorKind :: Unexpected , e. to_string ( ) ) )
102+ } )
103+ . transpose ( )
104+ } ,
105+ |_| true ,
106+ ) ?;
107+ let transform_functions = partition_spec
108+ . fields ( )
109+ . iter ( )
110+ . map ( |field| create_transform_function ( & field. transform ) )
111+ . collect :: < Result < Vec < _ > > > ( ) ?;
112+ ( Some ( projector) , transform_functions)
113+ } ;
114+
103115 Ok ( Self {
104116 schema : iceberg_schema,
105117 partition_spec,
106118 projector,
107119 transform_functions,
108120 partition_type,
109121 partition_arrow_type,
122+ has_partition_column,
110123 } )
111124 }
112125
@@ -153,14 +166,66 @@ impl RecordBatchPartitionSplitter {
153166
154167 /// Split the record batch into multiple record batches based on the partition spec.
155168 pub fn split ( & self , batch : & RecordBatch ) -> Result < Vec < ( PartitionKey , RecordBatch ) > > {
156- let source_columns = self . projector . project_column ( batch. columns ( ) ) ?;
157- let partition_columns = source_columns
158- . into_iter ( )
159- . zip_eq ( self . transform_functions . iter ( ) )
160- . map ( |( source_column, transform_function) | transform_function. transform ( source_column) )
161- . collect :: < Result < Vec < _ > > > ( ) ?;
169+ let partition_structs = if self . has_partition_column {
170+ // Extract partition values from pre-computed partition column
171+ let partition_column = batch
172+ . column_by_name ( PROJECTED_PARTITION_VALUE_COLUMN )
173+ . ok_or_else ( || {
174+ Error :: new (
175+ ErrorKind :: DataInvalid ,
176+ format ! (
177+ "Partition column '{}' not found in batch" ,
178+ PROJECTED_PARTITION_VALUE_COLUMN
179+ ) ,
180+ )
181+ } ) ?;
182+
183+ let partition_struct_array = partition_column
184+ . as_any ( )
185+ . downcast_ref :: < StructArray > ( )
186+ . ok_or_else ( || {
187+ Error :: new (
188+ ErrorKind :: DataInvalid ,
189+ "Partition column is not a StructArray" ,
190+ )
191+ } ) ?;
192+
193+ let arrow_struct_array = Arc :: new ( partition_struct_array. clone ( ) ) as ArrayRef ;
194+ let struct_array = arrow_struct_to_literal ( & arrow_struct_array, & self . partition_type ) ?;
195+
196+ struct_array
197+ . into_iter ( )
198+ . map ( |s| {
199+ if let Some ( Literal :: Struct ( s) ) = s {
200+ Ok ( s)
201+ } else {
202+ Err ( Error :: new (
203+ ErrorKind :: DataInvalid ,
204+ "Partition value is not a struct literal or is null" ,
205+ ) )
206+ }
207+ } )
208+ . collect :: < Result < Vec < _ > > > ( ) ?
209+ } else {
210+ // Compute partition values from source columns
211+ let projector = self . projector . as_ref ( ) . ok_or_else ( || {
212+ Error :: new (
213+ ErrorKind :: DataInvalid ,
214+ "Projector not initialized for non-partition-column mode" ,
215+ )
216+ } ) ?;
217+
218+ let source_columns = projector. project_column ( batch. columns ( ) ) ?;
219+ let partition_columns = source_columns
220+ . into_iter ( )
221+ . zip_eq ( self . transform_functions . iter ( ) )
222+ . map ( |( source_column, transform_function) | {
223+ transform_function. transform ( source_column)
224+ } )
225+ . collect :: < Result < Vec < _ > > > ( ) ?;
162226
163- let partition_structs = self . partition_columns_to_struct ( partition_columns) ?;
227+ self . partition_columns_to_struct ( partition_columns) ?
228+ } ;
164229
165230 // Group the batch by row value.
166231 let mut group_ids = HashMap :: new ( ) ;
@@ -246,9 +311,13 @@ mod tests {
246311 . unwrap ( ) ,
247312 ) ;
248313 let input_schema = Arc :: new ( schema_to_arrow_schema ( & schema) . unwrap ( ) ) ;
249- let partition_splitter =
250- RecordBatchPartitionSplitter :: new ( input_schema. clone ( ) , schema. clone ( ) , partition_spec)
251- . expect ( "Failed to create splitter" ) ;
314+ let partition_splitter = RecordBatchPartitionSplitter :: new (
315+ input_schema. clone ( ) ,
316+ schema. clone ( ) ,
317+ partition_spec,
318+ false ,
319+ )
320+ . expect ( "Failed to create splitter" ) ;
252321
253322 let id_array = Int32Array :: from ( vec ! [ 1 , 2 , 1 , 3 , 2 , 3 , 1 ] ) ;
254323 let data_array = StringArray :: from ( vec ! [ "a" , "b" , "c" , "d" , "e" , "f" , "g" ] ) ;
@@ -319,4 +388,119 @@ mod tests {
319388 Struct :: from_iter( vec![ Some ( Literal :: int( 3 ) ) ] ) ,
320389 ] ) ;
321390 }
391+
392+ #[ test]
393+ fn test_record_batch_partition_split_with_partition_column ( ) {
394+ use arrow_array:: StructArray ;
395+ use arrow_schema:: { Field , Schema as ArrowSchema } ;
396+
397+ let schema = Arc :: new (
398+ Schema :: builder ( )
399+ . with_fields ( vec ! [
400+ NestedField :: required(
401+ 1 ,
402+ "id" ,
403+ Type :: Primitive ( crate :: spec:: PrimitiveType :: Int ) ,
404+ )
405+ . into( ) ,
406+ NestedField :: required(
407+ 2 ,
408+ "name" ,
409+ Type :: Primitive ( crate :: spec:: PrimitiveType :: String ) ,
410+ )
411+ . into( ) ,
412+ ] )
413+ . build ( )
414+ . unwrap ( ) ,
415+ ) ;
416+ let partition_spec = Arc :: new (
417+ PartitionSpecBuilder :: new ( schema. clone ( ) )
418+ . with_spec_id ( 1 )
419+ . add_unbound_field ( UnboundPartitionField {
420+ source_id : 1 ,
421+ field_id : None ,
422+ name : "id_bucket" . to_string ( ) ,
423+ transform : Transform :: Identity ,
424+ } )
425+ . unwrap ( )
426+ . build ( )
427+ . unwrap ( ) ,
428+ ) ;
429+
430+ // Create input schema with _partition column
431+ // Note: partition field IDs start from 1000 by default
432+ let partition_field = Field :: new ( "id_bucket" , DataType :: Int32 , false ) . with_metadata (
433+ HashMap :: from ( [ ( PARQUET_FIELD_ID_META_KEY . to_string ( ) , "1000" . to_string ( ) ) ] ) ,
434+ ) ;
435+ let partition_struct_field = Field :: new (
436+ PROJECTED_PARTITION_VALUE_COLUMN ,
437+ DataType :: Struct ( vec ! [ partition_field. clone( ) ] . into ( ) ) ,
438+ false ,
439+ ) ;
440+
441+ let input_schema = Arc :: new ( ArrowSchema :: new ( vec ! [
442+ Field :: new( "id" , DataType :: Int32 , false ) ,
443+ Field :: new( "name" , DataType :: Utf8 , false ) ,
444+ partition_struct_field,
445+ ] ) ) ;
446+
447+ // Create splitter with has_partition_column=true
448+ let partition_splitter = RecordBatchPartitionSplitter :: new (
449+ input_schema. clone ( ) ,
450+ schema. clone ( ) ,
451+ partition_spec,
452+ true ,
453+ )
454+ . expect ( "Failed to create splitter" ) ;
455+
456+ // Create test data with pre-computed partition column
457+ let id_array = Int32Array :: from ( vec ! [ 1 , 2 , 1 , 3 , 2 , 3 , 1 ] ) ;
458+ let data_array = StringArray :: from ( vec ! [ "a" , "b" , "c" , "d" , "e" , "f" , "g" ] ) ;
459+
460+ // Create partition column (same values as id for Identity transform)
461+ let partition_values = Int32Array :: from ( vec ! [ 1 , 2 , 1 , 3 , 2 , 3 , 1 ] ) ;
462+ let partition_struct = StructArray :: from ( vec ! [ (
463+ Arc :: new( partition_field) ,
464+ Arc :: new( partition_values) as ArrayRef ,
465+ ) ] ) ;
466+
467+ let batch = RecordBatch :: try_new ( input_schema. clone ( ) , vec ! [
468+ Arc :: new( id_array) ,
469+ Arc :: new( data_array) ,
470+ Arc :: new( partition_struct) ,
471+ ] )
472+ . expect ( "Failed to create RecordBatch" ) ;
473+
474+ // Split using the pre-computed partition column
475+ let mut partitioned_batches = partition_splitter
476+ . split ( & batch)
477+ . expect ( "Failed to split RecordBatch" ) ;
478+
479+ partitioned_batches. sort_by_key ( |( partition_key, _) | {
480+ if let PrimitiveLiteral :: Int ( i) = partition_key. data ( ) . fields ( ) [ 0 ]
481+ . as_ref ( )
482+ . unwrap ( )
483+ . as_primitive_literal ( )
484+ . unwrap ( )
485+ {
486+ i
487+ } else {
488+ panic ! ( "The partition value is not a int" ) ;
489+ }
490+ } ) ;
491+
492+ assert_eq ! ( partitioned_batches. len( ) , 3 ) ;
493+
494+ // Verify partition values
495+ let partition_values = partitioned_batches
496+ . iter ( )
497+ . map ( |( partition_key, _) | partition_key. data ( ) . clone ( ) )
498+ . collect :: < Vec < _ > > ( ) ;
499+
500+ assert_eq ! ( partition_values, vec![
501+ Struct :: from_iter( vec![ Some ( Literal :: int( 1 ) ) ] ) ,
502+ Struct :: from_iter( vec![ Some ( Literal :: int( 2 ) ) ] ) ,
503+ Struct :: from_iter( vec![ Some ( Literal :: int( 3 ) ) ] ) ,
504+ ] ) ;
505+ }
322506}
0 commit comments