Skip to content

Commit 4adffe6

Browse files
larryliu0820pytorchmergebot
authored andcommitted
[torchgen] Let native function declaration generation logic take a callable (pytorch#90780)
Retry of pytorch#90590, which is a retry of pytorch#89594. Original PR reverted due to internal breakage. This PR fixes the breakage by adding a default value to the new argument. This PR allows `get_native_function_declarations` API to take a function as argument. This function should take `NativeFunction` as input and emit code for native function declaration. By default it is `dest.compute_native_function_declaration`. Pull Request resolved: pytorch#90780 Approved by: https://github.com/ezyang
1 parent df58020 commit 4adffe6

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

tools/test/test_codegen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import yaml
1010

1111
from tools.autograd import gen_autograd_functions, load_derivatives
12+
from torchgen import dest
1213
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
1314
from torchgen.context import native_function_manager
1415
from torchgen.gen import (
@@ -356,6 +357,7 @@ def test_native_function_declaration_1_op_2_ns_error(self) -> None:
356357
self.op_2_native_function,
357358
],
358359
backend_indices=self.backend_indices,
360+
native_function_decl_gen=dest.compute_native_function_declaration,
359361
)
360362

361363
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
@@ -365,6 +367,7 @@ def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
365367
self.op_1_native_function,
366368
],
367369
backend_indices=self.backend_indices,
370+
native_function_decl_gen=dest.compute_native_function_declaration,
368371
)
369372
target = """
370373
namespace at {

torchgen/gen.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import pathlib
66
from collections import defaultdict, namedtuple, OrderedDict
77
from dataclasses import dataclass
8-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
8+
from typing import (
9+
Any,
10+
Callable,
11+
Dict,
12+
List,
13+
Optional,
14+
Sequence,
15+
Set,
16+
Tuple,
17+
TypeVar,
18+
Union,
19+
)
920

1021
import yaml
1122
from typing_extensions import Literal
@@ -1406,7 +1417,17 @@ def get_native_function_declarations(
14061417
*,
14071418
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
14081419
backend_indices: Dict[DispatchKey, BackendIndex],
1420+
native_function_decl_gen: Callable[
1421+
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
1422+
] = dest.compute_native_function_declaration,
14091423
) -> List[str]:
1424+
"""
1425+
Generate kernel declarations, in `NativeFunction(s).h`.
1426+
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
1427+
:param backend_indices: kernel collections grouped by dispatch key.
1428+
:param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
1429+
:return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
1430+
"""
14101431
declarations: List[str] = []
14111432
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
14121433
newline = "\n"
@@ -1425,7 +1446,7 @@ def get_native_function_declarations(
14251446
len(native_function_namespaces) <= 1
14261447
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
14271448
ns_grouped_kernels[namespace].extend(
1428-
dest.compute_native_function_declaration(f, backend_idx)
1449+
native_function_decl_gen(f, backend_idx)
14291450
)
14301451

14311452
for namespace, kernels in ns_grouped_kernels.items():
@@ -1863,7 +1884,9 @@ def gen_per_operator_headers(
18631884
},
18641885
)
18651886
declarations = get_native_function_declarations(
1866-
grouped_native_functions=grouped_functions, backend_indices=backend_indices
1887+
grouped_native_functions=grouped_functions,
1888+
backend_indices=backend_indices,
1889+
native_function_decl_gen=dest.compute_native_function_declaration,
18671890
)
18681891
ops_fm.write_with_template(
18691892
f"{name}_native.h",

0 commit comments

Comments
 (0)