Skip to content

Commit 5c6f543

Browse files
ezyangpytorchmergebot
authored andcommitted
Implement SymBool (pytorch#92149)
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., pytorch#90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.) This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below. * I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix. * ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from pytorch#85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR. * (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides` * On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use. * There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable. TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#92149 Approved by: https://github.com/albanD, https://github.com/Skylion007
1 parent 34e8eb2 commit 5c6f543

18 files changed

+626
-149
lines changed

c10/core/SymBool.cpp

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include <c10/core/SymBool.h>
2+
#include <c10/core/SymNodeImpl.h>
3+
#include <array>
4+
#include <utility>
5+
6+
namespace c10 {
7+
8+
SymNode SymBool::toSymNodeImpl() const {
9+
TORCH_CHECK(is_symbolic());
10+
return SymNode::reclaim_copy(toSymNodeImplUnowned());
11+
}
12+
13+
static std::array<SymNode, 2> normalize_symbools(
14+
const SymBool& a_,
15+
const SymBool& b_) {
16+
SymNode a, b;
17+
if (a_.is_symbolic())
18+
a = a_.toSymNodeImpl();
19+
if (b_.is_symbolic())
20+
b = b_.toSymNodeImpl();
21+
22+
SymNodeImpl* common = a ? a.get() : b.get();
23+
if (!a) {
24+
a = common->wrap_bool(a_.as_bool_unchecked());
25+
}
26+
if (!b) {
27+
b = common->wrap_bool(b_.as_bool_unchecked());
28+
}
29+
return {std::move(a), std::move(b)};
30+
}
31+
32+
SymBool SymBool::sym_and(const SymBool& sci) const {
33+
if (!is_symbolic() && !sci.is_symbolic()) {
34+
return SymBool(data_ && sci.data_);
35+
}
36+
auto res = normalize_symbools(*this, sci);
37+
return SymBool(res[0]->sym_and(res[1]));
38+
}
39+
40+
SymBool SymBool::sym_or(const SymBool& sci) const {
41+
if (!is_symbolic() && !sci.is_symbolic()) {
42+
return SymBool(data_ || sci.data_);
43+
}
44+
auto res = normalize_symbools(*this, sci);
45+
return SymBool(res[0]->sym_or(res[1]));
46+
}
47+
48+
SymBool SymBool::sym_not() const {
49+
if (!is_symbolic()) {
50+
return SymBool(!data_);
51+
}
52+
return SymBool(toSymNodeImpl()->sym_not());
53+
}
54+
55+
std::ostream& operator<<(std::ostream& os, const SymBool& s) {
56+
if (s.is_symbolic()) {
57+
os << s.toSymNodeImpl()->str();
58+
} else {
59+
os << s.as_bool_unchecked();
60+
}
61+
return os;
62+
}
63+
64+
bool SymBool::guard_bool(const char* file, int64_t line) const {
65+
if (!is_symbolic()) {
66+
return data_;
67+
}
68+
SymNode a = toSymNodeImpl();
69+
return a->guard_bool(file, line);
70+
}
71+
72+
} // namespace c10

c10/core/SymBool.h

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#pragma once
2+
3+
#include <c10/core/SymNodeImpl.h>
4+
#include <c10/macros/Macros.h>
5+
#include <c10/util/Exception.h>
6+
#include <c10/util/intrusive_ptr.h>
7+
8+
#include <limits>
9+
#include <memory>
10+
11+
namespace c10 {
12+
13+
class C10_API SymBool {
14+
public:
15+
/*implicit*/ SymBool(bool b) : data_(b){};
16+
SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
17+
TORCH_CHECK(ptr_->is_bool());
18+
};
19+
SymBool() : data_(false) {}
20+
21+
SymNodeImpl* toSymNodeImplUnowned() const {
22+
return ptr_.get();
23+
}
24+
25+
SymNodeImpl* release() && {
26+
return std::move(ptr_).release();
27+
}
28+
29+
SymNode toSymNodeImpl() const;
30+
31+
bool expect_bool() const {
32+
TORCH_CHECK(!is_symbolic());
33+
return data_;
34+
}
35+
36+
SymBool sym_and(const SymBool&) const;
37+
SymBool sym_or(const SymBool&) const;
38+
SymBool sym_not() const;
39+
40+
SymBool operator&(const SymBool& other) const {
41+
return sym_and(other);
42+
}
43+
SymBool operator|(const SymBool& other) const {
44+
return sym_or(other);
45+
}
46+
SymBool operator~() const {
47+
return sym_not();
48+
}
49+
50+
// Insert a guard for the bool to be its concrete value, and then return
51+
// that value. Note that C++ comparison operations default to returning
52+
// bool, so it's not so common to have to call this
53+
bool guard_bool(const char* file, int64_t line) const;
54+
55+
C10_ALWAYS_INLINE bool is_symbolic() const {
56+
return ptr_;
57+
}
58+
59+
bool as_bool_unchecked() const {
60+
return data_;
61+
}
62+
63+
private:
64+
// TODO: optimize to union
65+
bool data_;
66+
SymNode ptr_;
67+
};
68+
69+
C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
70+
} // namespace c10

