Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,9 @@ class FieldToFlatbufferVisitor {
};

Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos,
DictionaryMemo* dictionary_memo, std::shared_ptr<Field>* out) {
DictionaryMemo* dictionary_memo,
const IpcReadOptions* options,
std::shared_ptr<Field>* out) {
std::shared_ptr<DataType> type;

std::shared_ptr<KeyValueMetadata> metadata;
Expand All @@ -866,7 +868,7 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos,
child_fields.resize(children->size());
for (int i = 0; i < static_cast<int>(children->size()); ++i) {
RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), field_pos.child(i),
dictionary_memo, &child_fields[i]));
dictionary_memo, options, &child_fields[i]));
}
}

Expand Down Expand Up @@ -899,21 +901,26 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos,
// Look for extension metadata in custom_metadata field
int name_index = metadata->FindKey(kExtensionTypeKeyName);
if (name_index != -1) {
std::shared_ptr<ExtensionType> ext_type =
GetExtensionType(metadata->value(name_index));
if (ext_type != nullptr) {
int data_index = metadata->FindKey(kExtensionMetadataKeyName);
std::string type_data = data_index == -1 ? "" : metadata->value(data_index);

ARROW_ASSIGN_OR_RAISE(type, ext_type->Deserialize(type, type_data));
// Remove the metadata, for faithful roundtripping
if (data_index != -1) {
RETURN_NOT_OK(metadata->DeleteMany({name_index, data_index}));
} else {
RETURN_NOT_OK(metadata->Delete(name_index));
// Check if extension types are blocked
bool should_deserialize = (options == nullptr || !options->extension_types_blocked);

if (should_deserialize) {
std::shared_ptr<ExtensionType> ext_type =
GetExtensionType(metadata->value(name_index));
if (ext_type != nullptr) {
int data_index = metadata->FindKey(kExtensionMetadataKeyName);
std::string type_data = data_index == -1 ? "" : metadata->value(data_index);

ARROW_ASSIGN_OR_RAISE(type, ext_type->Deserialize(type, type_data));
// Remove the metadata, for faithful roundtripping
if (data_index != -1) {
RETURN_NOT_OK(metadata->DeleteMany({name_index, data_index}));
} else {
RETURN_NOT_OK(metadata->Delete(name_index));
}
}
}
// NOTE: if extension type is unknown, we do not raise here and
// NOTE: if extension type is unknown or blocked, we do not raise here and
// simply return the storage type.
}
}
Expand All @@ -933,6 +940,13 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos,
return Status::OK();
}

// Backward-compatible overload without options
Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos,
DictionaryMemo* dictionary_memo, std::shared_ptr<Field>* out) {
return FieldFromFlatbuffer(field, field_pos, dictionary_memo, nullptr, out);
}


flatbuffers::Offset<KVVector> SerializeCustomMetadata(
FBB& fbb, const std::shared_ptr<const KeyValueMetadata>& metadata) {
std::vector<KeyValueOffset> key_values;
Expand Down Expand Up @@ -1433,7 +1447,7 @@ Status WriteFileFooter(const Schema& schema, const std::vector<FileBlock>& dicti
// ----------------------------------------------------------------------

Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo,
std::shared_ptr<Schema>* out) {
const IpcReadOptions* options, std::shared_ptr<Schema>* out) {
auto schema = static_cast<const flatbuf::Schema*>(opaque_schema);
CHECK_FLATBUFFERS_NOT_NULL(schema, "schema");
CHECK_FLATBUFFERS_NOT_NULL(schema->fields(), "Schema.fields");
Expand All @@ -1447,7 +1461,7 @@ Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo,
// XXX I don't think this check is necessary (AP)
CHECK_FLATBUFFERS_NOT_NULL(field, "DictionaryEncoding.indexType");
RETURN_NOT_OK(
FieldFromFlatbuffer(field, field_pos.child(i), dictionary_memo, &fields[i]));
FieldFromFlatbuffer(field, field_pos.child(i), dictionary_memo, options, &fields[i]));
}

std::shared_ptr<KeyValueMetadata> metadata;
Expand All @@ -1460,6 +1474,12 @@ Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo,
return Status::OK();
}

