Skip to content

Commit 41035d4

Browse files
authored
GH-49716: [C++] FixedShapeTensorType::Deserialize should strictly validate serialized metadata (#49718)
### Rationale for this change FixedShapeTensorType::Deserialize should validate input from unknown sources. ### What changes are included in this PR? Adds stricter deserialization valideation. ### Are these changes tested? Yes. New tests are added. ### Are there any user-facing changes? Stricter validation should not be noticed if metadata is correct as per spec of fixed_shape_tensor. * GitHub Issue: #49716 Authored-by: Rok Mihevc <rok@mihevc.org> Signed-off-by: Raúl Cumplido <raulcumplido@gmail.com>
1 parent 0fb2c8e commit 41035d4

File tree

5 files changed

+202
-22
lines changed

5 files changed

+202
-22
lines changed

cpp/src/arrow/extension/fixed_shape_tensor.cc

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
#include <limits>
1819
#include <numeric>
1920
#include <sstream>
2021

@@ -109,8 +110,8 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
109110
return Status::Invalid("Expected FixedSizeList storage type, got ",
110111
storage_type->ToString());
111112
}
112-
auto value_type =
113-
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
113+
auto fsl_type = internal::checked_pointer_cast<FixedSizeListType>(storage_type);
114+
auto value_type = fsl_type->value_type();
114115
rj::Document document;
115116
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
116117
!document.IsObject() || !document.HasMember("shape") ||
@@ -119,29 +120,66 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
119120
}
120121

121122
std::vector<int64_t> shape;
122-
for (auto& x : document["shape"].GetArray()) {
123+
for (const auto& x : document["shape"].GetArray()) {
124+
if (!x.IsInt64()) {
125+
return Status::Invalid("shape must contain integers, got ",
126+
internal::JsonTypeName(x));
127+
}
123128
shape.emplace_back(x.GetInt64());
124129
}
130+
125131
std::vector<int64_t> permutation;
126132
if (document.HasMember("permutation")) {
127-
for (auto& x : document["permutation"].GetArray()) {
133+
const auto& json_permutation = document["permutation"];
134+
if (!json_permutation.IsArray()) {
135+
return Status::Invalid("permutation must be an array, got ",
136+
internal::JsonTypeName(json_permutation));
137+
}
138+
for (const auto& x : json_permutation.GetArray()) {
139+
if (!x.IsInt64()) {
140+
return Status::Invalid("permutation must contain integers, got ",
141+
internal::JsonTypeName(x));
142+
}
128143
permutation.emplace_back(x.GetInt64());
129144
}
130145
if (shape.size() != permutation.size()) {
131146
return Status::Invalid("Invalid permutation");
132147
}
148+
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
133149
}
134150
std::vector<std::string> dim_names;
135151
if (document.HasMember("dim_names")) {
136-
for (auto& x : document["dim_names"].GetArray()) {
152+
const auto& json_dim_names = document["dim_names"];
153+
if (!json_dim_names.IsArray()) {
154+
return Status::Invalid("dim_names must be an array, got ",
155+
internal::JsonTypeName(json_dim_names));
156+
}
157+
for (const auto& x : json_dim_names.GetArray()) {
158+
if (!x.IsString()) {
159+
return Status::Invalid("dim_names must contain strings, got ",
160+
internal::JsonTypeName(x));
161+
}
137162
dim_names.emplace_back(x.GetString());
138163
}
139164
if (shape.size() != dim_names.size()) {
140165
return Status::Invalid("Invalid dim_names");
141166
}
142167
}
143168

144-
return fixed_shape_tensor(value_type, shape, permutation, dim_names);
169+
// Validate product of shape dimensions matches storage type list_size.
170+
// This check is intentionally after field parsing so that metadata-level errors
171+
// (type mismatches, size mismatches) are reported first.
172+
ARROW_ASSIGN_OR_RAISE(auto ext_type, FixedShapeTensorType::Make(
173+
value_type, shape, permutation, dim_names));
174+
const auto& fst_type = internal::checked_cast<const FixedShapeTensorType&>(*ext_type);
175+
ARROW_ASSIGN_OR_RAISE(const int64_t expected_size,
176+
internal::ComputeShapeProduct(fst_type.shape()));
177+
if (expected_size != fsl_type->list_size()) {
178+
return Status::Invalid("Product of shape dimensions (", expected_size,
179+
") does not match FixedSizeList size (", fsl_type->list_size(),
180+
")");
181+
}
182+
return ext_type;
145183
}
146184

147185
std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
@@ -310,8 +348,7 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
310348
}
311349

312350
std::vector<int64_t> shape = ext_type.shape();
313-
auto cell_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
314-
std::multiplies<>());
351+
ARROW_ASSIGN_OR_RAISE(const int64_t cell_size, internal::ComputeShapeProduct(shape));
315352
shape.insert(shape.begin(), 1, this->length());
316353
internal::Permute<int64_t>(permutation, &shape);
317354

