Skip to content

Commit 272a446

Browse files
committed
[C++][Python] Support Python-like slicing in list_slice kernel
1 parent 3cbc27a commit 272a446

3 files changed

Lines changed: 26 additions & 26 deletions

File tree

cpp/src/arrow/compute/kernels/scalar_nested.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ Result<TypeHolder> ListSliceOutputType(const ListSliceOptions& opts,
162162
"`stop` being set.");
163163
}
164164
if (opts.step < 1) {
165-
return Status::Invalid("`step` must be >= 1, got: ", opts.step);
165+
return Status::Invalid("`step` must be greater than or equal to 1, got: ",
166+
opts.step);
166167
}
167168
const auto length = ListSliceLength(opts.start, opts.step, *stop);
168169
return fixed_size_list(value_type, static_cast<int32_t>(length));
@@ -183,14 +184,15 @@ struct ListSlice {
183184
const auto* list_type = checked_cast<const BaseListType*>(list_array.type);
184185

185186
// Pre-conditions
186-
if (opts.start < 0 || (opts.stop.has_value() && opts.start >= opts.stop.value())) {
187-
// TODO(ARROW-18281): support start == stop which should give empty lists
188-
return Status::Invalid("`start`(", opts.start,
189-
") should be greater than 0 and smaller than `stop`(",
190-
ToString(opts.stop), ")");
187+
if (opts.start < 0 || (opts.stop.has_value() && opts.start > opts.stop.value())) {
188+
return Status::Invalid(
189+
"`start`(", opts.start,
190+
") should be greater than or equal to 0 and not greater than `stop`(",
191+
ToString(opts.stop), ")");
191192
}
192193
if (opts.step < 1) {
193-
return Status::Invalid("`step` must be >= 1, got: ", opts.step);
194+
return Status::Invalid("`step` must be greater than or equal to 1, got: ",
195+
opts.step);
194196
}
195197

196198
auto* pool = ctx->memory_pool();

cpp/src/arrow/compute/kernels/scalar_nested_test.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ TEST(TestScalarNested, ListSliceVariableOutput) {
176176
auto input = ArrayFromJSON(fixed_size_list(int32(), 1), "[[1]]");
177177
auto expected = ArrayFromJSON(list(int32()), "[[1]]");
178178
CheckScalarUnary("list_slice", input, expected, &args);
179+
180+
args.start = 0;
181+
args.stop = 0;
182+
auto input_empty = ArrayFromJSON(list(int32()), "[[1, 2, 3], [4, 5], null]");
183+
auto expected_empty = ArrayFromJSON(list(int32()), "[[], [], null]");
184+
CheckScalarUnary("list_slice", input_empty, expected_empty, &args);
179185
}
180186

181187
TEST(TestScalarNested, ListSliceFixedOutput) {
@@ -315,22 +321,17 @@ TEST(TestScalarNested, ListSliceBadParameters) {
315321
EXPECT_RAISES_WITH_MESSAGE_THAT(
316322
Invalid,
317323
::testing::HasSubstr(
318-
"`start`(-1) should be greater than 0 and smaller than `stop`(1)"),
324+
"`start`(-1) should be greater than or equal to 0 and not greater than "
325+
"`stop`(1)"),
319326
CallFunction("list_slice", {input}, &args));
320327
// start greater than stop
321328
args.start = 1;
322329
args.stop = 0;
323330
EXPECT_RAISES_WITH_MESSAGE_THAT(
324331
Invalid,
325332
::testing::HasSubstr(
326-
"`start`(1) should be greater than 0 and smaller than `stop`(0)"),
327-
CallFunction("list_slice", {input}, &args));
328-
// start same as stop
329-
args.stop = args.start;
330-
EXPECT_RAISES_WITH_MESSAGE_THAT(
331-
Invalid,
332-
::testing::HasSubstr(
333-
"`start`(1) should be greater than 0 and smaller than `stop`(1)"),
333+
"`start`(1) should be greater than or equal to 0 and not greater than "
334+
"`stop`(0)"),
334335
CallFunction("list_slice", {input}, &args));
335336
// stop not set and FixedSizeList requested with variable sized input
336337
args.stop = std::nullopt;
@@ -343,9 +344,9 @@ TEST(TestScalarNested, ListSliceBadParameters) {
343344
args.start = 0;
344345
args.stop = 2;
345346
args.step = 0;
346-
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
347-
::testing::HasSubstr("`step` must be >= 1, got: 0"),
348-
CallFunction("list_slice", {input}, &args));
347+
EXPECT_RAISES_WITH_MESSAGE_THAT(
348+
Invalid, ::testing::HasSubstr("`step` must be greater than or equal to 1, got: 0"),
349+
CallFunction("list_slice", {input}, &args));
349350
}
350351

351352
TEST(TestScalarNested, StructField) {

python/pyarrow/tests/test_compute.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3898,7 +3898,8 @@ def test_list_slice_output_fixed(start, stop, step, expected, value_type,
38983898
(0, 1,),
38993899
(0, 2,),
39003900
(1, 2,),
3901-
(2, 4,)
3901+
(2, 4,),
3902+
(0, 0,)
39023903
))
39033904
@pytest.mark.parametrize("step", (1, 2))
39043905
@pytest.mark.parametrize("value_type", (pa.string, pa.int16, pa.float64))
@@ -3946,18 +3947,14 @@ def test_list_slice_field_names_retained(return_fixed_size, type):
39463947

39473948
def test_list_slice_bad_parameters():
39483949
arr = pa.array([[1]], pa.list_(pa.int8(), 1))
3949-
msg = r"`start`(.*) should be greater than 0 and smaller than `stop`(.*)"
3950+
msg = r"`start`(.*) should be greater than or equal to 0 and not greater than `stop`(.*)"
39503951
with pytest.raises(pa.ArrowInvalid, match=msg):
39513952
pc.list_slice(arr, -1, 1) # negative start?
39523953
with pytest.raises(pa.ArrowInvalid, match=msg):
39533954
pc.list_slice(arr, 2, 1) # start > stop?
39543955

3955-
# TODO(ARROW-18281): start==stop -> empty lists
3956-
with pytest.raises(pa.ArrowInvalid, match=msg):
3957-
pc.list_slice(arr, 0, 0) # start == stop?
3958-
39593956
# Step not >= 1
3960-
msg = "`step` must be >= 1, got: "
3957+
msg = "`step` must be greater than or equal to 1, got: "
39613958
with pytest.raises(pa.ArrowInvalid, match=msg + "0"):
39623959
pc.list_slice(arr, 0, 1, step=0)
39633960
with pytest.raises(pa.ArrowInvalid, match=msg + "-1"):

0 commit comments

Comments
 (0)