Skip to content

Commit 1ff5222

Browse files
ezyangpytorchmergebot
authored andcommitted
Unify SymIntNode and SymFloatNode into SymNode (pytorch#87817)
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy pytorch#87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <[email protected]> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: pytorch#87817 Approved by: https://github.com/albanD, https://github.com/anjali411
1 parent 2205f56 commit 1ff5222

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+730
-1437
lines changed

.github/ci_commit_pins/xla.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
095ee628212f0235ad0d6908bdd514123639fc86
1+
1e9b8bdc75114ac6c16305c970be37a1cd2fdb1c

.lintrunner.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ command = [
439439
"""--error-description=\
440440
This line has an isinstance call that directly refers to \
441441
int or float. This is error-prone because you may also \
442-
have wanted to allow SymIntNode or SymFloatNode in your test. \
442+
have wanted to allow SymInt or SymFloat in your test. \
443443
To suppress this lint, use an appropriate type alias defined \
444444
in torch._prims_common; use IntLike/FloatLike when you would accept \
445445
both regular and symbolic numbers, Dim for ints representing \

aten/src/ATen/FunctionalStorageImpl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ c10::SymInt get_nbytes(const Tensor& value) {
9595
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
9696
// Today, the two implementations of SymInt are in Python (proxy tensor),
9797
// and lazy tensor (LTC/XLA).
98-
// LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl).
98+
// LTC hasn't implemented SymInt support yet though
9999
// Once it does, we should remove this check.
100100
if (value.key_set().has(c10::DispatchKey::Python)) {
101101
return value.storage().sym_nbytes();

aten/src/ATen/core/ivalue.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ struct TORCH_API IValue final {
562562
IValue(c10::SymInt i) {
563563
if (i.is_symbolic()) {
564564
tag = Tag::SymInt;
565-
payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release();
565+
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
566566
} else {
567567
tag = Tag::Int;
568568
payload.u.as_int = i.as_int_unchecked();
@@ -578,7 +578,7 @@ struct TORCH_API IValue final {
578578
IValue(c10::SymFloat i) {
579579
if (i.is_symbolic()) {
580580
tag = Tag::SymFloat;
581-
payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release();
581+
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
582582
} else {
583583
tag = Tag::Double;
584584
payload.u.as_double = i.as_float_unchecked();
@@ -812,10 +812,10 @@ struct TORCH_API IValue final {
812812
// for both SymFloat and double
813813
if (s.isSymInt()) {
814814
tag = Tag::SymInt;
815-
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release();
815+
payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
816816
} else if (s.isSymFloat()) {
817817
tag = Tag::SymFloat;
818-
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release();
818+
payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
819819
} else if (s.isFloatingPoint()) {
820820
tag = Tag::Double;
821821
payload.u.as_double = s.toDouble();

aten/src/ATen/core/ivalue_inl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& {
219219
inline c10::SymInt IValue::toSymInt() const {
220220
AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
221221
if (isSymInt()) {
222-
return c10::SymInt::toSymInt(toIntrusivePtr<c10::SymIntNodeImpl>());
222+
return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
223223
} else {
224224
return c10::SymInt(payload.u.as_int);
225225
}
@@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const {
228228
inline c10::SymFloat IValue::toSymFloat() const {
229229
AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
230230
if (isSymFloat()) {
231-
return c10::SymFloat::toSymFloat(toIntrusivePtr<c10::SymFloatNodeImpl>());
231+
return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
232232
} else {
233233
return c10::SymFloat(payload.u.as_double);
234234
}

aten/src/ATen/core/jit_type.h

-1
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type {
13101310
return "SymInt";
13111311
}
13121312
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1313-
// TODO: will become a Union[SymIntNodeImpl|int] in the near future
13141313
return "int";
13151314
}
13161315
static const TypeKind Kind = TypeKind::SymIntType;

aten/src/ATen/test/scalar_test.cpp

-31
Original file line numberDiff line numberDiff line change
@@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) {
194194
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
195195
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
196196
}
197-
198-
TEST(TestSymInt, Basic) {
199-
Scalar foo;
200-
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
201-
foo = Scalar(a_impl->toSymInt());
202-
ASSERT_EQ(a_impl.use_count(), 2);
203-
Scalar bar{foo};
204-
ASSERT_EQ(a_impl.use_count(), 3);
205-
auto baz = bar;
206-
ASSERT_EQ(a_impl.use_count(), 4);
207-
auto foo2 = std::move(bar);
208-
ASSERT_EQ(a_impl.use_count(), 4);
209-
ASSERT_TRUE(foo2.isSymInt());
210-
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
211-
ASSERT_TRUE(bar.isIntegral(false));
212-
foo2 = SymInt(4);
213-
ASSERT_FALSE(foo2.isSymInt());
214-
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
215-
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
216-
foo2 = foo2;
217-
ASSERT_FALSE(foo2.isSymInt());
218-
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
219-
220-
ASSERT_EQ(a_impl.use_count(), 3);
221-
222-
ASSERT_THROW(foo.to<double>(), c10::Error);
223-
224-
Scalar int_s = 3;
225-
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
226-
227-
}

build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ libtorch_python_core_sources = [
958958
"torch/csrc/utils/object_ptr.cpp",
959959
"torch/csrc/utils/python_arg_parser.cpp",
960960
"torch/csrc/utils/python_dispatch.cpp",
961+
"torch/csrc/utils/python_symnode.cpp",
961962
"torch/csrc/utils/structseq.cpp",
962963
"torch/csrc/utils/tensor_apply.cpp",
963964
"torch/csrc/utils/tensor_dtypes.cpp",

c10/core/Scalar.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,17 @@ class C10_API Scalar {
9292

9393
SymInt toSymInt() const {
9494
if (Tag::HAS_si == tag) {
95-
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy(
96-
static_cast<SymIntNodeImpl*>(v.p)));
95+
return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
96+
static_cast<SymNodeImpl*>(v.p)));
9797
} else {
9898
return toLong();
9999
}
100100
}
101101

102102
SymFloat toSymFloat() const {
103103
if (Tag::HAS_sd == tag) {
104-
return c10::SymFloat::toSymFloat(
105-
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy(
106-
static_cast<SymFloatNodeImpl*>(v.p)));
104+
return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
105+
static_cast<SymNodeImpl*>(v.p)));
107106
} else {
108107
return toDouble();
109108
}

c10/core/SymFloat.cpp

+15-24
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,27 @@
11
#include <c10/core/SymFloat.h>
2-
#include <c10/core/SymFloatNodeImpl.h>
2+
#include <c10/core/SymNodeImpl.h>
33
#include <array>
44

55
namespace c10 {
66

7-
SymFloatNode SymFloat::toSymFloatNodeImpl() const {
7+
SymNode SymFloat::toSymNodeImpl() const {
88
TORCH_CHECK(is_symbolic());
9-
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
9+
return SymNode::reclaim_copy(toSymNodeImplUnowned());
1010
}
1111

12-
static std::array<SymFloatNode, 2> normalize_symfloats(
13-
SymFloat a_,
14-
SymFloat b_) {
15-
SymFloatNode a, b;
12+
static std::array<SymNode, 2> normalize_symfloats(SymFloat a_, SymFloat b_) {
13+
SymNode a, b;
1614
if (a_.is_symbolic())
17-
a = a_.toSymFloatNodeImpl();
15+
a = a_.toSymNodeImpl();
1816
if (b_.is_symbolic())
19-
b = b_.toSymFloatNodeImpl();
17+
b = b_.toSymNodeImpl();
2018

21-
SymFloatNodeImpl* common = a ? a.get() : b.get();
22-
// TODO: technically we need to check that the classes match
19+
SymNodeImpl* common = a ? a.get() : b.get();
2320
if (!a) {
24-
a = common->wrap(a_.as_float_unchecked());
25-
a_.toSymFloat(a); //
21+
a = common->wrap_float(a_.as_float_unchecked());
2622
}
2723
if (!b) {
28-
b = common->wrap(b_.as_float_unchecked());
29-
b_.toSymFloat(b);
24+
b = common->wrap_float(b_.as_float_unchecked());
3025
}
3126
return {a, b};
3227
}
@@ -36,40 +31,36 @@ SymFloat SymFloat::operator+(SymFloat sci) const {
3631
return SymFloat(data_ + sci.data_);
3732
}
3833
auto res = normalize_symfloats(*this, sci);
39-
return SymFloat::toSymFloat(res[0]->add(res[1]));
34+
return SymFloat(res[0]->add(res[1]));
4035
}
4136

4237
SymFloat SymFloat::operator-(SymFloat sci) const {
4338
if (!is_symbolic() && !sci.is_symbolic()) {
4439
return SymFloat(data_ - sci.data_);
4540
}
4641
auto res = normalize_symfloats(*this, sci);
47-
return SymFloat::toSymFloat(res[0]->sub(res[1]));
42+
return SymFloat(res[0]->sub(res[1]));
4843
}
4944

5045
SymFloat SymFloat::operator*(SymFloat sci) const {
5146
if (!is_symbolic() && !sci.is_symbolic()) {
5247
return SymFloat(data_ * sci.data_);
5348
}
5449
auto res = normalize_symfloats(*this, sci);
55-
return SymFloat::toSymFloat(res[0]->mul(res[1]));
50+
return SymFloat(res[0]->mul(res[1]));
5651
}
5752

5853
SymFloat SymFloat::operator/(SymFloat sci) const {
5954
if (!is_symbolic() && !sci.is_symbolic()) {
6055
return SymFloat(data_ / sci.data_);
6156
}
6257
auto res = normalize_symfloats(*this, sci);
63-
return SymFloat::toSymFloat(res[0]->truediv(res[1]));
64-
}
65-
66-
c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
67-
return c10::SymFloat(std::move(sin_sp));
58+
return SymFloat(res[0]->truediv(res[1]));
6859
}
6960

7061
std::ostream& operator<<(std::ostream& os, SymFloat s) {
7162
if (s.is_symbolic()) {
72-
os << s.toSymFloatNodeImpl()->str();
63+
os << s.toSymNodeImpl()->str();
7364
} else {
7465
os << s.as_float_unchecked();
7566
}

c10/core/SymFloat.h

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <c10/core/SymFloatNodeImpl.h>
3+
#include <c10/core/SymNodeImpl.h>
44
#include <c10/macros/Macros.h>
55
#include <c10/util/Exception.h>
66
#include <c10/util/intrusive_ptr.h>
@@ -14,20 +14,21 @@ namespace c10 {
1414
class C10_API SymFloat {
1515
public:
1616
/*implicit*/ SymFloat(double d) : data_(d){};
17-
SymFloat(SymFloatNode ptr)
18-
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)){};
17+
SymFloat(SymNode ptr)
18+
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
19+
TORCH_CHECK(ptr_->is_float());
20+
};
1921
SymFloat() : data_(0.0) {}
2022

21-
SymFloatNodeImpl* toSymFloatNodeImplUnowned() const {
23+
SymNodeImpl* toSymNodeImplUnowned() const {
2224
return ptr_.get();
2325
}
2426

25-
SymFloatNodeImpl* release() && {
27+
SymNodeImpl* release() && {
2628
return std::move(ptr_).release();
2729
}
2830

29-
SymFloatNode toSymFloatNodeImpl() const;
30-
static c10::SymFloat toSymFloat(SymFloatNode sin);
31+
SymNode toSymNodeImpl() const;
3132

3233
double expect_float() const {
3334
TORCH_CHECK(!is_symbolic());
@@ -53,7 +54,7 @@ class C10_API SymFloat {
5354
private:
5455
// TODO: optimize to union
5556
double data_;
56-
SymFloatNode ptr_;
57+
SymNode ptr_;
5758
};
5859

5960
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);

c10/core/SymFloatNodeImpl.cpp

-20
This file was deleted.

c10/core/SymFloatNodeImpl.h

-76
This file was deleted.

0 commit comments

Comments
 (0)