22Original code by @bertmaher; profiling added by @apgoucher
33"""
44
5+ import argparse
56import cProfile
7+ import csv
8+ import os
69import pstats
710import time
811
@@ -42,11 +45,11 @@ def nop_args(
4245def do_bench_walltime (fn ):
4346 print ("Compiling..." )
4447 fn ()
45- torch .cuda .synchronize ()
48+ torch .xpu .synchronize ()
4649
4750 for _ in range (1000 ):
4851 fn ()
49- torch .cuda .synchronize ()
52+ torch .xpu .synchronize ()
5053
5154 n_repeat = 10000
5255
@@ -55,11 +58,11 @@ def do_bench_walltime(fn):
5558 for _ in range (25 ):
5659 print ("Running %d benchmarking iterations..." % n_repeat )
5760 # Benchmark
58- torch .cuda .synchronize ()
61+ torch .xpu .synchronize ()
5962 start_time = time .time ()
6063 for _ in range (n_repeat ):
6164 fn ()
62- torch .cuda .synchronize ()
65+ torch .xpu .synchronize ()
6366 end_time = time .time ()
6467 wall_time_ms = (end_time - start_time ) * 1e3 / n_repeat
6568 mses .append (wall_time_ms )
@@ -71,19 +74,19 @@ def do_bench_walltime(fn):
7174 profile .enable ()
7275 for _ in range (n_repeat ):
7376 fn ()
74- torch .cuda .synchronize ()
77+ torch .xpu .synchronize ()
7578 profile .disable ()
7679 stats = pstats .Stats (profile )
7780 stats .sort_stats ("time" )
7881 stats .print_stats ()
7982 return mses
8083
8184
82- def main (use_tensor_desc : bool ):
85+ def main (use_tensor_desc : bool , reports_dir : str = None ):
8386 if use_tensor_desc :
84- targs = [TensorDescriptor .from_tensor (torch .zeros (1 , 16 , device = "cuda " ), block_shape = [1 , 16 ]) for _ in range (5 )]
87+ targs = [TensorDescriptor .from_tensor (torch .zeros (1 , 16 , device = "xpu " ), block_shape = [1 , 16 ]) for _ in range (5 )]
8588 else :
86- targs = [torch .zeros (1 , device = "cuda " ) for _ in range (5 )]
89+ targs = [torch .zeros (1 , device = "xpu " ) for _ in range (5 )]
8790 ncargs = [0 , 1 , 1024 , 2 ** 31 - 1 , 2 ** 64 - 1 , False , True , None , (16 , 16 )]
8891 cargs = [32 , False , True , 0 , 64 ]
8992
@@ -94,9 +97,26 @@ def main(use_tensor_desc: bool):
9497 print (usecs )
9598 print (sorted (usecs )[len (usecs ) >> 1 ])
9699
100+ if reports_dir :
101+ os .makedirs (reports_dir , exist_ok = True )
102+ csv_path = os .path .join (reports_dir , "launch_overhead_results.csv" )
103+ file_exists = os .path .exists (csv_path )
104+
105+ with open (csv_path , "a" , newline = "" ) as csvfile :
106+ writer = csv .writer (csvfile )
107+ if not file_exists :
108+ writer .writerow (["input_type" , "triton-time_us" ])
109+
110+ input_type = "TensorDescriptor" if use_tensor_desc else "Tensor"
111+ writer .writerow ([input_type , round (sorted (usecs )[len (usecs ) >> 1 ], 2 )])
112+
97113
98114if __name__ == "__main__" :
115+ parser = argparse .ArgumentParser (description = "Benchmark launch overhead for Triton kernels" )
116+ parser .add_argument ("--reports" , type = str , default = None , help = "Path to directory for CSV reports" )
117+ args = parser .parse_args ()
118+
99119 print ("launch overhead of kernel with Tensor inputs" )
100- main (use_tensor_desc = False )
120+ main (use_tensor_desc = False , reports_dir = args . reports )
101121 print ("launch overhead of kernel with TensorDescriptor inputs" )
102- main (use_tensor_desc = True )
122+ main (use_tensor_desc = True , reports_dir = args . reports )
0 commit comments