Skip to content

Commit 9c78f59

Browse files
ezyangpytorchmergebot
authored andcommitted
Delete SymIntArrayRef wrapper struct (pytorch#84837)
Since we separated at::foo and at::foo_symint there is no benefit to trying to make initializer lists work in both cases. So we can get rid of the special different struct. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#84837 Approved by: https://github.com/kit1980
1 parent 8cdc067 commit 9c78f59

File tree

12 files changed

+35
-215
lines changed

12 files changed

+35
-215
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e0dcc3171c8024ab288551d105fba24fbfae7332
1+
09be9870437684ba2da6741af3eb10126c04aede

aten/src/ATen/core/ivalue.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,6 @@ struct TORCH_API IValue final {
565565
}
566566
}
567567

568-
IValue(c10::SymIntArrayRef v);
569-
570568
bool isSymInt() const {
571569
return Tag::SymInt == tag;
572570
}

aten/src/ATen/core/ivalue_inl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,6 @@ inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
19991999
list.push_back(e);
20002000
}
20012001
}
2002-
inline IValue::IValue(c10::SymIntArrayRef v) : IValue(at::ArrayRef<c10::SymInt>(v.data(), v.size())) {}
20032002
template <class T, IValue::enable_if_ivalue_constructible<T>>
20042003
inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
20052004
auto list = to<c10::List<T>>();

