Skip to content

Commit c2bf3be

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] Remove _get_grid_fn_str (pytorch#146800)
Pull Request resolved: pytorch#146800 Approved by: https://github.com/yanboliang
1 parent 0d5fb09 commit c2bf3be

File tree

3 files changed

+4
-11
lines changed

3 files changed

+4
-11
lines changed

tools/build_with_debinfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def main() -> None:
113113
print("More than 100 items needs to be rebuild, run `ninja torch_python` first")
114114
sys.exit(-1)
115115
for idx, (name, cmd) in enumerate(build_plan):
116-
print(f"[{idx + 1 } / {len(build_plan)}] Building {name}")
116+
print(f"[{idx + 1} / {len(build_plan)}] Building {name}")
117117
if args.verbose:
118118
print(cmd)
119119
subprocess.check_call(["sh", "-c", cmd], cwd=BUILD_DIR)

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3615,9 +3615,6 @@ def is_static_integer(expr: sympy.Expr) -> bool:
36153615
if tree.prefix == "x" and self.no_x_dim:
36163616
code.writeline("XBLOCK: tl.constexpr = 1")
36173617

3618-
def _get_grid_fn_str(self):
3619-
return self._get_grid_fn().__name__
3620-
36213618
def _get_grid_fn(self):
36223619
if self.cooperative_reduction:
36233620
return cooperative_reduction_grid
@@ -3648,9 +3645,8 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None):
36483645
for ws in self.args.workspace_args:
36493646
wrapper.generate_workspace_allocation(ws)
36503647

3651-
grid = wrapper.generate_default_grid(
3652-
name, grid, grid_callable=self._get_grid_fn()
3653-
)
3648+
grid_fn = self._get_grid_fn()
3649+
grid = wrapper.generate_default_grid(name, grid, grid_callable=grid_fn)
36543650
wrapper.generate_kernel_call(
36553651
name,
36563652
call_args,
@@ -3659,7 +3655,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None):
36593655
gpu=current_device.type != "cpu",
36603656
triton=True,
36613657
arg_types=arg_types,
3662-
grid_fn=self._get_grid_fn_str(),
3658+
grid_fn=grid_fn.__name__,
36633659
triton_meta=self.triton_meta,
36643660
)
36653661

torch/_inductor/codegen/triton_split_scan.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,5 @@ def scan(self, dtypes, combine_fn, values):
203203
def _get_heuristic(self):
204204
return "split_scan"
205205

206-
def _get_grid_fn_str(self):
207-
return "split_scan_grid"
208-
209206
def _get_grid_fn(self):
210207
return split_scan_grid

0 commit comments

Comments
 (0)