Skip to content

Commit f4b72ef

Browse files
committed
add a flag in splitter to skip projection
1 parent 2cd60b0 commit f4b72ef

File tree

3 files changed

+230
-46
lines changed

3 files changed

+230
-46
lines changed

crates/iceberg/src/arrow/record_batch_partition_splitter.rs

Lines changed: 225 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, Struct, St
3131
use crate::transform::{BoxedTransformFunction, create_transform_function};
3232
use crate::{Error, ErrorKind, Result};
3333

34+
/// Column name for the projected partition values struct
35+
pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition";
36+
3437
/// The splitter used to split the record batch into multiple record batches by the partition spec.
3538
/// 1. It will project and transform the input record batch based on the partition spec, get the partitioned record batch.
3639
/// 2. Split the input record batch into multiple record batches based on the partitioned record batch.
@@ -40,11 +43,12 @@ use crate::{Error, ErrorKind, Result};
4043
pub struct RecordBatchPartitionSplitter {
4144
schema: SchemaRef,
4245
partition_spec: PartitionSpecRef,
43-
projector: RecordBatchProjector,
46+
projector: Option<RecordBatchProjector>,
4447
transform_functions: Vec<BoxedTransformFunction>,
4548

4649
partition_type: StructType,
4750
partition_arrow_type: DataType,
51+
has_partition_column: bool,
4852
}
4953