aten/src/ATen/native/metal/ops/MetalReshape.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) {
6464

6565
Tensor reshape(const Tensor& input, IntArrayRef shape) {
6666
TORCH_CHECK(input.is_metal());
67-
return view(input, c10::SymIntArrayRef::fromIntArrayRef(shape));
67+
return view(input, c10::fromIntArrayRef(shape));
6868
}
6969

7070
Tensor flatten_using_ints(

aten/src/ATen/test/extension_backend_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Tensor empty_strided_override(
4444
c10::optional<c10::Device> device,
4545
c10::optional<bool> pin_memory) {
4646

47-
return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
47+
return empty_override(fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
4848
}
4949

5050
TORCH_LIBRARY_IMPL(aten, ORT, m) {

c10/core/SymIntArrayRef.h

Lines changed: 21 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
// This file defines `SymIntArrayRef` which serves as the view onto
2-
// std::vector<SymInt>. This class is conceptually and mostly functionally
3-
// equivalent to ArrayRef<SymInt>.
4-
//
5-
// However, ArrayRef<SymInt> can't be used directly as it introduces ambiguity
6-
// in the following cases:
7-
// - a.expand({1, 2, 3}) matches two overloads:
8-
// 1. `at::Tensor Tensor::expand(c10::SymIntArrayRef size, bool implicit)`
9-
// 2. `at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit)`
10-
// Introducing `SymIntArrayRef` allows to have a finer-grained control over
11-
// which overload will be used.
12-
131
#pragma once
142

153
#include <c10/core/SymInt.h>
@@ -23,196 +11,33 @@
2311
#include <vector>
2412

2513
namespace c10 {
26-
/// SymIntArrayRef - Represent a constant reference to an array (0 or more
27-
/// elements consecutively in memory), i.e. a start pointer and a length. It
28-
/// allows various APIs to take consecutive elements easily and conveniently.
29-
///
30-
/// This class does not own the underlying data, it is expected to be used in
31-
/// situations where the data resides in some other buffer, whose lifetime
32-
/// extends past that of the SymIntArrayRef. For this reason, it is not in
33-
/// general safe to store an SymIntArrayRef.
34-
///
35-
/// This is intended to be trivially copyable, so it should be passed by
36-
/// value.
37-
38-
class SymIntArrayRef final {
39-
public:
40-
using iterator = const c10::SymInt*;
41-
using const_iterator = const c10::SymInt*;
42-
using size_type = size_t;
43-
using value_type = c10::SymInt;
44-
45-
using reverse_iterator = std::reverse_iterator<iterator>;
46-
47-
private:
48-
ArrayRef<c10::SymInt> wrapped_symint_array_ref;
49-
50-
public:
51-
/// @name Constructors
52-
/// @{
53-
54-
/// Construct an empty SymIntArrayRef.
55-
/* implicit */ constexpr SymIntArrayRef() {}
56-
57-
/* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec)
58-
: wrapped_symint_array_ref(Vec) {}
59-
60-
/// Construct an SymIntArrayRef from a pointer and length.
61-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
62-
const c10::SymInt* data,
63-
size_t length)
64-
: wrapped_symint_array_ref(data, length) {}
65-
66-
template <typename U>
67-
/* implicit */ SymIntArrayRef(
68-
const SmallVectorTemplateCommon<c10::SymInt, U>& Vec)
69-
: wrapped_symint_array_ref(Vec) {}
70-
71-
/// Construct an SymIntArrayRef from a range.
72-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
73-
const c10::SymInt* begin,
74-
const c10::SymInt* end)
75-
: wrapped_symint_array_ref(begin, end) {}
76-
77-
/// Construct an SymIntArrayRef from a C array.
78-
template <size_t N>
79-
/* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N])
80-
: wrapped_symint_array_ref(Arr) {}
81-
82-
// Prefer using a more semantic constructor, like
83-
// fromIntArrayRefKnownNonNegative
84-
static SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
85-
return SymIntArrayRef(
86-
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
87-
}
88-
89-
static SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
90-
return fromIntArrayRefUnchecked(array_ref);
91-
}
92-
93-
static SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
94-
for (size_t i = 0; i < array_ref.size(); ++i) {
95-
TORCH_CHECK(
96-
SymInt::check_range(array_ref[i]),
97-
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
98-
array_ref[i]);
99-
}
100-
return SymIntArrayRef(
101-
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
102-
}
103-
104-
/// @}
105-
/// @name Simple Operations
106-
/// @{
107-
108-
constexpr iterator begin() const {
109-
return wrapped_symint_array_ref.begin();
110-
}
111-
constexpr iterator end() const {
112-
return wrapped_symint_array_ref.end();
113-
}
114-
115-
// These are actually the same as iterator, since SymIntArrayRef only
116-
// gives you const iterators.
117-
constexpr const_iterator cbegin() const {
118-
return wrapped_symint_array_ref.cbegin();
119-
}
120-
constexpr const_iterator cend() const {
121-
return wrapped_symint_array_ref.cend();
122-
}
123-
124-
/// empty - Check if the array is empty.
125-
constexpr bool empty() const {
126-
return size() == 0;
127-
}
128-
129-
constexpr const c10::SymInt* data() const {
130-
return wrapped_symint_array_ref.data();
131-
}
132-
133-
/// size - Get the array size.
134-
constexpr size_t size() const {
135-
return wrapped_symint_array_ref.size();
136-
}
137-
138-
/// front - Get the first element.
139-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const {
140-
return wrapped_symint_array_ref.front();
141-
}
142-
143-
/// back - Get the last element.
144-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const {
145-
return wrapped_symint_array_ref.back();
146-
}
147-
148-
/// equals - Check for element-wise equality.
149-
constexpr bool equals(SymIntArrayRef RHS) const {
150-
return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref);
151-
}
152-
153-
/// slice(n, m) - Take M elements of the array starting at element N
154-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef
155-
slice(size_t N, size_t M) const {
156-
return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M);
157-
}
158-
159-
/// slice(n) - Chop off the first N elements of the array.
160-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const {
161-
return slice(N, size() - N);
162-
}
163-
164-
/// @}
165-
/// @name Operator Overloads
166-
/// @{
167-
constexpr const c10::SymInt& operator[](size_t Index) const {
168-
return wrapped_symint_array_ref[Index];
169-
}
170-
171-
/// Vector compatibility
172-
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const {
173-
return wrapped_symint_array_ref.at(Index);
174-
}
175-
176-
/// Disallow accidental assignment from a temporary.
177-
///
178-
/// The declaration here is extra complicated so that "arrayRef = {}"
179-
/// continues to select the move assignment operator.
180-
template <typename U>
181-
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
182-
type&
183-
operator=(U&& Temporary) = delete;
184-
185-
/// Disallow accidental assignment from a temporary.
186-
///
187-
/// The declaration here is extra complicated so that "arrayRef = {}"
188-
/// continues to select the move assignment operator.
189-
template <typename U>
190-
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
191-
type&
192-
operator=(std::initializer_list<U>) = delete;
193-
194-
/// @}
195-
/// @name Expensive Operations
196-
/// @{
197-
std::vector<c10::SymInt> vec() const {
198-
return wrapped_symint_array_ref.vec();
199-
}
200-
201-
friend std::ostream& operator<<(
202-
std::ostream& out,
203-
const SymIntArrayRef& list);
204-
/// @}
205-
};
14+
using SymIntArrayRef = ArrayRef<SymInt>;
20615

20716
TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar);
20817
TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar);
20918
TORCH_API c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
21019
c10::SymIntArrayRef ar);
21120

212-
inline std::ostream& operator<<(
213-
std::ostream& out,
214-
const c10::SymIntArrayRef& list) {
215-
return out << list.wrapped_symint_array_ref;
21+
// Prefer using a more semantic constructor, like
22+
// fromIntArrayRefKnownNonNegative
23+
inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
24+
return SymIntArrayRef(
25+
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
26+
}
27+
28+
inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
29+
return fromIntArrayRefUnchecked(array_ref);
30+
}
31+
32+
inline SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
33+
for (size_t i = 0; i < array_ref.size(); ++i) {
34+
TORCH_CHECK(
35+
SymInt::check_range(array_ref[i]),
36+
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
37+
array_ref[i]);
38+
}
39+
return SymIntArrayRef(
40+
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
21641
}
21742

21843
} // namespace c10

c10/core/TensorImpl.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
603603
return sym_sizes_custom();
604604
}
605605
// Sizes guaranteed to be non-negative, so unchecked cast is OK
606-
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
606+
return c10::fromIntArrayRefKnownNonNegative(
607607
sizes_and_strides_.sizes_arrayref());
608608
}
609609

@@ -620,8 +620,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
620620
return extra_meta_->sizes_;
621621
} else {
622622
// Sizes guaranteed to be non-negative, so unchecked cast is OK
623-
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
624-
sizes_default());
623+
return c10::fromIntArrayRefKnownNonNegative(sizes_default());
625624
}
626625
}
627626

@@ -733,8 +732,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
733732
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
734733
return sym_strides_custom();
735734
}
736-
// strides guaranteed to be non-negative, so unchecked cast is OK
737-
return c10::SymIntArrayRef::fromIntArrayRefUnchecked(strides_default());
735+
return c10::fromIntArrayRefKnownNonNegative(strides_default());
738736
}
739737

740738
IntArrayRef strides_default() const {
@@ -748,8 +746,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
748746
if (has_symbolic_sizes_strides_) {
749747
return extra_meta_->strides_;
750748
} else {
751-
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
752-
strides_default());
749+
return c10::fromIntArrayRefKnownNonNegative(strides_default());
753750
}
754751
}
755752

test/cpp/tensorexpr/test_quantization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ TEST_F(Quantization, QuantDequantUInt8_NLC) {
103103
parseIR(graph_string, &*graph);
104104

105105
auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
106-
x.unsafeGetTensorImpl()->set_sizes_and_strides({1, 2, 2}, {4, 1, 2});
106+
x.unsafeGetTensorImpl()->set_sizes_and_strides(
107+
std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
107108
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
108109
auto y_expected = at::dequantize(q);
109110
TensorExprKernel k(graph);

torch/csrc/lazy/core/tensor_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
157157
return c10::SymIntArrayRef(sym_sizes_->data(), sym_sizes_->size());
158158
}
159159

160-
return c10::SymIntArrayRef::fromIntArrayRef(sizes_custom());
160+
return c10::fromIntArrayRef(sizes_custom());
161161
}
162162

163163
void LTCTensorImpl::setup_size_properties() {

torch/csrc/lazy/ts_backend/ts_native_functions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
308308
c10::optional<bool> pin_memory) {
309309
TORCH_LAZY_FN_COUNTER("lazy::");
310310
at::Tensor t = empty_symint(
311-
c10::SymIntArrayRef::fromIntArrayRef(size),
311+
c10::fromIntArrayRef(size),
312312
dtype,
313313
layout,
314314
device,
@@ -410,7 +410,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
410410
at::IntArrayRef size) {
411411
TORCH_LAZY_FN_COUNTER("lazy::");
412412
return LazyNativeFunctions::view_copy_symint(
413-
self, c10::SymIntArrayRef::fromIntArrayRef(size));
413+
self, c10::fromIntArrayRef(size));
414414
}
415415

416416
// This is needed by the torch.tensor constructor.

torchgen/api/translate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def direct_solve(goal: NamedCType) -> str:
339339
elif goal.type == BaseCType(symIntArrayRefT):
340340
try:
341341
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
342-
return f"c10::SymIntArrayRef::fromIntArrayRef({r})"
342+
return f"c10::fromIntArrayRef({r})"
343343
except UnsatError:
344344
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
345345
elif goal.type == BaseCType(SymIntT):

torchgen/gen_functionalization_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
8989
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
9090
return self.reshape(size);
9191
} else {
92-
auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
92+
auto output = at::_ops::view::call(self, c10::fromIntArrayRef(size));
9393
return output.clone();
9494
}
9595
}

0 commit comments

Comments
 (0)