Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ use arrow::datatypes::{
DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema,
TimeUnit, UnionFields, UnionMode, i256,
};
use arrow::ipc::{reader::read_record_batch, root_as_message};
use arrow::ipc::{
convert::fb_to_schema,
reader::{read_dictionary, read_record_batch},
root_as_message,
writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions},
};

use datafusion_common::{
Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef,
Expand Down Expand Up @@ -406,6 +411,35 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
));
};

// IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf
// `Schema` doesn't preserve those IDs. Reconstruct them deterministically by
// round-tripping the schema through IPC.
let schema: Schema = {
let ipc_gen = IpcDataGenerator {};
let write_options = IpcWriteOptions::default();
let mut dict_tracker = DictionaryTracker::new(false);
let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker(
&schema,
&mut dict_tracker,
&write_options,
);
let message =
root_as_message(encoded_schema.ipc_message.as_slice()).map_err(
|e| {
Error::General(format!(
"Error IPC schema message while deserializing ScalarValue::List: {e}"
))
},
)?;
Comment on lines +427 to +433

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current error handling loses the original ArrowError type by wrapping it in Error::General. It's better to convert it to a DataFusionError to preserve the error type and add context, which improves debuggability. This can be done using arrow_datafusion_err! and the .context() method.

This pattern can be applied to similar root_as_message calls in this file (e.g., lines 443-447 and 463-467).

                        root_as_message(encoded_schema.ipc_message.as_slice()).map_err(|e| {
                            arrow_datafusion_err!(e)
                                .context("Error IPC schema message while deserializing ScalarValue::List")
                        })?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:good-to-have; category:bug; feedback:The Gemini AI reviewer is correct! By using arrow_datafusion_err!() macro the original error type will be preserved and also its stacktrace will be collected. Creating a new custom Error will use the Display view of the error and will not init the backtrace from the correct call. But it uses the custom Error for consistency with the rest of the code in this file.

let ipc_schema = message.header_as_schema().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List schema"
.to_string(),
)
})?;
fb_to_schema(ipc_schema)
};

let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List: {e}"
Expand All @@ -420,7 +454,12 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let mut dict_by_id: HashMap<i64, ArrayRef> = HashMap::new();
for protobuf::scalar_nested_value::Dictionary {
ipc_message,
arrow_data,
} in dictionaries
{
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
Expand All @@ -434,22 +473,16 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
.to_string(),
)
})?;

let id = dict_batch.id();

let record_batch = read_record_batch(
read_dictionary(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema.clone()),
&Default::default(),
None,
dict_batch,
&schema,
&mut dict_by_id,
&message.version(),
)?;

let values: ArrayRef = Arc::clone(record_batch.column(0));

Ok((id, values))
}).collect::<datafusion_common::Result<HashMap<_, _>>>()?;
)
.map_err(|e| arrow_datafusion_err!(e))
.map_err(|e| e.context("Decoding ScalarValue::List dictionary"))?;
}

let record_batch = read_record_batch(
&buffer,
Expand Down
7 changes: 7 additions & 0 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,13 @@ fn encode_scalar_nested_value(
let ipc_gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let write_options = IpcWriteOptions::default();
// The IPC writer requires pre-allocated dictionary IDs (normally assigned when
// serializing the schema). Populate `dict_tracker` by encoding the schema first.
ipc_gen.schema_to_bytes_with_dictionary_tracker(
batch.schema().as_ref(),
&mut dict_tracker,
&write_options,
);
let mut compression_context = CompressionContext::default();
let (encoded_dictionaries, encoded_message) = ipc_gen
.encode(
Expand Down
19 changes: 19 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2564,3 +2564,22 @@ fn custom_proto_converter_intercepts() -> Result<()> {

Ok(())
}

#[test]
fn roundtrip_call_null_scalar_struct_dict() -> Result<()> {
let data_type = DataType::Struct(Fields::from(vec![Field::new(
"item",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
true,
)]));

let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)]));
let scan = Arc::new(EmptyExec::new(Arc::clone(&schema)));
let scalar = lit(ScalarValue::try_from(data_type)?);
let filter = Arc::new(FilterExec::try_new(
Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)),
scan,
)?);

roundtrip_test(filter)
}