Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
6cae181
Initial commit
rok Aug 15, 2023
caf67d5
Add VariableShapeTensorArray::ToTensor(i)
rok Sep 3, 2023
420ba90
:Add ragged_dimensions
rok Sep 12, 2023
d5af047
Replace ragged_dimensions with uniform_dimensions
rok Sep 15, 2023
fafc2dd
Add example for explanation
rok Sep 15, 2023
9600363
Add uniform_shape parameter
rok Sep 24, 2023
3ffc016
Apply suggestions from code review
rok Sep 25, 2023
0a5fd57
Post rebase
rok Oct 11, 2023
4c64460
Remove uniform_dimensions, fix python test
rok Oct 12, 2023
31896f1
lint
rok Oct 12, 2023
eb5f45a
uniform_shape values are optional
rok Oct 12, 2023
421f2da
Add scalar test
rok Oct 29, 2023
550e5aa
Create Tensor from scalar
rok Oct 30, 2023
2f1e7f0
Move get_tensor logic to cpp
rok Nov 28, 2023
54678a8
slice buffer with array offset
rok Nov 28, 2023
a188361
Update cpp/src/arrow/extension/variable_shape_tensor.h
rok Nov 28, 2023
f02bfc7
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
e2d83ae
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
c500a18
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
721c9a4
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 28, 2023
0400746
Review feedback
rok Nov 28, 2023
393c358
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Nov 29, 2023
15c1d24
Review feedback
rok Nov 29, 2023
c8418d9
import and uint32->int32
rok Nov 29, 2023
2c15be2
permutation check
rok Nov 29, 2023
035468c
Remove serialization from cython, lint
rok Nov 29, 2023
476b01a
Review feedback
rok Nov 30, 2023
d04e4e7
ndim initializer
rok Nov 30, 2023
7bedaf5
Test null values
rok Nov 30, 2023
fd74c9d
Remove one GetTensor code paths, permutation handling
rok Dec 2, 2023
3f419bd
Allow arbitrary memory layout
rok Dec 3, 2023
312e24f
fix permutation check
rok Dec 3, 2023
32819c0
lint
rok Dec 3, 2023
fb796ea
lint
rok Dec 3, 2023
fadeafa
roundtrip strided
rok Dec 4, 2023
f3a371a
Apply suggestions from code review
rok Dec 13, 2023
e5473a5
remove array.gettensor, simlify
rok Dec 13, 2023
67ef05e
work
rok Dec 14, 2023
49a0bba
Add repr
rok Dec 14, 2023
d0f4632
Review feedback
rok Dec 14, 2023
dc2f383
GetTensor->MakeTensor, static
rok Dec 23, 2023
cd21aac
Better permutations check
rok Dec 23, 2023
6c4420f
post rebase changes
rok Feb 8, 2024
2e612cd
work
rok Feb 9, 2024
c8047d1
ToString new parameter
rok Mar 4, 2024
b20d9fd
Remove Python bindings
rok Mar 4, 2024
8373b64
Review feedback
rok Mar 16, 2024
5b0de2b
Use TensorFromJSON
rok Mar 16, 2024
c2d8284
lint
rok Mar 17, 2024
b92a32e
Apply suggestions from code review
rok Mar 27, 2024
08c73c4
Update cpp/src/arrow/extension/variable_shape_tensor.cc
rok Mar 27, 2024
10da7f1
fix
rok Mar 27, 2024
6bbf8cc
Review feedback
rok Mar 27, 2024
373e29c
mingw64 issue
rok Mar 28, 2024
2c4b0bd
refactor ComputeStrides
rok Mar 29, 2024
38189a7
Change to ComputeStrides
rok Apr 1, 2024
53d40ea
Change ToTensor
rok Apr 1, 2024
6e6b679
Refactoring ComputeStrides
rok Apr 2, 2024
ea80008
Move RoundtripBatch to gtest_util.cc
rok Apr 14, 2024
f4cb7fe
Post rebase changes
rok Jun 6, 2024
8dfb4a3
Post rebase changes
rok Sep 11, 2024
e8afc66
post rebase fixes
rok Nov 12, 2025
438dd29
review feedback
rok Feb 20, 2026
f0088c0
span
rok Feb 20, 2026
f0d0ebe
std::span
rok Feb 20, 2026
d8b83e5
some changes
rok Feb 20, 2026
7f200fc
permutation equivalency check
rok Feb 20, 2026
e49a8a0
lint
rok Feb 20, 2026
96cdfa8
some typos, style changes, etc
rok Feb 20, 2026
87e77b5
lint
rok Feb 21, 2026
ab3edb4
Factor out SliceTensorBuffer
rok Feb 24, 2026
0de39f2
review feedback
rok Feb 24, 2026
586f822
lint
rok Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,8 @@ if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/fixed_shape_tensor.cc
extension/opaque.cc
extension/tensor_internal.cc
extension/variable_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc)

if(ARROW_JSON)
list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc)
list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc)
endif()

add_arrow_test(test
Expand Down
102 changes: 19 additions & 83 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include "arrow/array/array_primitive.h"
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
#include "arrow/tensor.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging_internal.h"
#include "arrow/util/print_internal.h"
#include "arrow/util/sort_internal.h"
Expand All @@ -37,71 +36,18 @@

namespace rj = arrow::rapidjson;

namespace arrow {

namespace extension {

namespace {

Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
std::vector<int64_t>* strides) {
if (permutation.empty()) {
return internal::ComputeRowMajorStrides(type, shape, strides);
}

const int byte_width = type.byte_width();

int64_t remaining = 0;
if (!shape.empty() && shape.front() > 0) {
remaining = byte_width;
for (auto i : permutation) {
if (i > 0) {
if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
return Status::Invalid(
"Strides computed from shape would not fit in 64-bit integer");
}
}
}
}

if (remaining == 0) {
strides->assign(shape.size(), byte_width);
return Status::OK();
}

strides->push_back(remaining);
for (auto i : permutation) {
if (i > 0) {
remaining /= shape[i];
strides->push_back(remaining);
}
}
internal::Permute(permutation, strides);

return Status::OK();
}

} // namespace
namespace arrow::extension {

bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
return false;
}
const auto& other_ext = internal::checked_cast<const FixedShapeTensorType&>(other);

auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
for (size_t i = 1; i < permutation.size(); ++i) {
if (permutation[i - 1] + 1 != permutation[i]) {
return false;
}
}
return true;
};
const bool permutation_equivalent =
((permutation_ == other_ext.permutation()) ||
(permutation_.empty() && is_permutation_trivial(other_ext.permutation())) ||
(is_permutation_trivial(permutation_) && other_ext.permutation().empty()));
(permutation_ == other_ext.permutation()) ||
(internal::IsPermutationTrivial(permutation_) &&
internal::IsPermutationTrivial(other_ext.permutation()));

return (storage_type()->Equals(other_ext.storage_type())) &&
(this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) &&
Expand Down Expand Up @@ -167,7 +113,8 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
rj::Document document;
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
!document.HasMember("shape") || !document["shape"].IsArray()) {
!document.IsObject() || !document.HasMember("shape") ||
!document["shape"].IsArray()) {
return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
}

Expand Down Expand Up @@ -218,10 +165,6 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
if (array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls to Tensor.");
}
const auto& value_type =
internal::checked_cast<const FixedWidthType&>(*ext_type.value_type());
const auto byte_width = value_type.byte_width();

std::vector<int64_t> permutation = ext_type.permutation();
if (permutation.empty()) {
permutation.resize(ext_type.ndim());
Expand All @@ -236,13 +179,10 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
internal::Permute<std::string>(permutation, &dim_names);
}

std::vector<int64_t> strides;
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
const auto buffer =
SliceBuffer(array->data()->buffers[1], start_position, size * byte_width);
ARROW_ASSIGN_OR_RAISE(
auto strides, internal::ComputeStrides(ext_type.value_type(), shape, permutation));
ARROW_ASSIGN_OR_RAISE(const auto buffer, internal::SliceTensorBuffer(
*array, *ext_type.value_type(), shape));

return Tensor::Make(ext_type.value_type(), buffer, shape, strides, dim_names);
}
Expand Down Expand Up @@ -304,7 +244,7 @@ Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor
break;
}
case Type::UINT64: {
value_array = std::make_shared<Int64Array>(tensor->size(), tensor->data());
value_array = std::make_shared<UInt64Array>(tensor->size(), tensor->data());
break;
}
case Type::INT64: {
Expand Down Expand Up @@ -375,10 +315,8 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
shape.insert(shape.begin(), 1, this->length());
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
ARROW_RETURN_NOT_OK(
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));
ARROW_ASSIGN_OR_RAISE(auto tensor_strides,
internal::ComputeStrides(value_type, shape, permutation));

const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -412,11 +350,10 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
strides_ = tensor_strides;
auto maybe_strides =
internal::ComputeStrides(this->value_type_, this->shape(), this->permutation());
ARROW_CHECK_OK(maybe_strides.status());
strides_ = std::move(maybe_strides).MoveValueUnsafe();
}
return strides_;
}
Expand All @@ -426,9 +363,8 @@ std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& va
const std::vector<int64_t>& permutation,
const std::vector<std::string>& dim_names) {
auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names);
ARROW_DCHECK_OK(maybe_type.status());
ARROW_CHECK_OK(maybe_type.status());
return maybe_type.MoveValueUnsafe();
}

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
7 changes: 2 additions & 5 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {
namespace arrow::extension {

class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
Expand Down Expand Up @@ -112,7 +111,6 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
const std::vector<std::string>& dim_names = {});

private:
std::shared_ptr<DataType> storage_type_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> shape_;
std::vector<int64_t> strides_;
Expand All @@ -126,5 +124,4 @@ ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
Loading
Loading