-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdebug_triton_floordiv.py
104 lines (87 loc) · 3.55 KB
/
debug_triton_floordiv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_root/sp/cspucvcwaoybwak2cpgpdnpw272p7qtvwz6zk73yqugonheagwuw.py
# Source Nodes: [y], Original ATen: [aten.add]
# y => add
triton_poi_fused_add_0 = async_compile.triton('triton_', '''
import math
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': []},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tl.static_print("ks0:", ks0)
# tmp2 = tl.math.floor((1.0 * ks0) / 8.0)
tmp2 = tl.semantic.floor((1.0 * ks0) / 8.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp4, None)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
s0 = arg0_1
assert_size_stride(arg1_1, (), ())
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((), device='cuda', dtype=torch.float32)
# Source Nodes: [y], Original ATen: [aten.add]
stream0 = get_cuda_stream(0)
triton_poi_fused_add_0.run(arg1_1, buf0, s0, 1, grid=grid(1), stream=stream0)
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = 12
arg1_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
# from torch._inductor.wrapper_benchmark import compiled_module_main
# compiled_module_main('None', benchmark_compiled_module)
import torch
arg0_1 = 33
arg1_1 = torch.tensor(0, dtype=torch.float32, device="cuda:0")
fn = lambda: call([arg0_1, arg1_1])
print(fn())