@@ -330,6 +367,11 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
330367
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
331368
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
332369
const size_t ndim = shape.size();
370+
for (auto dim : shape) {
371+
if (dim < 0) {
372+
return Status::Invalid("shape must have non-negative values, got ", dim);
373+
}
374+
}
333375
if (!permutation.empty() && ndim != permutation.size()) {
334376
return Status::Invalid("permutation size must match shape size. Expected: ", ndim,
335377
" Got: ", permutation.size());
@@ -342,8 +384,12 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
342384
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
343385
}
344386

345-
const int64_t size = std::accumulate(shape.begin(), shape.end(),
346-
static_cast<int64_t>(1), std::multiplies<>());
387+
ARROW_ASSIGN_OR_RAISE(const int64_t size, internal::ComputeShapeProduct(shape));
388+
if (size > std::numeric_limits<int32_t>::max()) {
389+
return Status::Invalid("Product of shape dimensions (", size,
390+
") exceeds maximum FixedSizeList size (",
391+
std::numeric_limits<int32_t>::max(), ")");
392+
}
347393
return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
348394
shape, permutation, dim_names);
349395
}

cpp/src/arrow/extension/tensor_extension_array_test.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,73 @@ TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) {
219219
CheckDeserializationRaises(ext_type_, storage_type,
220220
R"({"shape":[3],"dim_names":["x","y"]})",
221221
"Invalid dim_names");
222+
223+
// Validate shape values must be integers. Error message should include the
224+
// JSON type name of the offending value.
225+
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3.5,4]})",
226+
"shape must contain integers, got Number");
227+
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":["3","4"]})",
228+
"shape must contain integers, got String");
229+
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[null]})",
230+
"shape must contain integers, got Null");
231+
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[true]})",
232+
"shape must contain integers, got True");
233+
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[false]})",
234+
"shape must contain integers, got False");
235+
236+
// Validate shape values must be non-negative
237+
CheckDeserializationRaises(ext_type_, fixed_size_list(int64(), 1), R"({"shape":[-1]})",
238+
"shape must have non-negative values");
239+
240+
// Validate product of shape matches storage list_size
241+
CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3,3]})",
242+
"Product of shape dimensions");
243+
244+
// Validate permutation member must be an array with integer values
245+
CheckDeserializationRaises(ext_type_, storage_type,
246+
R"({"shape":[3,4],"permutation":"invalid"})",
247+
"permutation must be an array, got String");
248+
CheckDeserializationRaises(ext_type_, storage_type,
249+
R"({"shape":[3,4],"permutation":{"a":1}})",
250+
"permutation must be an array, got Object");
251+
CheckDeserializationRaises(ext_type_, storage_type,
252+
R"({"shape":[3,4],"permutation":[1.5,0.5]})",
253+
"permutation must contain integers, got Number");
254+
CheckDeserializationRaises(ext_type_, storage_type,
255+
R"({"shape":[3,4],"permutation":["a","b"]})",
256+
"permutation must contain integers, got String");
257+
258+
// Validate permutation values must be unique integers in [0, N-1]
259+
CheckDeserializationRaises(ext_type_, storage_type,
260+
R"({"shape":[3,4],"permutation":[0,0]})",
261+
"Permutation indices");
262+
CheckDeserializationRaises(ext_type_, storage_type,
263+
R"({"shape":[3,4],"permutation":[0,5]})",
264+
"Permutation indices");
265+
CheckDeserializationRaises(ext_type_, storage_type,
266+
R"({"shape":[3,4],"permutation":[-1,0]})",
267+
"Permutation indices");
268+
269+
// Validate dim_names member must be an array with string values
270+
CheckDeserializationRaises(ext_type_, storage_type,
271+
R"({"shape":[3,4],"dim_names":"invalid"})",
272+
"dim_names must be an array, got String");
273+
CheckDeserializationRaises(ext_type_, storage_type,
274+
R"({"shape":[3,4],"dim_names":[1,2]})",
275+
"dim_names must contain strings, got Number");
276+
CheckDeserializationRaises(ext_type_, storage_type,
277+
R"({"shape":[3,4],"dim_names":[null,null]})",
278+
"dim_names must contain strings, got Null");
279+
}
280+
281+
TEST_F(TestFixedShapeTensorType, MakeValidatesShape) {
282+
// Negative shape values should be rejected
283+
EXPECT_RAISES_WITH_MESSAGE_THAT(
284+
Invalid, testing::HasSubstr("shape must have non-negative values"),
285+
FixedShapeTensorType::Make(value_type_, {-1}));
286+
EXPECT_RAISES_WITH_MESSAGE_THAT(
287+
Invalid, testing::HasSubstr("shape must have non-negative values"),
288+
FixedShapeTensorType::Make(value_type_, {3, -1, 4}));
222289
}
223290