5054
// # TODO
@@ -58,6 +62,7 @@ impl RecordBatchPartitionSplitter {
5862
/// * `input_schema` - The Arrow schema of the input record batches
5963
/// * `iceberg_schema` - The Iceberg schema reference
6064
/// * `partition_spec` - The partition specification reference
65+
/// * `has_partition_column` - If true, expects a pre-computed partition column in the input batch
6166
///
6267
/// # Returns
6368
///
@@ -66,47 +71,55 @@ impl RecordBatchPartitionSplitter {
6671
input_schema: ArrowSchemaRef,
6772
iceberg_schema: SchemaRef,
6873
partition_spec: PartitionSpecRef,
74+
has_partition_column: bool,
6975
) -> Result<Self> {
70-
let projector = RecordBatchProjector::new(
71-
input_schema,
72-
&partition_spec
73-
.fields()
74-
.iter()
75-
.map(|field| field.source_id)
76-
.collect::<Vec<_>>(),
77-
// The source columns, selected by ids, must be a primitive type and cannot be contained in a map or list, but may be nested in a struct.
78-
// ref: https://iceberg.apache.org/spec/#partitioning
79-
|field| {
80-
if !field.data_type().is_primitive() {
81-
return Ok(None);
82-
}
83-
field
84-
.metadata()
85-
.get(PARQUET_FIELD_ID_META_KEY)
86-
.map(|s| {
87-
s.parse::<i64>()
88-
.map_err(|e| Error::new(ErrorKind::Unexpected, e.to_string()))
89-
})
90-
.transpose()
91-
},
92-
|_| true,
93-
)?;
94-
let transform_functions = partition_spec
95-
.fields()
96-
.iter()
97-
.map(|field| create_transform_function(&field.transform))
98-
.collect::<Result<Vec<_>>>()?;
99-
10076
let partition_type = partition_spec.partition_type(&iceberg_schema)?;
10177
let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
10278

79+
let (projector, transform_functions) = if has_partition_column {
80+
// Skip projector and transform initialization when partition column is pre-computed
81+
(None, Vec::new())
82+
} else {
83+
let projector = RecordBatchProjector::new(
84+
input_schema,
85+
&partition_spec
86+
.fields()
87+
.iter()
88+
.map(|field| field.source_id)
89+
.collect::<Vec<_>>(),
90+
// The source columns, selected by ids, must be a primitive type and cannot be contained in a map or list, but may be nested in a struct.
91+
// ref: https://iceberg.apache.org/spec/#partitioning
92+
|field| {
93+
if !field.data_type().is_primitive() {
94+
return Ok(None);
95+
}
96+
field
97+
.metadata()
98+
.get(PARQUET_FIELD_ID_META_KEY)
99+
.map(|s| {
100+
s.parse::<i64>()
101+
.map_err(|e| Error::new(ErrorKind::Unexpected, e.to_string()))
102+
})
103+
.transpose()
104+
},
105+
|_| true,
106+
)?;
107+
let transform_functions = partition_spec
108+
.fields()
109+
.iter()
110+
.map(|field| create_transform_function(&field.transform))
111+
.collect::<Result<Vec<_>>>()?;
112+
(Some(projector), transform_functions)
113+
};
114+
103115
Ok(Self {
104116
schema: iceberg_schema,
105117
partition_spec,
106118
projector,
107119
transform_functions,
108120
partition_type,
109121
partition_arrow_type,
122+
has_partition_column,
110123
})
111124
}
112125

@@ -153,14 +166,66 @@ impl RecordBatchPartitionSplitter {
153166

154167
/// Split the record batch into multiple record batches based on the partition spec.
155168
pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(PartitionKey, RecordBatch)>> {
156-
let source_columns = self.projector.project_column(batch.columns())?;
157-
let partition_columns = source_columns
158-
.into_iter()
159-
.zip_eq(self.transform_functions.iter())
160-
.map(|(source_column, transform_function)| transform_function.transform(source_column))
161-
.collect::<Result<Vec<_>>>()?;
169+
let partition_structs = if self.has_partition_column {
170+
// Extract partition values from pre-computed partition column
171+
let partition_column = batch
172+
.column_by_name(PROJECTED_PARTITION_VALUE_COLUMN)
173+
.ok_or_else(|| {
174+
Error::new(
175+
ErrorKind::DataInvalid,
176+
format!(
177+
"Partition column '{}' not found in batch",
178+
PROJECTED_PARTITION_VALUE_COLUMN
179+
),
180+
)
181+
})?;
182+
183+
let partition_struct_array = partition_column
184+
.as_any()
185+
.downcast_ref::<StructArray>()
186+
.ok_or_else(|| {
187+
Error::new(
188+
ErrorKind::DataInvalid,
189+
"Partition column is not a StructArray",
190+
)
191+
})?;
192+
193+
let arrow_struct_array = Arc::new(partition_struct_array.clone()) as ArrayRef;
194+
let struct_array = arrow_struct_to_literal(&arrow_struct_array, &self.partition_type)?;
195+
196+
struct_array
197+
.into_iter()
198+
.map(|s| {
199+
if let Some(Literal::Struct(s)) = s {
200+
Ok(s)
201+
} else {
202+
Err(Error::new(
203+
ErrorKind::DataInvalid,
204+
"Partition value is not a struct literal or is null",
205+
))
206+
}
207+
})
208+
.collect::<Result<Vec<_>>>()?
209+
} else {
210+
// Compute partition values from source columns
211+
let projector = self.projector.as_ref().ok_or_else(|| {
212+
Error::new(
213+
ErrorKind::DataInvalid,
214+
"Projector not initialized for non-partition-column mode",
215+
)
216+
})?;
217+
218+
let source_columns = projector.project_column(batch.columns())?;
219+
let partition_columns = source_columns
220+
.into_iter()
221+
.zip_eq(self.transform_functions.iter())
222+
.map(|(source_column, transform_function)| {
223+
transform_function.transform(source_column)
224+
})
225+
.collect::<Result<Vec<_>>>()?;
162226

163-
let partition_structs = self.partition_columns_to_struct(partition_columns)?;
227+
self.partition_columns_to_struct(partition_columns)?
228+
};
164229

165230
// Group the batch by row value.
166231
let mut group_ids = HashMap::new();
@@ -246,9 +311,13 @@ mod tests {
246311
.unwrap(),
247312
);
248313
let input_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
249-
let partition_splitter =
250-
RecordBatchPartitionSplitter::new(input_schema.clone(), schema.clone(), partition_spec)
251-
.expect("Failed to create splitter");
314+
let partition_splitter = RecordBatchPartitionSplitter::new(
315+
input_schema.clone(),
316+
schema.clone(),
317+
partition_spec,
318+
false,
319+
)
320+
.expect("Failed to create splitter");
252321

253322
let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
254323
let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]);
@@ -319,4 +388,119 @@ mod tests {
319388
Struct::from_iter(vec![Some(Literal::int(3))]),
320389
]);
321390
}
391+
392+
#[test]
393+
fn test_record_batch_partition_split_with_partition_column() {
394+
use arrow_array::StructArray;
395+
use arrow_schema::{Field, Schema as ArrowSchema};
396+
397+
let schema = Arc::new(
398+
Schema::builder()
399+
.with_fields(vec![
400+
NestedField::required(
401+
1,
402+
"id",
403+
Type::Primitive(crate::spec::PrimitiveType::Int),
404+
)
405+
.into(),
406+
NestedField::required(
407+
2,
408+
"name",
409+
Type::Primitive(crate::spec::PrimitiveType::String),
410+
)
411+
.into(),
412+
])
413+
.build()
414+
.unwrap(),
415+
);
416+
let partition_spec = Arc::new(
417+
PartitionSpecBuilder::new(schema.clone())
418+
.with_spec_id(1)
419+
.add_unbound_field(UnboundPartitionField {
420+
source_id: 1,
421+
field_id: None,
422+
name: "id_bucket".to_string(),
423+
transform: Transform::Identity,
424+
})
425+
.unwrap()
426+
.build()
427+
.unwrap(),
428+
);
429+
430+
// Create input schema with _partition column
431+
// Note: partition field IDs start from 1000 by default
432+
let partition_field = Field::new("id_bucket", DataType::Int32, false).with_metadata(
433+
HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1000".to_string())]),
434+
);
435+
let partition_struct_field = Field::new(
436+
PROJECTED_PARTITION_VALUE_COLUMN,
437+
DataType::Struct(vec![partition_field.clone()].into()),
438+
false,
439+
);
440+
441+
let input_schema = Arc::new(ArrowSchema::new(vec![
442+
Field::new("id", DataType::Int32, false),
443+
Field::new("name", DataType::Utf8, false),
444+
partition_struct_field,
445+
]));
446+
447+
// Create splitter with has_partition_column=true
448+
let partition_splitter = RecordBatchPartitionSplitter::new(
449+
input_schema.clone(),
450+
schema.clone(),
451+
partition_spec,
452+
true,
453+
)
454+
.expect("Failed to create splitter");
455+
456+
// Create test data with pre-computed partition column
457+
let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
458+
let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]);
459+
460+
// Create partition column (same values as id for Identity transform)
461+
let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
462+
let partition_struct = StructArray::from(vec![(
463+
Arc::new(partition_field),
464+
Arc::new(partition_values) as ArrayRef,
465+
)]);
466+
467+
let batch = RecordBatch::try_new(input_schema.clone(), vec![
468+
Arc::new(id_array),
469+
Arc::new(data_array),
470+
Arc::new(partition_struct),
471+
])
472+
.expect("Failed to create RecordBatch");
473+
474+
// Split using the pre-computed partition column
475+
let mut partitioned_batches = partition_splitter
476+
.split(&batch)
477+
.expect("Failed to split RecordBatch");
478+
479+
partitioned_batches.sort_by_key(|(partition_key, _)| {
480+
if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
481+
.as_ref()
482+
.unwrap()
483+
.as_primitive_literal()
484+
.unwrap()
485+
{
486+
i
487+
} else {
488+
panic!("The partition value is not a int");
489+
}
490+
});
491+
492+
assert_eq!(partitioned_batches.len(), 3);
493+
494+
// Verify partition values
495+
let partition_values = partitioned_batches
496+
.iter()
497+
.map(|(partition_key, _)| partition_key.data().clone())
498+
.collect::<Vec<_>>();
499+
500+
assert_eq!(partition_values, vec![
501+
Struct::from_iter(vec![Some(Literal::int(1))]),
502+
Struct::from_iter(vec![Some(Literal::int(2))]),
503+
Struct::from_iter(vec![Some(Literal::int(3))]),
504+
]);
505+
}
322506
}

