Skip to content

Commit 3222e2a

Browse files
zanmato1984pitrou
andauthored
GH-44393: [C++][Compute] Vector selection functions inverse_permutation and scatter (#44394)
### Rationale for this change For background please see #44393. When implementing the "scatter" function requested in #44393, I found it also useful to make it a public vector API. After a painful thinking, I decided to name it "permute". And when implementing permute, I found it fairly easy to implement it by first computing the "reverse indices" of the positions, and then invoking the existing "take", where I think "reverse_indices" itself can also be a useful public vector API. Thus the PR categorized them as "placement functions". ### What changes are included in this PR? Implement vector selection API `inverse_permutation` and `scatter`, where `scatter(values, indices)` is implemented as `take(values, inverse_permutation(indices))`. ### Are these changes tested? UT included. ### Are there any user-facing changes? Yes, new public APIs added. Documents updated. * GitHub Issue: #44393 Lead-authored-by: Ruoxi Sun <zanmato1984@gmail.com> Co-authored-by: Rossi Sun <zanmato1984@gmail.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent d7dc586 commit 3222e2a

11 files changed

Lines changed: 1343 additions & 16 deletions

File tree

cpp/src/arrow/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,13 +771,14 @@ if(ARROW_COMPUTE)
771771
compute/kernels/scalar_validity.cc
772772
compute/kernels/vector_array_sort.cc
773773
compute/kernels/vector_cumulative_ops.cc
774-
compute/kernels/vector_pairwise.cc
775774
compute/kernels/vector_nested.cc
775+
compute/kernels/vector_pairwise.cc
776776
compute/kernels/vector_rank.cc
777777
compute/kernels/vector_replace.cc
778778
compute/kernels/vector_run_end_encode.cc
779779
compute/kernels/vector_select_k.cc
780780
compute/kernels/vector_sort.cc
781+
compute/kernels/vector_swizzle.cc
781782
compute/key_hash_internal.cc
782783
compute/key_map_internal.cc
783784
compute/light_array_internal.cc

cpp/src/arrow/compute/api_vector.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
155155
DataMember("periods", &PairwiseOptions::periods));
156156
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
157157
DataMember("recursive", &ListFlattenOptions::recursive));
158+
static auto kInversePermutationOptionsType =
159+
GetFunctionOptionsType<InversePermutationOptions>(
160+
DataMember("max_index", &InversePermutationOptions::max_index),
161+
DataMember("output_type", &InversePermutationOptions::output_type));
162+
static auto kScatterOptionsType = GetFunctionOptionsType<ScatterOptions>(
163+
DataMember("max_index", &ScatterOptions::max_index));
158164
} // namespace
159165
} // namespace internal
160166

@@ -230,6 +236,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
230236
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
231237
constexpr char ListFlattenOptions::kTypeName[];
232238

239+
InversePermutationOptions::InversePermutationOptions(
240+
int64_t max_index, std::shared_ptr<DataType> output_type)
241+
: FunctionOptions(internal::kInversePermutationOptionsType),
242+
max_index(max_index),
243+
output_type(std::move(output_type)) {}
244+
constexpr char InversePermutationOptions::kTypeName[];
245+
246+
ScatterOptions::ScatterOptions(int64_t max_index)
247+
: FunctionOptions(internal::kScatterOptionsType), max_index(max_index) {}
248+
constexpr char ScatterOptions::kTypeName[];
249+
233250
namespace internal {
234251
void RegisterVectorOptions(FunctionRegistry* registry) {
235252
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
@@ -244,6 +261,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
244261
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
245262
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
246263
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
264+
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
265+
DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType));
247266
}
248267
} // namespace internal
249268

@@ -429,5 +448,19 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
429448
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
430449
}
431450

451+
// ----------------------------------------------------------------------
452+
// Swizzle functions
453+
454+
Result<Datum> InversePermutation(const Datum& indices,
455+
const InversePermutationOptions& options,
456+
ExecContext* ctx) {
457+
return CallFunction("inverse_permutation", {indices}, &options, ctx);
458+
}
459+
460+
Result<Datum> Scatter(const Datum& values, const Datum& indices,
461+
const ScatterOptions& options, ExecContext* ctx) {
462+
return CallFunction("scatter", {values, indices}, &options, ctx);
463+
}
464+
432465
} // namespace compute
433466
} // namespace arrow

cpp/src/arrow/compute/api_vector.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,40 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
257257
bool recursive = false;
258258
};
259259

