|
1 |
| -// This file defines `SymIntArrayRef` which serves as the view onto |
2 |
| -// std::vector<SymInt>. This class is conceptually and mostly functionally |
3 |
| -// equivalent to ArrayRef<SymInt>. |
4 |
| -// |
5 |
| -// However, ArrayRef<SymInt> can't be used directly as it introduces ambiguity |
6 |
| -// in the following cases: |
7 |
| -// - a.expand({1, 2, 3}) matches two overloads: |
8 |
| -// 1. `at::Tensor Tensor::expand(c10::SymIntArrayRef size, bool implicit)` |
9 |
| -// 2. `at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit)` |
10 |
| -// Introducing `SymIntArrayRef` allows to have a finer-grained control over |
11 |
| -// which overload will be used. |
12 |
| - |
13 | 1 | #pragma once
|
14 | 2 |
|
15 | 3 | #include <c10/core/SymInt.h>
|
|
23 | 11 | #include <vector>
|
24 | 12 |
|
25 | 13 | namespace c10 {
|
26 |
| -/// SymIntArrayRef - Represent a constant reference to an array (0 or more |
27 |
| -/// elements consecutively in memory), i.e. a start pointer and a length. It |
28 |
| -/// allows various APIs to take consecutive elements easily and conveniently. |
29 |
| -/// |
30 |
| -/// This class does not own the underlying data, it is expected to be used in |
31 |
| -/// situations where the data resides in some other buffer, whose lifetime |
32 |
| -/// extends past that of the SymIntArrayRef. For this reason, it is not in |
33 |
| -/// general safe to store an SymIntArrayRef. |
34 |
| -/// |
35 |
| -/// This is intended to be trivially copyable, so it should be passed by |
36 |
| -/// value. |
37 |
| - |
38 |
| -class SymIntArrayRef final { |
39 |
| - public: |
40 |
| - using iterator = const c10::SymInt*; |
41 |
| - using const_iterator = const c10::SymInt*; |
42 |
| - using size_type = size_t; |
43 |
| - using value_type = c10::SymInt; |
44 |
| - |
45 |
| - using reverse_iterator = std::reverse_iterator<iterator>; |
46 |
| - |
47 |
| - private: |
48 |
| - ArrayRef<c10::SymInt> wrapped_symint_array_ref; |
49 |
| - |
50 |
| - public: |
51 |
| - /// @name Constructors |
52 |
| - /// @{ |
53 |
| - |
54 |
| - /// Construct an empty SymIntArrayRef. |
55 |
| - /* implicit */ constexpr SymIntArrayRef() {} |
56 |
| - |
57 |
| - /* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec) |
58 |
| - : wrapped_symint_array_ref(Vec) {} |
59 |
| - |
60 |
| - /// Construct an SymIntArrayRef from a pointer and length. |
61 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef( |
62 |
| - const c10::SymInt* data, |
63 |
| - size_t length) |
64 |
| - : wrapped_symint_array_ref(data, length) {} |
65 |
| - |
66 |
| - template <typename U> |
67 |
| - /* implicit */ SymIntArrayRef( |
68 |
| - const SmallVectorTemplateCommon<c10::SymInt, U>& Vec) |
69 |
| - : wrapped_symint_array_ref(Vec) {} |
70 |
| - |
71 |
| - /// Construct an SymIntArrayRef from a range. |
72 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef( |
73 |
| - const c10::SymInt* begin, |
74 |
| - const c10::SymInt* end) |
75 |
| - : wrapped_symint_array_ref(begin, end) {} |
76 |
| - |
77 |
| - /// Construct an SymIntArrayRef from a C array. |
78 |
| - template <size_t N> |
79 |
| - /* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N]) |
80 |
| - : wrapped_symint_array_ref(Arr) {} |
81 |
| - |
82 |
| - // Prefer using a more semantic constructor, like |
83 |
| - // fromIntArrayRefKnownNonNegative |
84 |
| - static SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) { |
85 |
| - return SymIntArrayRef( |
86 |
| - reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size()); |
87 |
| - } |
88 |
| - |
89 |
| - static SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) { |
90 |
| - return fromIntArrayRefUnchecked(array_ref); |
91 |
| - } |
92 |
| - |
93 |
| - static SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) { |
94 |
| - for (size_t i = 0; i < array_ref.size(); ++i) { |
95 |
| - TORCH_CHECK( |
96 |
| - SymInt::check_range(array_ref[i]), |
97 |
| - "IntArrayRef contains an int that cannot be represented as a SymInt: ", |
98 |
| - array_ref[i]); |
99 |
| - } |
100 |
| - return SymIntArrayRef( |
101 |
| - reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size()); |
102 |
| - } |
103 |
| - |
104 |
| - /// @} |
105 |
| - /// @name Simple Operations |
106 |
| - /// @{ |
107 |
| - |
108 |
| - constexpr iterator begin() const { |
109 |
| - return wrapped_symint_array_ref.begin(); |
110 |
| - } |
111 |
| - constexpr iterator end() const { |
112 |
| - return wrapped_symint_array_ref.end(); |
113 |
| - } |
114 |
| - |
115 |
| - // These are actually the same as iterator, since SymIntArrayRef only |
116 |
| - // gives you const iterators. |
117 |
| - constexpr const_iterator cbegin() const { |
118 |
| - return wrapped_symint_array_ref.cbegin(); |
119 |
| - } |
120 |
| - constexpr const_iterator cend() const { |
121 |
| - return wrapped_symint_array_ref.cend(); |
122 |
| - } |
123 |
| - |
124 |
| - /// empty - Check if the array is empty. |
125 |
| - constexpr bool empty() const { |
126 |
| - return size() == 0; |
127 |
| - } |
128 |
| - |
129 |
| - constexpr const c10::SymInt* data() const { |
130 |
| - return wrapped_symint_array_ref.data(); |
131 |
| - } |
132 |
| - |
133 |
| - /// size - Get the array size. |
134 |
| - constexpr size_t size() const { |
135 |
| - return wrapped_symint_array_ref.size(); |
136 |
| - } |
137 |
| - |
138 |
| - /// front - Get the first element. |
139 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const { |
140 |
| - return wrapped_symint_array_ref.front(); |
141 |
| - } |
142 |
| - |
143 |
| - /// back - Get the last element. |
144 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const { |
145 |
| - return wrapped_symint_array_ref.back(); |
146 |
| - } |
147 |
| - |
148 |
| - /// equals - Check for element-wise equality. |
149 |
| - constexpr bool equals(SymIntArrayRef RHS) const { |
150 |
| - return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref); |
151 |
| - } |
152 |
| - |
153 |
| - /// slice(n, m) - Take M elements of the array starting at element N |
154 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef |
155 |
| - slice(size_t N, size_t M) const { |
156 |
| - return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M); |
157 |
| - } |
158 |
| - |
159 |
| - /// slice(n) - Chop off the first N elements of the array. |
160 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const { |
161 |
| - return slice(N, size() - N); |
162 |
| - } |
163 |
| - |
164 |
| - /// @} |
165 |
| - /// @name Operator Overloads |
166 |
| - /// @{ |
167 |
| - constexpr const c10::SymInt& operator[](size_t Index) const { |
168 |
| - return wrapped_symint_array_ref[Index]; |
169 |
| - } |
170 |
| - |
171 |
| - /// Vector compatibility |
172 |
| - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const { |
173 |
| - return wrapped_symint_array_ref.at(Index); |
174 |
| - } |
175 |
| - |
176 |
| - /// Disallow accidental assignment from a temporary. |
177 |
| - /// |
178 |
| - /// The declaration here is extra complicated so that "arrayRef = {}" |
179 |
| - /// continues to select the move assignment operator. |
180 |
| - template <typename U> |
181 |
| - typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>:: |
182 |
| - type& |
183 |
| - operator=(U&& Temporary) = delete; |
184 |
| - |
185 |
| - /// Disallow accidental assignment from a temporary. |
186 |
| - /// |
187 |
| - /// The declaration here is extra complicated so that "arrayRef = {}" |
188 |
| - /// continues to select the move assignment operator. |
189 |
| - template <typename U> |
190 |
| - typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>:: |
191 |
| - type& |
192 |
| - operator=(std::initializer_list<U>) = delete; |
193 |
| - |
194 |
| - /// @} |
195 |
| - /// @name Expensive Operations |
196 |
| - /// @{ |
197 |
| - std::vector<c10::SymInt> vec() const { |
198 |
| - return wrapped_symint_array_ref.vec(); |
199 |
| - } |
200 |
| - |
201 |
| - friend std::ostream& operator<<( |
202 |
| - std::ostream& out, |
203 |
| - const SymIntArrayRef& list); |
204 |
| - /// @} |
205 |
| -}; |
| 14 | +using SymIntArrayRef = ArrayRef<SymInt>; |
206 | 15 |
|
207 | 16 | TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar);
|
208 | 17 | TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar);
|
209 | 18 | TORCH_API c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
|
210 | 19 | c10::SymIntArrayRef ar);
|
211 | 20 |
|
212 |
| -inline std::ostream& operator<<( |
213 |
| - std::ostream& out, |
214 |
| - const c10::SymIntArrayRef& list) { |
215 |
| - return out << list.wrapped_symint_array_ref; |
| 21 | +// Prefer using a more semantic constructor, like |
| 22 | +// fromIntArrayRefKnownNonNegative |
| 23 | +inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) { |
| 24 | + return SymIntArrayRef( |
| 25 | + reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size()); |
| 26 | +} |
| 27 | + |
| 28 | +inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) { |
| 29 | + return fromIntArrayRefUnchecked(array_ref); |
| 30 | +} |
| 31 | + |
| 32 | +inline SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) { |
| 33 | + for (size_t i = 0; i < array_ref.size(); ++i) { |
| 34 | + TORCH_CHECK( |
| 35 | + SymInt::check_range(array_ref[i]), |
| 36 | + "IntArrayRef contains an int that cannot be represented as a SymInt: ", |
| 37 | + array_ref[i]); |
| 38 | + } |
| 39 | + return SymIntArrayRef( |
| 40 | + reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size()); |
216 | 41 | }
|
217 | 42 |
|
218 | 43 | } // namespace c10
|
0 commit comments