crates/iceberg/src/writer/task/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ mod tests {
353353
arrow_schema.clone(),
354354
schema.clone(),
355355
partition_spec.clone(),
356+
false,
356357
)?;
357358

358359
// Create DefaultTaskWriter with FanoutWriter and splitter
@@ -451,6 +452,7 @@ mod tests {
451452
arrow_schema.clone(),
452453
schema.clone(),
453454
partition_spec.clone(),
455+
false,
454456
)?;
455457

456458
// Create DefaultTaskWriter with ClusteredWriter and splitter

crates/integrations/datafusion/src/physical_plan/project.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@ use datafusion::physical_expr::PhysicalExpr;
2727
use datafusion::physical_expr::expressions::Column;
2828
use datafusion::physical_plan::projection::ProjectionExec;
2929
use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
30+
use iceberg::arrow::PROJECTED_PARTITION_VALUE_COLUMN;
3031
use iceberg::arrow::record_batch_projector::RecordBatchProjector;
3132
use iceberg::spec::{PartitionSpec, Schema};
3233
use iceberg::table::Table;
3334
use iceberg::transform::BoxedTransformFunction;
3435

3536
use crate::to_datafusion_error;
3637

37-
/// Column name for the combined partition values struct
38-
const PARTITION_VALUES_COLUMN: &str = "_partition";
39-
4038
/// Extends an ExecutionPlan with partition value calculations for Iceberg tables.
4139
///
4240
/// This function takes an input ExecutionPlan and extends it with an additional column
@@ -81,7 +79,7 @@ pub fn project_with_partition(
8179
}
8280

8381
let partition_expr = Arc::new(PartitionExpr::new(calculator));
84-
projection_exprs.push((partition_expr, PARTITION_VALUES_COLUMN.to_string()));
82+
projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
8583

8684
let projection = ProjectionExec::try_new(projection_exprs, input)?;
8785
Ok(Arc::new(projection))
@@ -343,7 +341,7 @@ mod tests {
343341
}
344342

345343
let partition_expr = Arc::new(PartitionExpr::new(calculator));
346-
projection_exprs.push((partition_expr, PARTITION_VALUES_COLUMN.to_string()));
344+
projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
347345

348346
let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
349347
let result = Arc::new(projection);

0 commit comments

Comments
 (0)