224291
TEST_F(TestFixedShapeTensorType, RoundtripBatch) {
@@ -794,6 +861,32 @@ TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) {
794861
"Invalid: permutation");
795862
CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})",
796863
"Invalid: dim_names");
864+
865+
// Validate permutation member must be an array with integer values. Error
866+
// message should include the JSON type name of the offending value.
867+
CheckDeserializationRaises(ext_type_, storage_type, R"({"permutation":"invalid"})",
868+
"permutation must be an array, got String");
869+
CheckDeserializationRaises(ext_type_, storage_type, R"({"permutation":[1.5,0.5,2.5]})",
870+
"permutation must contain integers, got Number");
871+
CheckDeserializationRaises(ext_type_, storage_type,
872+
R"({"permutation":[null,null,null]})",
873+
"permutation must contain integers, got Null");
874+
875+
// Validate dim_names member must be an array with string values
876+
CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":"invalid"})",
877+
"dim_names must be an array, got String");
878+
CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":[1,2,3]})",
879+
"dim_names must contain strings, got Number");
880+
881+
// Validate uniform_shape member must be an array with integer-or-null values
882+
CheckDeserializationRaises(ext_type_, storage_type, R"({"uniform_shape":"invalid"})",
883+
"uniform_shape must be an array, got String");
884+
CheckDeserializationRaises(ext_type_, storage_type,
885+
R"({"uniform_shape":[1.5,null,null]})",
886+
"uniform_shape must contain integers or nulls, got Number");
887+
CheckDeserializationRaises(ext_type_, storage_type,
888+
R"({"uniform_shape":["x",null,null]})",
889+
"uniform_shape must contain integers or nulls, got String");
797890
}
798891

