55""" 
66# std imports 
77from  argparse  import  ArgumentParser 
8+ import  contextlib 
89import  json 
910import  logging 
1011import  os 
11- import  tempfile 
1212from  typing  import  Optional 
1313
1414# tpl imports 
@@ -30,8 +30,11 @@ def get_args():
3030    parser .add_argument ("input_json" , type = str , help = "Input JSON file containing the test cases." )
3131    parser .add_argument ("-o" , "--output" , type = str , help = "Output JSON file containing the results." )
3232    parser .add_argument ("--scratch-dir" , type = str , help = "If provided, put scratch files here." )
33+     parser .add_argument ("--driver-root" , type = str , help = "Where to look for the driver files, if not in cwd." )
3334    parser .add_argument ("--launch-configs" , type = str , default = "launch-configs.json" , 
3435        help = "config for how to run samples." )
36+     parser .add_argument ("--build-configs" , type = str , default = "build-configs.json" ,
37+         help = "config for how to build samples. If not provided, will use the default build settings for each model." )
3538    parser .add_argument ("--problem-sizes" , type = str , default = "problem-sizes.json" , 
3639        help = "config for how to run samples." )
3740    parser .add_argument ("--yes-to-all" , action = "store_true" , help = "If provided, automatically answer yes to all prompts." )
@@ -56,11 +59,19 @@ def get_args():
5659    parser .add_argument ("--log-runs" , action = "store_true" , help = "Display the stderr and stdout of runs." )
5760    return  parser .parse_args ()
5861
59- def  get_driver (prompt : dict , scratch_dir : Optional [os .PathLike ], launch_configs : dict , problem_sizes : dict , dry : bool , ** kwargs ) ->  DriverWrapper :
62+ def  get_driver (
63+     prompt : dict , 
64+     scratch_dir : Optional [os .PathLike ], 
65+     launch_configs : dict , 
66+     build_configs : dict , 
67+     problem_sizes : dict , 
68+     dry : bool , 
69+     ** kwargs 
70+ ) ->  DriverWrapper :
6071    """ Get the language drive wrapper for this prompt """ 
6172    driver_cls  =  LANGUAGE_DRIVERS [prompt ["language" ]]
6273    return  driver_cls (parallelism_model = prompt ["parallelism_model" ], launch_configs = launch_configs , 
63-         problem_sizes = problem_sizes , scratch_dir = scratch_dir , dry = dry , ** kwargs )
74+         build_configs = build_configs ,  problem_sizes = problem_sizes , scratch_dir = scratch_dir , dry = dry , ** kwargs )
6475
6576def  already_has_results (prompt : dict ) ->  bool :
6677    """ Check if a prompt already has results stored in it. """ 
@@ -102,10 +113,25 @@ def main():
102113    launch_configs  =  load_json (args .launch_configs )
103114    logging .info (f"Loaded launch configs from { args .launch_configs }  )
104115
116+     # load build configs 
117+     build_configs  =  load_json (args .build_configs )
118+     logging .info (f"Loaded build configs from { args .build_configs }  )
119+ 
105120    # load problem sizes 
106121    problem_sizes  =  load_json (args .problem_sizes )
107122    logging .info (f"Loaded problem sizes from { args .problem_sizes }  )
108123
124+     # set driver root; If provided, use user argument. If it's not provided, then check if the PAREVAL_ROOT environment 
125+     # variable is set, then use "${PAREVAL_ROOT}/drivers" as the root. If neither is set, then use the location of  
126+     # this script as the root. 
127+     if  args .driver_root :
128+         DRIVER_ROOT  =  args .driver_root 
129+     elif  "PAREVAL_ROOT"  in  os .environ :
130+         DRIVER_ROOT  =  os .path .join (os .environ ["PAREVAL_ROOT" ], "drivers" )
131+     else :
132+         DRIVER_ROOT  =  os .path .dirname (os .path .abspath (__file__ ))
133+     logging .info (f"Using driver root: { DRIVER_ROOT }  )
134+ 
109135    # gather the list of parallelism models to test 
110136    models_to_test  =  args .include_models  if  args .include_models  else  ["serial" , "omp" , "mpi" , "mpi+omp" , "kokkos" , "cuda" , "hip" ]
111137    if  args .exclude_models :
@@ -139,15 +165,18 @@ def main():
139165            prompt , 
140166            args .scratch_dir , 
141167            launch_configs , 
168+             build_configs ,
142169            problem_sizes ,
143170            args .dry , 
144171            display_build_errors = args .log_build_errors ,
145172            display_runs = args .log_runs ,
146173            early_exit_runs = args .early_exit_runs ,
147174            build_timeout = args .build_timeout ,
148-             run_timeout = args .run_timeout 
175+             run_timeout = args .run_timeout , 
149176        )
150-         driver .test_all_outputs_in_prompt (prompt )
177+ 
178+         with  contextlib .chdir (DRIVER_ROOT ):
179+             driver .test_all_outputs_in_prompt (prompt )
151180
152181        # go ahead and write out outputs now 
153182        if  args .output  and  args .output  !=  '-' :
0 commit comments