260+
/// \brief Options for inverse_permutation function
261+
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
262+
public:
263+
explicit InversePermutationOptions(int64_t max_index = -1,
264+
std::shared_ptr<DataType> output_type = NULLPTR);
265+
static constexpr char const kTypeName[] = "InversePermutationOptions";
266+
static InversePermutationOptions Defaults() { return InversePermutationOptions(); }
267+
268+
/// \brief The max value in the input indices to allow. The length of the function's
269+
/// output will be this value plus 1. If negative, this value will be set to the length
270+
/// of the input indices minus 1 and the length of the function's output will be the
271+
/// length of the input indices.
272+
int64_t max_index = -1;
273+
/// \brief The type of the output inverse permutation. If null, the output will be of
274+
/// the same type as the input indices, otherwise must be signed integer type. An
275+
/// invalid error will be reported if this type is not able to store the length of the
276+
/// input indices.
277+
std::shared_ptr<DataType> output_type = NULLPTR;
278+
};
279+
280+
/// \brief Options for scatter function
281+
class ARROW_EXPORT ScatterOptions : public FunctionOptions {
282+
public:
283+
explicit ScatterOptions(int64_t max_index = -1);
284+
static constexpr char const kTypeName[] = "ScatterOptions";
285+
static ScatterOptions Defaults() { return ScatterOptions(); }
286+
287+
/// \brief The max value in the input indices to allow. The length of the function's
288+
/// output will be this value plus 1. If negative, this value will be set to the length
289+
/// of the input indices minus 1 and the length of the function's output will be the
290+
/// length of the input indices.
291+
int64_t max_index = -1;
292+
};
293+
260294
/// @}
261295

262296
/// \brief Filter with a boolean selection filter
@@ -705,5 +739,58 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
705739
bool check_overflow = false,
706740
ExecContext* ctx = NULLPTR);
707741

742+
/// \brief Return the inverse permutation of the given indices.
743+
///
744+
/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x
745+
/// does not appear in the input indices. Indices must be in the range of [0, max_index],
746+
/// or null, which will be ignored. If multiple indices point to the same value, the last
747+
/// one is used.
748+
///
749+
/// For example, with
750+
/// indices = [null, 0, null, 2, 4, 1, 1]
751+
/// the inverse permutation is
752+
/// [1, 6, 3, null, 4, null, null]
753+
/// if max_index = 6.
754+
///
755+
/// \param[in] indices array-like indices
756+
/// \param[in] options configures the max index and the output type
757+
/// \param[in] ctx the function execution context, optional
758+
/// \return the resulting inverse permutation
759+
///
760+
/// \since 20.0.0
761+
/// \note API not yet finalized
762+
ARROW_EXPORT
763+
Result<Datum> InversePermutation(
764+
const Datum& indices,
765+
const InversePermutationOptions& options = InversePermutationOptions::Defaults(),
766+
ExecContext* ctx = NULLPTR);
767+
768+
/// \brief Scatter the values into specified positions according to the indices.
769+
///
770+
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
771+
/// in the input indices. Indices must be in the range of [0, max_index], or null, in
772+
/// which case the corresponding value will be ignored. If multiple indices point to the
773+
/// same value, the last one is used.
774+
///
775+
/// For example, with
776+
/// values = [a, b, c, d, e, f, g]
777+
/// indices = [null, 0, null, 2, 4, 1, 1]
778+
/// the output is
779+
/// [b, g, d, null, e, null, null]
780+
/// if max_index = 6.
781+
///
782+
/// \param[in] values datum to scatter
783+
/// \param[in] indices array-like indices
784+
/// \param[in] options configures the max index of to scatter
785+
/// \param[in] ctx the function execution context, optional
786+
/// \return the resulting datum
787+
///
788+
/// \since 20.0.0
789+
/// \note API not yet finalized
790+
ARROW_EXPORT
791+
Result<Datum> Scatter(const Datum& values, const Datum& indices,
792+
const ScatterOptions& options = ScatterOptions::Defaults(),
793+
ExecContext* ctx = NULLPTR);
794+
708795
} // namespace compute
709796
} // namespace arrow

cpp/src/arrow/compute/function_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) {
136136
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
137137
options.emplace_back(new Utf8NormalizeOptions());
138138
options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD));
139+
options.emplace_back(
140+
new InversePermutationOptions(/*max_index=*/42, /*output_type=*/int32()));
141+
options.emplace_back(new ScatterOptions());
142+
options.emplace_back(new ScatterOptions(/*max_index=*/42));
139143

140144
for (size_t i = 0; i < options.size(); i++) {
141145
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;

cpp/src/arrow/compute/kernels/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
115115
EXTRA_LINK_LIBS
116116
arrow_compute_kernels_testing)
117117

118+
add_arrow_compute_test(vector_swizzle_test
119+
SOURCES
120+
vector_swizzle_test.cc
121+
EXTRA_LINK_LIBS
122+
arrow_compute_kernels_testing)
123+
118124
add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
119125
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
120126
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")

cpp/src/arrow/compute/kernels/codegen_internal.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
10371037
// Generate a kernel given a templated functor for integer types
10381038
//
10391039
// See "Numeric" above for description of the generator functor
1040-
template <template <typename...> class Generator, typename Type0, typename... Args>
1041-
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
1040+
template <template <typename...> class Generator, typename Type0,
1041+
typename KernelType = ArrayKernelExec, typename... Args>
1042+
KernelType GenerateInteger(detail::GetTypeId get_id) {
10421043
switch (get_id.id) {
10431044
case Type::INT8:
10441045
return Generator<Type0, Int8Type, Args...>::Exec;

0 commit comments

Comments
 (0)