// Backward-compatible overload
Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo,
std::shared_ptr<Schema>* out) {
return GetSchema(opaque_schema, dictionary_memo, nullptr, out);
}

Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr<DataType>* type,
std::vector<int64_t>* shape, std::vector<int64_t>* strides,
std::vector<std::string>* dim_names) {
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/ipc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ struct ARROW_EXPORT IpcReadOptions {
/// The lazy property will always be reset to true to deliver the expected behavior
io::CacheOptions pre_buffer_cache_options = io::CacheOptions::LazyDefaults();

/// \brief Whether to disable deserialization of extension types
///
/// If true, extension types will be deserialized as their storage types instead
/// of calling custom deserialization code. This can be useful for security-sensitive
/// applications that want to avoid potentially buggy third-party extension type
/// deserialization code.
///
/// Default is false (extension types are deserialized normally).
bool extension_types_blocked = false;

static IpcReadOptions Defaults();
};

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& opti
std::shared_ptr<Schema>* schema,
std::shared_ptr<Schema>* out_schema,
std::vector<bool>* field_inclusion_mask, bool* swap_endian) {
RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema));
RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, &options, schema));

// If we are selecting only certain fields, populate the inclusion mask now
// for fast lookups
Expand Down
11 changes: 4 additions & 7 deletions cpp/src/arrow/sparse_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,10 @@ SparseCSFIndex::SparseCSFIndex(const std::vector<std::shared_ptr<Tensor>>& indpt
std::string SparseCSFIndex::ToString() const { return std::string("SparseCSFIndex"); }

bool SparseCSFIndex::Equals(const SparseCSFIndex& other) const {
for (int64_t i = 0; i < static_cast<int64_t>(indices().size()); ++i) {
if (!indices()[i]->Equals(*other.indices()[i])) return false;
}
for (int64_t i = 0; i < static_cast<int64_t>(indptr().size()); ++i) {
if (!indptr()[i]->Equals(*other.indptr()[i])) return false;
}
return axis_order() == other.axis_order();
auto eq = [](const auto& a, const auto& b) { return a->Equals(*b); };
return axis_order() == other.axis_order() &&
std::ranges::equal(indices(), other.indices(), eq) &&
std::ranges::equal(indptr(), other.indptr(), eq);
}

// ----------------------------------------------------------------------
Expand Down
24 changes: 23 additions & 1 deletion cpp/src/arrow/sparse_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1641,10 +1641,32 @@ TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestNonAscendingShape) {
ASSERT_TRUE(st->Equals(*sparse_tensor));
}

TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestEqualityMismatchedDimensions) {
using IndexValueType = TypeParam;
using c_index_value_type = typename IndexValueType::c_type;

// 2D vs 3D - comparing indices with different dimensionality
// 2D CSF: ndim=2, so indptr.size()=1, indices.size()=2
std::vector<int64_t> axis_order_2D = {0, 1};
std::vector<std::vector<c_index_value_type>> indptr_2D = {{0, 1}};
std::vector<std::vector<c_index_value_type>> indices_2D = {{0}, {0}};
auto si_2D = this->MakeSparseCSFIndex(axis_order_2D, indptr_2D, indices_2D);

// 3D CSF: ndim=3, so indptr.size()=2, indices.size()=3
std::vector<int64_t> axis_order_3D = {0, 1, 2};
std::vector<std::vector<c_index_value_type>> indptr_3D = {{0, 1}, {0, 1}};
std::vector<std::vector<c_index_value_type>> indices_3D = {{0}, {0}, {0}};
auto si_3D = this->MakeSparseCSFIndex(axis_order_3D, indptr_3D, indices_3D);

ASSERT_FALSE(si_2D->Equals(*si_3D));
ASSERT_FALSE(si_3D->Equals(*si_2D));
ASSERT_TRUE(si_2D->Equals(*si_2D));
}

REGISTER_TYPED_TEST_SUITE_P(TestSparseCSFTensorForIndexValueType, TestCreateSparseTensor,
TestTensorToSparseTensor, TestSparseTensorToTensor,
TestAlternativeAxisOrder, TestNonAscendingShape,
TestRoundTrip);
TestRoundTrip, TestEqualityMismatchedDimensions);

INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestSparseCSFTensorForIndexValueType, Int8Type);
INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestSparseCSFTensorForIndexValueType,
Expand Down
Loading