799892
TEST_F(TestVariableShapeTensorType, RoundtripBatch) {

cpp/src/arrow/extension/tensor_internal.cc

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,31 @@
3030

3131
namespace arrow::internal {
3232

33+
namespace {
34+
35+
// Names indexed by rapidjson::Type enum value:
36+
// kNullType=0, kFalseType=1, kTrueType=2, kObjectType=3,
37+
// kArrayType=4, kStringType=5, kNumberType=6.
38+
constexpr const char* kJsonTypeNames[] = {"Null", "False", "True", "Object",
39+
"Array", "String", "Number"};
40+
41+
} // namespace
42+
43+
const char* JsonTypeName(const ::arrow::rapidjson::Value& v) {
44+
return kJsonTypeNames[v.GetType()];
45+
}
46+
47+
Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape) {
48+
int64_t product = 1;
49+
for (const auto dim : shape) {
50+
if (MultiplyWithOverflow(product, dim, &product)) {
51+
return Status::Invalid(
52+
"Product of tensor shape dimensions would not fit in 64-bit integer");
53+
}
54+
}
55+
return product;
56+
}
57+
3358
bool IsPermutationTrivial(std::span<const int64_t> permutation) {
3459
for (size_t i = 1; i < permutation.size(); ++i) {
3560
if (permutation[i - 1] + 1 != permutation[i]) {
@@ -105,12 +130,7 @@ Result<std::shared_ptr<Buffer>> SliceTensorBuffer(const Array& data_array,
105130
const DataType& value_type,
106131
std::span<const int64_t> shape) {
107132
const int64_t byte_width = value_type.byte_width();
108-
int64_t size = 1;
109-
for (const auto dim : shape) {
110-
if (MultiplyWithOverflow(size, dim, &size)) {
111-
return Status::Invalid("Tensor size would not fit in 64-bit integer");
112-
}
113-
}
133+
ARROW_ASSIGN_OR_RAISE(const int64_t size, ComputeShapeProduct(shape));
114134
if (size != data_array.length()) {
115135
return Status::Invalid("Expected data array of length ", size, ", got ",
116136
data_array.length());

cpp/src/arrow/extension/tensor_internal.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,25 @@
2121
#include <span>
2222
#include <vector>
2323

24+
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
2425
#include "arrow/result.h"
2526
#include "arrow/type_fwd.h"
2627

28+
#include <rapidjson/document.h>
29+
2730
namespace arrow::internal {
2831

32+
/// \brief Return the name of a RapidJSON value's type (e.g., "Null", "Array", "Number").
33+
ARROW_EXPORT
34+
const char* JsonTypeName(const ::arrow::rapidjson::Value& v);
35+
36+
/// \brief Compute the product of the given shape dimensions.
37+
///
38+
/// Returns Status::Invalid if the product would overflow int64_t.
39+
/// An empty shape returns 1 (the multiplicative identity).
40+
ARROW_EXPORT
41+
Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape);
42+
2943
ARROW_EXPORT
3044
bool IsPermutationTrivial(std::span<const int64_t> permutation);
3145

cpp/src/arrow/extension/variable_shape_tensor.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,26 +159,31 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
159159
if (document.HasMember("permutation")) {
160160
const auto& json_permutation = document["permutation"];
161161
if (!json_permutation.IsArray()) {
162-
return Status::Invalid("permutation must be an array");
162+
return Status::Invalid("permutation must be an array, got ",
163+
internal::JsonTypeName(json_permutation));
163164
}
164165
permutation.reserve(ndim);
165166
for (const auto& x : json_permutation.GetArray()) {
166167
if (!x.IsInt64()) {
167-
return Status::Invalid("permutation must contain integers");
168+
return Status::Invalid("permutation must contain integers, got ",
169+
internal::JsonTypeName(x));
168170
}
169171
permutation.emplace_back(x.GetInt64());
170172
}
173+
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
171174
}
172175
std::vector<std::string> dim_names;
173176
if (document.HasMember("dim_names")) {
174177
const auto& json_dim_names = document["dim_names"];
175178
if (!json_dim_names.IsArray()) {
176-
return Status::Invalid("dim_names must be an array");
179+
return Status::Invalid("dim_names must be an array, got ",
180+
internal::JsonTypeName(json_dim_names));
177181
}
178182
dim_names.reserve(ndim);
179183
for (const auto& x : json_dim_names.GetArray()) {
180184
if (!x.IsString()) {
181-
return Status::Invalid("dim_names must contain strings");
185+
return Status::Invalid("dim_names must contain strings, got ",
186+
internal::JsonTypeName(x));
182187
}
183188
dim_names.emplace_back(x.GetString());
184189
}
@@ -188,7 +193,8 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
188193
if (document.HasMember("uniform_shape")) {
189194
const auto& json_uniform_shape = document["uniform_shape"];
190195
if (!json_uniform_shape.IsArray()) {
191-
return Status::Invalid("uniform_shape must be an array");
196+
return Status::Invalid("uniform_shape must be an array, got ",
197+
internal::JsonTypeName(json_uniform_shape));
192198
}
193199
uniform_shape.reserve(ndim);
194200
for (const auto& x : json_uniform_shape.GetArray()) {
@@ -197,7 +203,8 @@ Result<std::shared_ptr<DataType>> VariableShapeTensorType::Deserialize(
197203
} else if (x.IsInt64()) {
198204
uniform_shape.emplace_back(x.GetInt64());
199205
} else {
200-
return Status::Invalid("uniform_shape must contain integers or nulls");
206+
return Status::Invalid("uniform_shape must contain integers or nulls, got ",
207+
internal::JsonTypeName(x));
201208
}
202209
}
203210
}

0 commit comments

Comments
 (0)