diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc index 65a4fcee7a2..e73df6e9828 100644 --- a/cpp/src/arrow/ipc/metadata_internal.cc +++ b/cpp/src/arrow/ipc/metadata_internal.cc @@ -851,7 +851,9 @@ class FieldToFlatbufferVisitor { }; Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos, - DictionaryMemo* dictionary_memo, std::shared_ptr* out) { + DictionaryMemo* dictionary_memo, + const IpcReadOptions* options, + std::shared_ptr* out) { std::shared_ptr type; std::shared_ptr metadata; @@ -866,7 +868,7 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos, child_fields.resize(children->size()); for (int i = 0; i < static_cast(children->size()); ++i) { RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), field_pos.child(i), - dictionary_memo, &child_fields[i])); + dictionary_memo, options, &child_fields[i])); } } @@ -899,21 +901,26 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos, // Look for extension metadata in custom_metadata field int name_index = metadata->FindKey(kExtensionTypeKeyName); if (name_index != -1) { - std::shared_ptr ext_type = - GetExtensionType(metadata->value(name_index)); - if (ext_type != nullptr) { - int data_index = metadata->FindKey(kExtensionMetadataKeyName); - std::string type_data = data_index == -1 ? "" : metadata->value(data_index); - - ARROW_ASSIGN_OR_RAISE(type, ext_type->Deserialize(type, type_data)); - // Remove the metadata, for faithful roundtripping - if (data_index != -1) { - RETURN_NOT_OK(metadata->DeleteMany({name_index, data_index})); - } else { - RETURN_NOT_OK(metadata->Delete(name_index)); + // Check if extension types are blocked + bool should_deserialize = (options == nullptr || !options->extension_types_blocked); + + if (should_deserialize) { + std::shared_ptr ext_type = + GetExtensionType(metadata->value(name_index)); + if (ext_type != nullptr) { + int data_index = metadata->FindKey(kExtensionMetadataKeyName); + std::string type_data = data_index == -1 ? "" : metadata->value(data_index); + + ARROW_ASSIGN_OR_RAISE(type, ext_type->Deserialize(type, type_data)); + // Remove the metadata, for faithful roundtripping + if (data_index != -1) { + RETURN_NOT_OK(metadata->DeleteMany({name_index, data_index})); + } else { + RETURN_NOT_OK(metadata->Delete(name_index)); + } } } - // NOTE: if extension type is unknown, we do not raise here and + // NOTE: if extension type is unknown or blocked, we do not raise here and // simply return the storage type. } } @@ -933,6 +940,13 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos, return Status::OK(); } +// Backward-compatible overload without options +Status FieldFromFlatbuffer(const flatbuf::Field* field, FieldPosition field_pos, + DictionaryMemo* dictionary_memo, std::shared_ptr* out) { + return FieldFromFlatbuffer(field, field_pos, dictionary_memo, nullptr, out); +} + + flatbuffers::Offset SerializeCustomMetadata( FBB& fbb, const std::shared_ptr& metadata) { std::vector key_values; @@ -1433,7 +1447,7 @@ Status WriteFileFooter(const Schema& schema, const std::vector& dicti // ---------------------------------------------------------------------- Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, - std::shared_ptr* out) { + const IpcReadOptions* options, std::shared_ptr* out) { auto schema = static_cast(opaque_schema); CHECK_FLATBUFFERS_NOT_NULL(schema, "schema"); CHECK_FLATBUFFERS_NOT_NULL(schema->fields(), "Schema.fields"); @@ -1447,7 +1461,7 @@ Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, // XXX I don't think this check is necessary (AP) CHECK_FLATBUFFERS_NOT_NULL(field, "DictionaryEncoding.indexType"); RETURN_NOT_OK( - FieldFromFlatbuffer(field, field_pos.child(i), dictionary_memo, &fields[i])); + FieldFromFlatbuffer(field, field_pos.child(i), dictionary_memo, options, &fields[i])); } std::shared_ptr metadata; @@ -1460,6 +1474,12 @@ Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, return Status::OK(); } +// Backward-compatible overload +Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + return GetSchema(opaque_schema, dictionary_memo, nullptr, out); +} + Status GetTensorMetadata(const Buffer& metadata, std::shared_ptr* type, std::vector* shape, std::vector* strides, std::vector* dim_names) { diff --git a/cpp/src/arrow/ipc/options.h b/cpp/src/arrow/ipc/options.h index ec0e2a5b6f9..a40053ab340 100644 --- a/cpp/src/arrow/ipc/options.h +++ b/cpp/src/arrow/ipc/options.h @@ -189,6 +189,16 @@ struct ARROW_EXPORT IpcReadOptions { /// The lazy property will always be reset to true to deliver the expected behavior io::CacheOptions pre_buffer_cache_options = io::CacheOptions::LazyDefaults(); + /// \brief Whether to disable deserialization of extension types + /// + /// If true, extension types will be deserialized as their storage types instead + /// of calling custom deserialization code. This can be useful for security-sensitive + /// applications that want to avoid potentially buggy third-party extension type + /// deserialization code. + /// + /// Default is false (extension types are deserialized normally). + bool extension_types_blocked = false; + static IpcReadOptions Defaults(); }; diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index f1571f76c24..6ad06a09f0d 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -810,7 +810,7 @@ Status UnpackSchemaMessage(const void* opaque_schema, const IpcReadOptions& opti std::shared_ptr* schema, std::shared_ptr* out_schema, std::vector* field_inclusion_mask, bool* swap_endian) { - RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, schema)); + RETURN_NOT_OK(internal::GetSchema(opaque_schema, dictionary_memo, &options, schema)); // If we are selecting only certain fields, populate the inclusion mask now // for fast lookups diff --git a/cpp/src/arrow/sparse_tensor.cc b/cpp/src/arrow/sparse_tensor.cc index b84070b3d28..477fa2f7650 100644 --- a/cpp/src/arrow/sparse_tensor.cc +++ b/cpp/src/arrow/sparse_tensor.cc @@ -405,13 +405,10 @@ SparseCSFIndex::SparseCSFIndex(const std::vector>& indpt std::string SparseCSFIndex::ToString() const { return std::string("SparseCSFIndex"); } bool SparseCSFIndex::Equals(const SparseCSFIndex& other) const { - for (int64_t i = 0; i < static_cast(indices().size()); ++i) { - if (!indices()[i]->Equals(*other.indices()[i])) return false; - } - for (int64_t i = 0; i < static_cast(indptr().size()); ++i) { - if (!indptr()[i]->Equals(*other.indptr()[i])) return false; - } - return axis_order() == other.axis_order(); + auto eq = [](const auto& a, const auto& b) { return a->Equals(*b); }; + return axis_order() == other.axis_order() && + std::ranges::equal(indices(), other.indices(), eq) && + std::ranges::equal(indptr(), other.indptr(), eq); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/sparse_tensor_test.cc b/cpp/src/arrow/sparse_tensor_test.cc index c9c28a11b1b..434f4a1723c 100644 --- a/cpp/src/arrow/sparse_tensor_test.cc +++ b/cpp/src/arrow/sparse_tensor_test.cc @@ -1641,10 +1641,32 @@ TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestNonAscendingShape) { ASSERT_TRUE(st->Equals(*sparse_tensor)); } +TYPED_TEST_P(TestSparseCSFTensorForIndexValueType, TestEqualityMismatchedDimensions) { + using IndexValueType = TypeParam; + using c_index_value_type = typename IndexValueType::c_type; + + // 2D vs 3D - comparing indices with different dimensionality + // 2D CSF: ndim=2, so indptr.size()=1, indices.size()=2 + std::vector axis_order_2D = {0, 1}; + std::vector> indptr_2D = {{0, 1}}; + std::vector> indices_2D = {{0}, {0}}; + auto si_2D = this->MakeSparseCSFIndex(axis_order_2D, indptr_2D, indices_2D); + + // 3D CSF: ndim=3, so indptr.size()=2, indices.size()=3 + std::vector axis_order_3D = {0, 1, 2}; + std::vector> indptr_3D = {{0, 1}, {0, 1}}; + std::vector> indices_3D = {{0}, {0}, {0}}; + auto si_3D = this->MakeSparseCSFIndex(axis_order_3D, indptr_3D, indices_3D); + + ASSERT_FALSE(si_2D->Equals(*si_3D)); + ASSERT_FALSE(si_3D->Equals(*si_2D)); + ASSERT_TRUE(si_2D->Equals(*si_2D)); +} + REGISTER_TYPED_TEST_SUITE_P(TestSparseCSFTensorForIndexValueType, TestCreateSparseTensor, TestTensorToSparseTensor, TestSparseTensorToTensor, TestAlternativeAxisOrder, TestNonAscendingShape, - TestRoundTrip); + TestRoundTrip, TestEqualityMismatchedDimensions); INSTANTIATE_TYPED_TEST_SUITE_P(TestInt8, TestSparseCSFTensorForIndexValueType, Int8Type); INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestSparseCSFTensorForIndexValueType,