Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#include "arrow/tensor.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging.h"
#include "arrow/util/print.h"
#include "arrow/util/sort.h"
#include "arrow/util/string.h"

#include <rapidjson/document.h>
#include <rapidjson/writer.h>
Expand Down Expand Up @@ -104,6 +106,22 @@ bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
permutation_equivalent;
}

std::string FixedShapeTensorType::ToString() const {
std::stringstream ss;
ss << "extension<" << this->extension_name()
<< "[value_type=" << value_type_->ToString()
<< ", shape=" << ::arrow::internal::PrintVector{shape_, ","};

if (!permutation_.empty()) {
ss << ", permutation=" << ::arrow::internal::PrintVector{permutation_, ","};
}
if (!dim_names_.empty()) {
ss << ", dim_names=[" << internal::JoinStrings(dim_names_, ",") << "]";
}
ss << "]>";
return ss.str();
}

std::string FixedShapeTensorType::Serialize() const {
rj::Document document;
document.SetObject();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
dim_names_(dim_names) {}

std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }
std::string ToString() const override;

/// Number of dimensions of tensor elements
size_t ndim() { return shape_.size(); }
Expand Down
28 changes: 28 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,32 @@ TEST_F(TestExtensionType, ComputeStrides) {
ASSERT_EQ(ext_type_7->Serialize(), R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}

TEST_F(TestExtensionType, ToString) {
auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);

auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int16(), {3, 4, 7}));
auto ext_type_2 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int32(), {3, 4, 7}, {1, 0, 2}));
auto ext_type_3 = internal::checked_pointer_cast<FixedShapeTensorType>(
fixed_shape_tensor(int64(), {3, 4, 7}, {}, {"C", "H", "W"}));

std::string result_1 = ext_type_1->ToString();
std::string expected_1 =
"extension<arrow.fixed_shape_tensor[value_type=int16, shape=[3,4,7]]>";
ASSERT_EQ(expected_1, result_1);

std::string result_2 = ext_type_2->ToString();
std::string expected_2 =
"extension<arrow.fixed_shape_tensor[value_type=int32, shape=[3,4,7], "
"permutation=[1,0,2]]>";
ASSERT_EQ(expected_2, result_2);

std::string result_3 = ext_type_3->ToString();
std::string expected_3 =
"extension<arrow.fixed_shape_tensor[value_type=int64, shape=[3,4,7], "
"dim_names=[C,H,W]]>";
ASSERT_EQ(expected_3, result_3);
}

} // namespace arrow
26 changes: 26 additions & 0 deletions cpp/src/arrow/util/print.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#pragma once

#include <tuple>
#include "arrow/util/string.h"

using arrow::internal::ToChars;

namespace arrow {
namespace internal {
Expand Down Expand Up @@ -47,5 +50,28 @@ void PrintTuple(OStream* os, const std::tuple<Args&...>& tup) {
detail::TuplePrinter<OStream, std::tuple<Args&...>, sizeof...(Args)>::Print(os, tup);
}

template <typename Range, typename Separator>
struct PrintVector {
const Range& range_;
const Separator& separator_;

template <typename Os> // template to dodge inclusion of <ostream>
friend Os& operator<<(Os& os, PrintVector l) {
bool first = true;
os << "[";
for (const auto& element : l.range_) {
if (first) {
first = false;
} else {
os << l.separator_;
}
os << ToChars(element); // use ToChars to avoid locale dependence
}
os << "]";
return os;
}
};
template <typename Range, typename Separator>
PrintVector(const Range&, const Separator&) -> PrintVector<Range, Separator>;
} // namespace internal
} // namespace arrow
4 changes: 2 additions & 2 deletions docs/source/python/extending_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ Extension arrays can be used as columns in ``pyarrow.Table`` or
f0: int8
f1: string
f2: bool
tensors_int: extension<arrow.fixed_size_tensor>
tensors_float: extension<arrow.fixed_size_tensor>
tensors_int: extension<arrow.fixed_shape_tensor[value_type=int32, shape=[2,2]]>
tensors_float: extension<arrow.fixed_shape_tensor[value_type=float, shape=[2,2]]>
----
f0: [[1,2,3]]
f1: [["foo","bar",null]]
Expand Down
19 changes: 19 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,3 +1351,22 @@ def test_tensor_type_is_picklable(pickle_module):
result = pickle_module.loads(pickle_module.dumps(expected_arr))

assert result == expected_arr


@pytest.mark.parametrize(("tensor_type", "text"), [
(
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
'fixed_shape_tensor[value_type=int8, shape=[2,2,3]]'
),
(
pa.fixed_shape_tensor(pa.int32(), [2, 2, 3], permutation=[0, 2, 1]),
'fixed_shape_tensor[value_type=int32, shape=[2,2,3], permutation=[0,2,1]]'
),
(
pa.fixed_shape_tensor(pa.int64(), [2, 2, 3], dim_names=['C', 'H', 'W']),
'fixed_shape_tensor[value_type=int64, shape=[2,2,3], dim_names=[C,H,W]]'
)
])
def test_tensor_type_str(tensor_type, text):
tensor_type_str = tensor_type.__str__()
assert text in tensor_type_str
6 changes: 3 additions & 3 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ cdef class FixedShapeTensorType(BaseExtensionType):

>>> import pyarrow as pa
>>> pa.fixed_shape_tensor(pa.int32(), [2, 2])
FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
FixedShapeTensorType(extension<arrow.fixed_shape_tensor[value_type=int32, shape=[2,2]]>)

Create an instance of fixed shape tensor extension type with
permutation:
Expand Down Expand Up @@ -4744,7 +4744,7 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N
>>> import pyarrow as pa
>>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2])
>>> tensor_type
FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
FixedShapeTensorType(extension<arrow.fixed_shape_tensor[value_type=int32, shape=[2,2]]>)

Inspect the data type:

Expand All @@ -4760,7 +4760,7 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N
>>> tensor = pa.ExtensionArray.from_storage(tensor_type, storage)
>>> pa.table([tensor], names=["tensor_array"])
pyarrow.Table
tensor_array: extension<arrow.fixed_shape_tensor>
tensor_array: extension<arrow.fixed_shape_tensor[value_type=int32, shape=[2,2]]>
----
tensor_array: [[[1,2,3,4],[10,20,30,40],[100,200,300,400]]]

Expand Down