Skip to content

Commit ad44670

Browse files
ezyangpytorchmergebot
authored andcommitted
Back out "Revert D38984222: Don't introduce new overload for SymInt (pytorch#83628)" (pytorch#84173)
Also Back out "Revert D39075159: [acc_tensor] Use SymIntArrayRef for overloaded empty.memory_format's signature" Original commit changeset: dab4a9dba4fa Original commit changeset: dcaf16c037a9 Original Phabricator Diff: D38984222 Original Phabricator Diff: D39075159 Also update Metal registrations for C++ registration changes. Also update NNPI registration to account for tightened schema checking Differential Revision: [D39084762](https://our.internmc.facebook.com/intern/diff/D39084762/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39084762/)! Pull Request resolved: pytorch#84173 Approved by: https://github.com/Krovatkin
1 parent cfd18e1 commit ad44670

File tree

89 files changed

+864
-749
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+864
-749
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9b2f7929c2dae841888a836449c25b04c8cf4045
1+
95eedc33fb48c2ba72f5efa45daa4941cb069864

aten/src/ATen/BatchingRegistrations.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
186186
}
187187

188188
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
189-
return self.expand(asIntArrayRefSlow(psize), implicit);
189+
// TODO: properly support this
190+
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
190191
}
191192

192193
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
@@ -469,7 +470,8 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
469470
}
470471

471472
Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
472-
return self.view(asIntArrayRefSlow(size));
473+
// TODO: properly support this
474+
return view_batching_rule(self, asIntArrayRefSlow(size));
473475
}
474476

475477
Tensor view_as_complex_batching_rule(const Tensor& self) {
@@ -1009,6 +1011,7 @@ Tensor new_empty_symint_batching_rule(
10091011
c10::optional<Layout> layout,
10101012
c10::optional<Device> device,
10111013
c10::optional<bool> pin_memory) {
1014+
// TODO: properly support this
10121015
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
10131016
}
10141017

@@ -1109,8 +1112,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
11091112
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
11101113
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
11111114
m.impl("diagonal", diagonal_batching_rule);
1112-
m.impl("expand", expand_batching_rule);
1113-
m.impl("expand.SymInt", expand_symint_batching_rule);
1115+
m.impl("expand", expand_symint_batching_rule);
11141116
m.impl("expand_as", native::expand_as); // composite wrt autograd
11151117
m.impl("movedim.intlist", movedim_batching_rule);
11161118
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
@@ -1138,8 +1140,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
11381140
m.impl("unbind.int", unbind_batching_rule);
11391141
m.impl("unfold", unfold_batching_rule);
11401142
m.impl("unsqueeze", unsqueeze_batching_rule);
1141-
m.impl("view", view_batching_rule);
1142-
m.impl("view.SymInt", view_symint_batching_rule);
1143+
m.impl("view", view_symint_batching_rule);
11431144
m.impl("view_as", native::view_as); // composite wrt autograd
11441145

11451146
// clamp operations
@@ -1277,8 +1278,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
12771278
m.impl("diagonal_backward", diagonal_backward_batching_rule);
12781279

12791280
// Tensor.new_* operators
1280-
m.impl("new_empty", new_empty_batching_rule);
1281-
m.impl("new_empty.SymInt", new_empty_symint_batching_rule);
1281+
m.impl("new_empty", new_empty_symint_batching_rule);
12821282
m.impl("new_empty_strided", new_empty_strided_batching_rule);
12831283
m.impl("new_zeros", new_zeros_batching_rule);
12841284

aten/src/ATen/FunctionalInverses.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,8 @@ Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tenso
137137
return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
138138
}
139139

140-
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, bool implicit) {
141-
return at::sum_to(mutated_view, base.sizes(),/*always_return_non_view=*/!reapply_views);
142-
}
143-
144-
Tensor FunctionalInverses::expand_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size, bool implicit) {
145-
return at::sum_to(mutated_view, c10::asIntArrayRefSlow(base.sym_sizes()),/*always_return_non_view=*/!reapply_views);
140+
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) {
141+
return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views);
146142
}
147143

148144
Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) {
@@ -291,22 +287,15 @@ Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Ten
291287
return base.select_scatter(mutated_view, dim, mutated_view_idx);
292288
}
293289

294-
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) {
295-
if (reapply_views) {
296-
return mutated_view.view(base.sizes());
297-
} else {
298-
return at::view_copy(mutated_view, base.sizes());
299-
}
300-
}
301-
302-
Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
290+
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) {
303291
if (reapply_views) {
304292
return mutated_view.view_symint(base.sym_sizes());
305293
} else {
306294
return at::view_copy_symint(mutated_view, base.sym_sizes());
307295
}
308296
}
309297

298+
310299
Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
311300
if (reapply_views) {
312301
return mutated_view.view(base.scalar_type());

aten/src/ATen/core/NamedRegistrations.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
179179
m.impl("exp.out", CppFunction::makeFallthrough());
180180
m.impl("exp_", CppFunction::makeFallthrough());
181181
m.impl("expand", CppFunction::makeFallthrough());
182-
m.impl("expand.SymInt", CppFunction::makeFallthrough());
183182
m.impl("expm1", CppFunction::makeFallthrough());
184183
m.impl("expm1.out", CppFunction::makeFallthrough());
185184
m.impl("expm1_", CppFunction::makeFallthrough());

aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,14 @@ namespace impl {
353353
template<bool AllowDeprecatedTypes>
354354
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
355355
static std::vector<c10::SymInt> call(IValue& v) {
356-
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
356+
if (v.isIntList()) {
357+
std::vector<c10::SymInt> r;
358+
auto src = v.toIntList();
359+
std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
360+
return r;
361+
} else {
362+
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
363+
}
357364
}
358365
};
359366
template<class T, bool AllowDeprecatedTypes>

aten/src/ATen/core/custom_class.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ c10::FunctionSchema class_base::withNewArguments(
143143
new_args.emplace_back(
144144
default_arg.name_,
145145
old_arg.type(),
146+
old_arg.real_type(),
146147
old_arg.N(),
147148
default_arg.value_);
148149
}

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)
3535

3636
namespace {
3737
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
38-
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
38+
// TODO: figure out if we can just directly save real schema at def time
39+
c10::optional<std::string> schema_difference = findSchemaDifferences(
40+
from_def.cloneWithRealTypes(),
41+
inferred.cloneWithRealTypes()
42+
);
3943
if (schema_difference.has_value()) {
4044
TORCH_CHECK(false,
4145
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"

aten/src/ATen/core/dynamic_type.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,6 @@ TypePtr DynamicType::fallback() const {
231231
return BoolType::get();
232232
case Tag::Int:
233233
return IntType::get();
234-
case Tag::SymInt:
235-
return SymIntType::get();
236234
case Tag::Float:
237235
return FloatType::get();
238236
case Tag::Complex:
@@ -326,8 +324,6 @@ DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
326324
return DynamicTypeTrait<ComplexType>::getBaseType();
327325
case Tag::Int:
328326
return DynamicTypeTrait<IntType>::getBaseType();
329-
case Tag::SymInt:
330-
return DynamicTypeTrait<SymIntType>::getBaseType();
331327
case Tag::Bool:
332328
return DynamicTypeTrait<BoolType>::getBaseType();
333329
case Tag::String:

aten/src/ATen/core/dynamic_type.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);
1616

1717
constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
1818
constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
19-
constexpr DynamicTypeBits kDynamicSymIntTypeBit = DYNAMIC_TYPE_BIT(23);
2019
constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
2120
constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
2221
constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
@@ -29,7 +28,6 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
2928
_(Bool, DYNAMIC_TYPE_BIT(2), 1) \
3029
_(Int, kDynamicIntTypeBit, 1) \
3130
_(Float, kDynamicFloatTypeBit, 1) \
32-
_(SymInt, kDynamicSymIntTypeBit, 1) \
3331
_(Complex, kDynamicComplexTypeBit, 1) \
3432
_(Number, \
3533
(kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \
@@ -63,6 +61,7 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
6361
#define FORALL_DYNAMIC_TYPES_FAKE(_) \
6462
_(ScalarType, kDynamicIntTypeBit, 1) \
6563
_(Layout, kDynamicIntTypeBit, 1) \
64+
_(SymInt, kDynamicIntTypeBit, 1) \
6665
_(MemoryFormat, kDynamicIntTypeBit, 1)
6766

6867
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;

aten/src/ATen/core/function_schema.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
1717
}
1818
}
1919

20+
FunctionSchema FunctionSchema::cloneWithRealTypes() const {
21+
auto cloneWithRealTypes = [](const Argument& a) {
22+
return a.cloneWithType(a.real_type());
23+
};
24+
std::vector<Argument> new_arguments, new_returns;
25+
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
26+
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
27+
return FunctionSchema(
28+
name(),
29+
overload_name(),
30+
std::move(new_arguments),
31+
std::move(new_returns),
32+
is_vararg(),
33+
is_varret());
34+
}
35+
2036
bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<AliasTypeSet> &lhs, const c10::optional<AliasTypeSet> &rhs) const {
2137
if (!lhs || !rhs) {
2238
return false;

aten/src/ATen/core/function_schema.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct Argument {
4444
c10::optional<AliasInfo> alias_info = c10::nullopt)
4545
: name_(std::move(name)),
4646
type_(fake_type ? std::move(fake_type) : TensorType::get()),
47-
real_type_(real_type ? std::move(real_type) : TensorType::get()),
47+
real_type_(real_type ? std::move(real_type) : type_),
4848
N_(std::move(N)),
4949
default_value_(std::move(default_value)),
5050
alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
@@ -88,6 +88,8 @@ struct Argument {
8888
const TypePtr& type() const {
8989
return type_;
9090
}
91+
// if type() is non-null, this is guaranteed to be non-null (if no real
92+
// type was provided, this takes on type()'s value)
9193
const TypePtr& real_type() const {
9294
return real_type_;
9395
}
@@ -472,6 +474,8 @@ struct TORCH_API FunctionSchema {
472474
FunctionSchema cloneWithRemappedTypes(
473475
const std::function<TypePtr(TypePtr)> type_map) const;
474476

477+
FunctionSchema cloneWithRealTypes() const;
478+
475479
// Check that inputs have the correct types and appends any missing default
476480
// values.
477481
template <typename T = c10::PlatformType>

0 commit comments

Comments
 (0)