Skip to content

Commit f59845d

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Symintify pytorch slicing logic (pytorch#91340)
Differential Revision: [D42398023](https://our.internmc.facebook.com/intern/diff/D42398023) Pull Request resolved: pytorch#91340 Approved by: https://github.com/Skylion007, https://github.com/albanD
1 parent 81b5eff commit f59845d

File tree

4 files changed

+123
-51
lines changed

4 files changed

+123
-51
lines changed

aten/src/ATen/TensorIndexing.h

+25-30
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
namespace at {
2424
namespace indexing {
2525

26-
const int64_t INDEX_MAX = std::numeric_limits<int64_t>::max();
27-
const int64_t INDEX_MIN = std::numeric_limits<int64_t>::min();
26+
const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
27+
const int64_t INDEX_MAX = -(INDEX_MIN + 1);
2828

2929
enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };
3030

@@ -37,52 +37,47 @@ TORCH_API extern const EllipsisIndexType Ellipsis;
3737

3838
struct TORCH_API Slice final {
3939
public:
40-
// This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h
4140
Slice(
42-
c10::optional<int64_t> start_index = c10::nullopt,
43-
c10::optional<int64_t> stop_index = c10::nullopt,
44-
c10::optional<int64_t> step_index = c10::nullopt) {
41+
c10::optional<c10::SymInt> start_index = c10::nullopt,
42+
c10::optional<c10::SymInt> stop_index = c10::nullopt,
43+
c10::optional<c10::SymInt> step_index = c10::nullopt) {
4544
if (!step_index.has_value()) {
46-
step_ = 1;
45+
step_ = c10::SymInt(1);
4746
} else {
48-
step_ = step_index.value();
49-
TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
50-
51-
// Here step might be -INDEX_MAX-1; in this case we replace it
52-
// with -INDEX_MAX. This doesn't affect the semantics, and it
53-
// guards against later undefined behaviour resulting from code that
54-
// does "step = -step" as part of a slice reversal.
55-
if (step_ < -INDEX_MAX)
56-
step_ = -INDEX_MAX;
47+
step_ = std::move(step_index).value();
5748
}
49+
50+
TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
51+
5852
if (!start_index.has_value()) {
59-
start_ = step_ < 0 ? INDEX_MAX : 0;
53+
start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
6054
} else {
61-
start_ = start_index.value();
55+
start_ = std::move(start_index).value();
6256
}
57+
6358
if (!stop_index.has_value()) {
64-
stop_ = step_ < 0 ? INDEX_MIN : INDEX_MAX;
59+
stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
6560
} else {
66-
stop_ = stop_index.value();
61+
stop_ = std::move(stop_index).value();
6762
}
6863
}
6964

70-
inline int64_t start() const {
65+
inline c10::SymInt start() const {
7166
return start_;
7267
}
7368

74-
inline int64_t stop() const {
69+
inline c10::SymInt stop() const {
7570
return stop_;
7671
}
7772

78-
inline int64_t step() const {
73+
inline c10::SymInt step() const {
7974
return step_;
8075
}
8176

8277
private:
83-
int64_t start_;
84-
int64_t stop_;
85-
int64_t step_;
78+
c10::SymInt start_;
79+
c10::SymInt stop_;
80+
c10::SymInt step_;
8681
};
8782

8883
TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
@@ -213,9 +208,9 @@ namespace impl {
213208
static inline Tensor applySlice(
214209
const Tensor& self,
215210
int64_t dim,
216-
int64_t start,
217-
int64_t stop,
218-
int64_t step,
211+
c10::SymInt start,
212+
c10::SymInt stop,
213+
c10::SymInt step,
219214
bool disable_slice_optimization,
220215
const at::Device& self_device,
221216
const c10::optional<SymIntArrayRef>& self_sizes) {
@@ -235,7 +230,7 @@ static inline Tensor applySlice(
235230
return self;
236231
}
237232
}
238-
return self.slice(dim, start, stop, step);
233+
return self.slice_symint(dim, start, stop, step);
239234
}
240235

241236
static inline Tensor applySelect(

c10/core/SymInt.h

+5
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ class C10_API SymInt {
190190
return i > MAX_UNREPRESENTABLE_INT;
191191
}
192192

193+
// Return the min represetable integer as a SymInt
194+
static constexpr int64_t min_representable_int() {
195+
return MAX_UNREPRESENTABLE_INT + 1;
196+
}
197+
193198
private:
194199
// Constraints on the internal representation:
195200
//

torch/csrc/autograd/python_variable_indexing.cpp

+8-21
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,6 @@ inline Variable valueToTensor(
133133
}
134134
}
135135

136-
static inline void checkUnpackSlice(
137-
PyObject* index,
138-
Py_ssize_t* start_ptr,
139-
Py_ssize_t* stop_ptr,
140-
Py_ssize_t* step_ptr) {
141-
if (PySlice_Unpack(index, start_ptr, stop_ptr, step_ptr) != 0) {
142-
throw python_error();
143-
}
144-
}
145-
146136
static inline void recordSliceTrace(PyObject* obj) {
147137
PySliceObject* sliceobj = (PySliceObject*)obj;
148138
if (THPVariable_Check(sliceobj->start)) {
@@ -214,14 +204,12 @@ static inline Variable applySlicing(
214204
}
215205
return at::indexing::TensorIndex(THPUtils_unpackLong(obj));
216206
} else if (PySlice_Check(obj)) {
217-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
218-
Py_ssize_t start, stop, step;
219-
checkUnpackSlice(obj, &start, &stop, &step);
207+
auto val = __PySlice_Unpack(obj);
220208
if (is_tracing) {
221209
recordSliceTrace(obj);
222210
}
223211
return at::indexing::TensorIndex(
224-
at::indexing::Slice(start, stop, step));
212+
at::indexing::Slice(val.start, val.stop, val.step));
225213
} else if (obj == Py_Ellipsis) {
226214
return at::indexing::TensorIndex(at::indexing::Ellipsis);
227215
} else if (obj == Py_None) {
@@ -360,15 +348,14 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
360348
return THPVariable_Wrap(at::indexing::get_item(
361349
self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}));
362350
} else if (PySlice_Check(index)) {
363-
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
364-
Py_ssize_t start, stop, step;
365-
checkUnpackSlice(index, &start, &stop, &step);
351+
auto val = __PySlice_Unpack(index);
366352
if (is_tracing) {
367353
recordSliceTrace(index);
368354
}
369355
return THPVariable_Wrap(at::indexing::get_item(
370356
self_,
371-
{at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))}));
357+
{at::indexing::TensorIndex(
358+
at::indexing::Slice(val.start, val.stop, val.step))}));
372359
} else if (index == Py_False || index == Py_True) {
373360
return THPVariable_Wrap(([&]() {
374361
pybind11::gil_scoped_release no_gil;
@@ -490,16 +477,16 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
490477
return 0;
491478
} else if (PySlice_Check(index)) {
492479
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
493-
Py_ssize_t start, stop, step;
494-
checkUnpackSlice(index, &start, &stop, &step);
480+
auto val = __PySlice_Unpack(index);
495481
if (is_tracing) {
496482
recordSliceTrace(index);
497483
}
498484
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
499485
// indexing functions from Python ]
500486
dispatch_set_item(
501487
self_,
502-
{at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))},
488+
{at::indexing::TensorIndex(
489+
at::indexing::Slice(val.start, val.stop, val.step))},
503490
value,
504491
/*disable_slice_optimization=*/is_tracing);
505492
return 0;

torch/csrc/autograd/python_variable_indexing.h

+85
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,96 @@
11
#pragma once
22

3+
#include <c10/core/SymInt.h>
34
#include <torch/csrc/autograd/python_variable.h>
45
#include <torch/csrc/python_headers.h>
6+
#include <torch/csrc/utils/pybind.h>
7+
#include <torch/csrc/utils/python_symnode.h>
58

69
namespace torch {
710
namespace autograd {
811

12+
struct UnpackedSlice {
13+
c10::SymInt start;
14+
c10::SymInt stop;
15+
c10::SymInt step;
16+
};
17+
18+
// This mirrors Cpython's PySlice_Unpack method
19+
static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) {
20+
PySliceObject* r = (PySliceObject*)_r;
21+
/* this is harder to get right than you might think */
22+
23+
c10::SymInt start_sym, stop_sym, step_sym;
24+
25+
auto clip_val = [](Py_ssize_t val) {
26+
if (val < c10::SymInt::min_representable_int()) {
27+
auto r = PyErr_WarnEx(
28+
PyExc_UserWarning,
29+
"Truncating the start/stop/step "
30+
"of slice. This is likely because of "
31+
"saved old models when the start/stop/step were larger.",
32+
1);
33+
if (r != 0) {
34+
throw python_error();
35+
}
36+
return (Py_ssize_t)(c10::SymInt::min_representable_int());
37+
}
38+
return val;
39+
};
40+
41+
if (r->step == Py_None) {
42+
step_sym = c10::SymInt(1);
43+
} else {
44+
if (torch::is_symint(r->step)) {
45+
auto step_sym = py::handle(r->step).cast<c10::SymInt>();
46+
} else {
47+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
48+
Py_ssize_t step;
49+
if (!_PyEval_SliceIndex(r->step, &step)) {
50+
throw python_error();
51+
}
52+
if (step == 0) {
53+
PyErr_SetString(PyExc_ValueError, "slice step cannot be zero");
54+
}
55+
56+
step = clip_val(step);
57+
step_sym = c10::SymInt(step);
58+
}
59+
}
60+
61+
if (torch::is_symint(r->start)) {
62+
start_sym = py::handle(r->start).cast<c10::SymInt>();
63+
} else if (r->start == Py_None) {
64+
start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0);
65+
} else {
66+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
67+
Py_ssize_t start;
68+
if (!_PyEval_SliceIndex(r->start, &start)) {
69+
throw python_error();
70+
}
71+
start = clip_val(start);
72+
start_sym = c10::SymInt(start);
73+
}
74+
75+
if (torch::is_symint(r->stop)) {
76+
stop_sym = py::handle(r->stop).cast<c10::SymInt>();
77+
} else if (r->stop == Py_None) {
78+
stop_sym = c10::SymInt(
79+
step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX);
80+
} else {
81+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
82+
Py_ssize_t stop;
83+
if (!_PyEval_SliceIndex(r->stop, &stop)) {
84+
throw python_error();
85+
}
86+
stop = clip_val(stop);
87+
stop_sym = c10::SymInt(stop);
88+
}
89+
90+
return UnpackedSlice{
91+
std::move(start_sym), std::move(stop_sym), std::move(step_sym)};
92+
}
93+
994
Py_ssize_t THPVariable_length(PyObject* self);
1095
PyObject* THPVariable_getitem(PyObject* self, PyObject* index);
1196
int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value);

0 commit comments

Comments
 (0)