c10/core/SymInt.cpp

+16-12
Original file line numberDiff line numberDiff line change
@@ -94,48 +94,52 @@ SymInt SymInt::operator%(const SymInt& sci) const {
9494
return SymInt(res[0]->mod(res[1]));
9595
}
9696

97-
bool SymInt::operator==(const SymInt& sci) const {
97+
SymBool SymInt::sym_eq(const SymInt& sci) const {
9898
if (!is_symbolic() && !sci.is_symbolic()) {
9999
return data_ == sci.data_;
100100
}
101101
auto res = normalize_symints(*this, sci);
102-
return res[0]->eq(res[1])->bool_();
102+
return res[0]->eq(res[1]);
103103
}
104104

105-
bool SymInt::operator!=(const SymInt& sci) const {
106-
return !(*this == sci);
105+
SymBool SymInt::sym_ne(const SymInt& sci) const {
106+
if (!is_symbolic() && !sci.is_symbolic()) {
107+
return data_ != sci.data_;
108+
}
109+
auto res = normalize_symints(*this, sci);
110+
return res[0]->ne(res[1]);
107111
}
108112

109-
bool SymInt::operator<(const SymInt& sci) const {
113+
SymBool SymInt::sym_lt(const SymInt& sci) const {
110114
if (!is_symbolic() && !sci.is_symbolic()) {
111115
return data_ < sci.data_;
112116
}
113117
auto res = normalize_symints(*this, sci);
114-
return res[0]->lt(res[1])->bool_();
118+
return res[0]->lt(res[1]);
115119
}
116120

117-
bool SymInt::operator<=(const SymInt& sci) const {
121+
SymBool SymInt::sym_le(const SymInt& sci) const {
118122
if (!is_symbolic() && !sci.is_symbolic()) {
119123
return data_ <= sci.data_;
120124
}
121125
auto res = normalize_symints(*this, sci);
122-
return res[0]->le(res[1])->bool_();
126+
return res[0]->le(res[1]);
123127
}
124128

125-
bool SymInt::operator>(const SymInt& sci) const {
129+
SymBool SymInt::sym_gt(const SymInt& sci) const {
126130
if (!is_symbolic() && !sci.is_symbolic()) {
127131
return data_ > sci.data_;
128132
}
129133
auto res = normalize_symints(*this, sci);
130-
return res[0]->gt(res[1])->bool_();
134+
return res[0]->gt(res[1]);
131135
}
132136

133-
bool SymInt::operator>=(const SymInt& sci) const {
137+
SymBool SymInt::sym_ge(const SymInt& sci) const {
134138
if (!is_symbolic() && !sci.is_symbolic()) {
135139
return data_ >= sci.data_;
136140
}
137141
auto res = normalize_symints(*this, sci);
138-
return res[0]->ge(res[1])->bool_();
142+
return res[0]->ge(res[1]);
139143
}
140144

141145
SymInt SymInt::min(const SymInt& sci) const {

c10/core/SymInt.h

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <c10/core/SymBool.h>
34
#include <c10/core/SymNodeImpl.h>
45
#include <c10/macros/Macros.h>
56
#include <c10/util/Exception.h>
@@ -157,16 +158,36 @@ class C10_API SymInt {
157158
SymInt operator*(const SymInt& sci) const;
158159
SymInt operator/(const SymInt& sci) const;
159160
SymInt operator%(const SymInt& sci) const;
160-
bool operator==(const SymInt& sci) const;
161-
bool operator!=(const SymInt& p2) const;
162-
bool operator<(const SymInt& sci) const;
163-
bool operator<=(const SymInt& sci) const;
164-
bool operator>(const SymInt& sci) const;
165-
bool operator>=(const SymInt& sci) const;
166161
void operator*=(const SymInt& sci);
167162
void operator+=(const SymInt& sci);
168163
void operator/=(const SymInt& sci);
169164

165+
SymBool sym_eq(const SymInt&) const;
166+
SymBool sym_ne(const SymInt&) const;
167+
SymBool sym_lt(const SymInt&) const;
168+
SymBool sym_le(const SymInt&) const;
169+
SymBool sym_gt(const SymInt&) const;
170+
SymBool sym_ge(const SymInt&) const;
171+
172+
bool operator==(const SymInt& o) const {
173+
return sym_eq(o).guard_bool(__FILE__, __LINE__);
174+
}
175+
bool operator!=(const SymInt& o) const {
176+
return sym_ne(o).guard_bool(__FILE__, __LINE__);
177+
}
178+
bool operator<(const SymInt& o) const {
179+
return sym_lt(o).guard_bool(__FILE__, __LINE__);
180+
}
181+
bool operator<=(const SymInt& o) const {
182+
return sym_le(o).guard_bool(__FILE__, __LINE__);
183+
}
184+
bool operator>(const SymInt& o) const {
185+
return sym_gt(o).guard_bool(__FILE__, __LINE__);
186+
}
187+
bool operator>=(const SymInt& o) const {
188+
return sym_ge(o).guard_bool(__FILE__, __LINE__);
189+
}
190+
170191
SymInt min(const SymInt& sci) const;
171192
SymInt max(const SymInt& sci) const;
172193

c10/core/SymNodeImpl.h

+25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <c10/macros/Macros.h>
4+
#include <c10/util/ArrayRef.h>
45
#include <c10/util/Exception.h>
56
#include <c10/util/intrusive_ptr.h>
67
#include <memory>
@@ -25,6 +26,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
2526
virtual bool is_int() {
2627
TORCH_CHECK(false, "NYI");
2728
};
29+
virtual bool is_bool() {
30+
TORCH_CHECK(false, "NYI");
31+
};
2832
virtual bool is_float() {
2933
TORCH_CHECK(false, "NYI");
3034
};
@@ -82,6 +86,21 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
8286
virtual SymNode sym_max(const SymNode& other) {
8387
TORCH_CHECK(false, "NYI");
8488
};
89+
virtual SymNode sym_or(const SymNode& other) {
90+
TORCH_CHECK(false, "NYI");
91+
};
92+
virtual SymNode sym_and(const SymNode& other) {
93+
TORCH_CHECK(false, "NYI");
94+
};
95+
virtual SymNode sym_not() {
96+
TORCH_CHECK(false, "NYI");
97+
};
98+
// NB: self is ignored here, only the arguments are used
99+
virtual SymNode is_non_overlapping_and_dense(
100+
ArrayRef<SymNode> sizes,
101+
ArrayRef<SymNode> strides) {
102+
TORCH_CHECK(false, "NYI");
103+
};
85104
virtual SymNode clone() {
86105
TORCH_CHECK(false, "NYI");
87106
};
@@ -94,9 +113,15 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
94113
virtual SymNode wrap_float(double num) {
95114
TORCH_CHECK(false, "NYI");
96115
};
116+
virtual SymNode wrap_bool(bool num) {
117+
TORCH_CHECK(false, "NYI");
118+
};
97119
virtual int64_t guard_int(const char* file, int64_t line) {
98120
TORCH_CHECK(false, "NYI");
99121
};
122+
virtual bool guard_bool(const char* file, int64_t line) {
123+
TORCH_CHECK(false, "NYI");
124+
};
100125
virtual double guard_float(const char* file, int64_t line) {
101126
TORCH_CHECK(false, "NYI");
102127
};

docs/source/conf.py

-2
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,6 @@
336336
"Quantize",
337337
# torch.utils.backcompat
338338
"Warning",
339-
"SymInt",
340-
"SymFloat",
341339
]
342340

343341
# The suffix(es) of source filenames.

docs/source/torch.rst

+10
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,15 @@ Utilities
621621

622622
Symbolic Numbers
623623
----------------
624+
.. autoclass:: SymInt
625+
:members:
626+
627+
.. autoclass:: SymFloat
628+
:members:
629+
630+
.. autoclass:: SymBool
631+
:members:
632+
624633
.. autosummary::
625634
:toctree: generated
626635
:nosignatures:
@@ -629,6 +638,7 @@ Symbolic Numbers
629638
sym_int
630639
sym_max
631640
sym_min
641+
sym_not
632642

633643
Optimizations
634644
-------------

0 commit comments

Comments
 (0)