Skip to content

Commit b623bde

Browse files
smessmerfacebook-github-bot
authored andcommitted
Move TensorOptions ops to c10 (pytorch#39492)
Summary: Pull Request resolved: pytorch#39492 This PR adds use_c10_dispatcher: full to ops taking TensorOptions. To allow this, since the c10 operator library doesn't know about TensorOptions, we need to register the operator kernels as optional<ScalarType>, optional<Device>, optional<Layout>, optional<bool> instead, and also call them this way. Changes: Add use_c10_dispatcher: full to those ops Write hacky_wrapper_for_legacy_signatures which takes an old-style kernel (i.e. one written to take TensorOptions) an creates a wrapper kernel for it that takes the scattered optional<ScalarType>, optional<Device>, optional<Layout>, optional<bool> instead. Change codegen so that all op registrations are wrapped into hacky_wrapper_for_legacy_signatures. This is added to all ops but is a no-op if the op doesn't take TensorOptions. This allows us in the future to just change a kernel signature from TensorOptions to the scattered version and have it work without having to touch codegen. Change codegen so that the frontend calls those operators with expanded arguments instead of with a TensorOptions object. This is required because now the kernels are written in this way. This PR does not remove TensorOptions special cases from codegen, but instead it separates kernels from the codegen/frontend issues. After this, kernels can be worked on separately without having to touch codegen and codegen can be worked on without having to touch kernels. Codegen diff: P133121032 ghstack-source-id: 106426630 Test Plan: waitforsandcastle Differential Revision: D21581908 fbshipit-source-id: 6d4a9f526fd70fae40581bf26f3ccf794ce6a89e
1 parent f6b9848 commit b623bde

26 files changed

+546
-64
lines changed

aten/src/ATen/common_with_cwrap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# this code should be common among cwrap and ATen preprocessing
22
# for now, I have put it in one place but right now is copied out of cwrap
33

4+
import copy
45

56
def parse_arguments(args):
67
new_args = []
@@ -50,11 +51,16 @@ def set_declaration_defaults(declaration):
5051
declaration['unqual_operator_name_with_overload'] = ''
5152
# Simulate multiple dispatch, even if it's not necessary
5253
if 'options' not in declaration:
53-
declaration['options'] = [{'arguments': declaration['arguments']}]
54+
declaration['options'] = [{
55+
'arguments': copy.deepcopy(declaration['arguments']),
56+
'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
57+
}]
5458
del declaration['arguments']
59+
del declaration['schema_order_arguments']
5560
# Parse arguments (some of them can be strings)
5661
for option in declaration['options']:
5762
option['arguments'] = parse_arguments(option['arguments'])
63+
option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
5864
# Propagate defaults from declaration to options
5965
for option in declaration['options']:
6066
for k, v in declaration.items():

aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ using supported_primitive_arg_types = guts::typelist::typelist<
4242
at::Tensor,
4343
at::Scalar,
4444
c10::QScheme,
45-
c10::ScalarType
45+
c10::ScalarType,
46+
c10::Device,
47+
c10::Layout,
48+
c10::MemoryFormat
4649
>;
4750

4851
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {

aten/src/ATen/core/jit_type.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,9 +1517,12 @@ namespace detail {
15171517
template <typename T>
15181518
struct getTypePtr_ final {
15191519
static TypePtr call() {
1520-
if (!isCustomClassRegistered<T>()) {
1521-
throw c10::Error("Type could not be converted to any of the known types.", "");
1522-
}
1520+
TORCH_CHECK(
1521+
isCustomClassRegistered<T>(),
1522+
"Type ",
1523+
c10::util::get_fully_qualified_type_name<T>(),
1524+
" could not be converted to any of the known types."
1525+
);
15231526
auto res = getCustomClassType<T>();
15241527
return std::dynamic_pointer_cast<Type>(std::move(res));
15251528
}
@@ -1557,6 +1560,24 @@ struct getTypePtr_<c10::ScalarType> final {
15571560
}
15581561
};
15591562
template <>
1563+
struct getTypePtr_<c10::Device> final {
1564+
static TypePtr call() {
1565+
return DeviceObjType::get();
1566+
}
1567+
};
1568+
template <>
1569+
struct getTypePtr_<c10::Layout> final {
1570+
static TypePtr call() {
1571+
return IntType::get();
1572+
}
1573+
};
1574+
template <>
1575+
struct getTypePtr_<c10::MemoryFormat> final {
1576+
static TypePtr call() {
1577+
return IntType::get();
1578+
}
1579+
};
1580+
template <>
15601581
struct getTypePtr_<bool> final {
15611582
static TypePtr call() {
15621583
return BoolType::get();
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#pragma once
2+
3+
#include <c10/util/Metaprogramming.h>
4+
#include <c10/util/TypeList.h>
5+
#include <c10/core/TensorOptions.h>
6+
#include <c10/core/CompileTimeFunctionPointer.h>
7+
8+
// This file defines hacky_wrapper_for_legacy_signatures, which takes a kernel written in a legacy way
9+
// (e.g. with TensorOptions packed) and wraps it into a kernel with the signature expected by
10+
// the PyTorch operator library. The intention is to ultimately rewrite kernels to take the new signature
11+
// and then delete this file. This transition process can happen kernel-by-kernel, since this wrapper
12+
// is a no-op for kernels that already have a non-legacy signature.
13+
14+
namespace c10 {
15+
namespace impl {
16+
17+
inline c10::optional<MemoryFormat> process_memory_format(const TensorOptions& options, c10::optional<MemoryFormat> memory_format) {
18+
TORCH_CHECK(
19+
!(options.has_memory_format() && memory_format.has_value()),
20+
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
21+
"the redundant setter.");
22+
if (memory_format.has_value()) {
23+
return memory_format;
24+
} else {
25+
return options.memory_format_opt();
26+
}
27+
}
28+
29+
namespace detail {
30+
31+
// with_scattered_tensor_options takes a function pointer that potentially takes a TensorOptions argument.
32+
// If it does, then it creates a new function pointer that takes scattered arguments, internally
33+
// gathers those arguments, and then calls the underlying function pointer. If the underlying
34+
// function pointer does not take a TensorOptions argument, it is passed through unmodified.
35+
36+
template<class Type, class Enable = void> struct is_tensoroptions_arg : std::false_type {};
37+
template<class Type> struct is_tensoroptions_arg<Type, std::enable_if_t<std::is_same<TensorOptions, std::decay_t<Type>>::value>> : std::true_type {};
38+
template<class Type>
39+
using is_tensoroptions_arg_t = typename is_tensoroptions_arg<Type>::type;
40+
41+
template<class FuncType>
42+
inline constexpr bool has_tensoroptions_arg() {
43+
using parameter_types = typename guts::infer_function_traits_t<FuncType>::parameter_types;
44+
constexpr size_t num_tensoroptions_args = guts::typelist::count_if<is_tensoroptions_arg_t, parameter_types>::value;
45+
static_assert(num_tensoroptions_args <= 1, "Function has multiple TensorOptions parameters. We support at most one.");
46+
return num_tensoroptions_args > 0;
47+
}
48+
49+
// sanity checks
50+
static_assert(has_tensoroptions_arg<int (int64_t, const TensorOptions&)>(), "");
51+
static_assert(has_tensoroptions_arg<int (int64_t, TensorOptions)>(), "");
52+
static_assert(!has_tensoroptions_arg<int (int64_t, std::string)>(), "");
53+
54+
template<class FuncPtr, class ParametersBeforeTensorOptions, class ParametersAfterTensorOptions> struct with_scattered_tensor_options_;
55+
56+
template<class FuncPtr, class Enable = void>
57+
struct with_scattered_tensor_options final {};
58+
59+
template<class UnderlyingFuncPtr>
60+
struct with_scattered_tensor_options<UnderlyingFuncPtr, std::enable_if_t<!has_tensoroptions_arg<typename UnderlyingFuncPtr::FuncType>()>> final {
61+
// FuncType does not have TensorOptions arguments.
62+
// Don't wrap anything but just return the base pointer.
63+
using FuncPtr = UnderlyingFuncPtr;
64+
};
65+
66+
template<class UnderlyingFuncPtr>
67+
struct with_scattered_tensor_options<UnderlyingFuncPtr, std::enable_if_t<has_tensoroptions_arg<typename UnderlyingFuncPtr::FuncType>()>> final {
68+
private:
69+
// FuncType has TensorOptions arguments.
70+
// Return a function pointer to a wrapper function that replaces those with expanded arguments.
71+
using gathered_parameter_types = typename guts::infer_function_traits_t<typename UnderlyingFuncPtr::FuncType>::parameter_types;
72+
static constexpr size_t tensoroptions_arg_index =
73+
guts::typelist::find_if<
74+
gathered_parameter_types,
75+
is_tensoroptions_arg_t
76+
>::value;
77+
78+
using parameters_before_tensoroptions =
79+
guts::typelist::take_t<gathered_parameter_types, tensoroptions_arg_index>;
80+
using parameters_after_tensoroptions =
81+
guts::typelist::drop_t<gathered_parameter_types, tensoroptions_arg_index + 1>;
82+
83+
using wrapper = with_scattered_tensor_options_<UnderlyingFuncPtr, parameters_before_tensoroptions, parameters_after_tensoroptions>;
84+
public:
85+
using FuncPtr = TORCH_FN_TYPE(&wrapper::wrapper);
86+
};
87+
88+
template<class FuncPtr, class... ParametersBeforeTensorOptions, class... ParametersAfterTensorOptions>
89+
struct with_scattered_tensor_options_<FuncPtr, guts::typelist::typelist<ParametersBeforeTensorOptions...>, guts::typelist::typelist<ParametersAfterTensorOptions...>> final {
90+
static decltype(auto) wrapper(
91+
ParametersBeforeTensorOptions... parameters_before,
92+
optional<ScalarType> scalar_type,
93+
optional<Layout> layout,
94+
optional<Device> device,
95+
optional<bool> pin_memory,
96+
ParametersAfterTensorOptions... parameters_after) {
97+
return (*FuncPtr::func_ptr())(
98+
std::forward<ParametersBeforeTensorOptions>(parameters_before)...,
99+
TensorOptions().dtype(scalar_type).device(device).layout(layout).pinned_memory(pin_memory),
100+
std::forward<ParametersAfterTensorOptions>(parameters_after)...
101+
);
102+
}
103+
};
104+
105+
}
106+
107+
template<class FuncPtr>
108+
constexpr auto hacky_wrapper_for_legacy_signatures(FuncPtr) {
109+
return typename detail::with_scattered_tensor_options<FuncPtr>::FuncPtr();
110+
};
111+
112+
}
113+
}

aten/src/ATen/cwrap_parser.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import yaml
2+
import copy
3+
24
try:
35
# use faster C loader if available
46
from yaml import CLoader as Loader
@@ -24,4 +26,13 @@ def parse(filename):
2426
declarations.append(declaration)
2527
elif in_declaration:
2628
declaration_lines.append(line)
29+
declarations = [process_declaration(declaration) for declaration in declarations]
2730
return declarations
31+
32+
def process_declaration(declaration):
33+
declaration = copy.deepcopy(declaration)
34+
if "arguments" in declaration:
35+
declaration["schema_order_arguments"] = copy.deepcopy(declaration["arguments"])
36+
if "options" in declaration:
37+
declaration["options"] = [process_declaration(option) for option in declaration["options"]]
38+
return declaration

0 commit comments

Comments
 (0)