1515
1616from distutils .util import strtobool
1717
18+ from benchmark_autotuner import auto_tf_func_tuner
19+
1820from benchmark_utils import DataAggregator
19- from benchmark_utils import force_gpu_resync
2021from benchmark_utils import print_dict
2122from benchmark_utils import timed_section
2223
@@ -383,16 +384,14 @@ def execute_benchmark(self):
383384 dataset , bypass_data_to_eval = self .get_dataset_batches ()
384385
385386 if self ._args .use_synthetic_data :
386- old_ds = dataset
387387 try :
388- dataset = SyntheticDataset (old_ds , device = "/gpu:0" )
388+ dataset = SyntheticDataset (dataset , device = "/gpu:0" )
389389 self ._debug_print (
390390 "Model dataset has been replaced by a synthetic data "
391391 "loader to minimize data loading jitter."
392392 )
393393
394394 except Exception as e :
395- dataset = old_ds
396395 print (
397396 f"[ERROR] Impossible to transform the dataset into a "
398397 f"synthetic dataset. Performance numbers will be "
@@ -401,8 +400,10 @@ def execute_benchmark(self):
401400 else :
402401 dataset = ensure_dataset_on_gpu (dataset , device = "GPU:0" )
403402
404- @force_gpu_resync
405- @tf .function (jit_compile = self ._args .use_xla )
403+ @auto_tf_func_tuner (
404+ use_xla = self ._args .use_xla ,
405+ use_synthetic_data = self ._args .use_synthetic_data
406+ )
406407 def infer_batch (x ):
407408 if isinstance (x , (tuple , list )):
408409 model_out = graph_func (* x )
@@ -439,72 +440,112 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
439440 )
440441
441442 if self ._args .tf_profile_export_path :
442- profiling_ctx = tf .profiler .experimental .Profile (
443- self ._args .tf_profile_export_path
444- )
443+ def start_profiling ():
444+ if self ._args .tf_profile_verbose :
445+ profiler_opts = tf .profiler .experimental .ProfilerOptions (
446+ # Ajust TraceMe levels:
447+ # - 1: critical
448+ # - 2: info [default]
449+ # - 3: verbose
450+ host_tracer_level = 2 ,
451+ # Enables python function call tracing
452+ # - 0: disable [default]
453+ # - 1: enable
454+ python_tracer_level = 1 ,
455+ # Adjust device (TPU/GPU) tracer level:
456+ # - 0: disable
457+ # - 1: enable [default]
458+ device_tracer_level = 1 ,
459+ # start profiling after 15 sec.
460+ # - Skip tf.function building
461+ # - Skip autotuning
462+ delay_ms = 30000
463+ )
464+ print ("[INFO] Using verbose TF Profiler." )
465+ else :
466+ profiler_opts = None
467+
468+ profiling_ctx = tf .profiler .experimental .start (
469+ self ._args .tf_profile_export_path ,
470+ options = profiler_opts
471+ )
472+
473+ stop_profiling = tf .profiler .experimental .stop
474+
445475 tracing_ctx = tf .profiler .experimental .Trace
476+
446477 else :
478+ start_profiling = stop_profiling = lambda * a , ** kw : None
447479 profiling_ctx = contextlib .nullcontext ()
448480 tracing_ctx = lambda * a , ** kw : contextlib .nullcontext ()
449481
450482 step_idx = 0
451483 ds_iter = iter (dataset )
452484
453- dequeue_batch_fn = get_dequeue_batch_fn (ds_iter )
485+ dequeue_batch_fn = get_dequeue_batch_fn (
486+ ds_iter ,
487+ use_xla = self ._args .use_xla ,
488+ use_synthetic_data = self ._args .use_synthetic_data
489+ )
490+
454491 force_data_on_gpu_fn = get_force_data_on_gpu_fn (
455492 device = "/gpu:0" ,
456- use_xla = self ._args .use_xla
493+ use_xla = self ._args .use_xla ,
494+ use_synthetic_data = self ._args .use_synthetic_data
457495 )
458496
459- with profiling_ctx :
460-
461- while True :
462-
463- step_idx += 1
497+ while True :
464498
465- if (self ._args .num_iterations is not None and
466- step_idx > self ._args .num_iterations ):
467- break
468-
469- with tracing_ctx ('Inference Step' , step_num = step_idx , _r = 1 ):
499+ step_idx += 1
470500
471- with tracing_ctx ('Input Dequeueing' , step_num = step_idx , _r = 1 ):
472- try :
473- start_time = time .time ()
474- data_batch = dequeue_batch_fn ()
475- dequeue_times .append (time .time () - start_time )
476- except (StopIteration , OutOfRangeError ):
477- print ("[Exiting] Reached end of dataset ..." )
478- break
501+ if step_idx == self ._args .num_warmup_iterations - 5 :
502+ start_profiling ()
479503
480- with tracing_ctx ('Inputs Preprocessing' , step_num = step_idx , _r = 1 ):
481- x , y = self .preprocess_model_inputs (data_batch )
504+ if (
505+ self ._args .num_iterations is not None and
506+ step_idx > self ._args .num_iterations
507+ ):
508+ break
482509
483- with tracing_ctx ('Inputs MemcpyHtoD' , step_num = step_idx , _r = 1 ):
484- start_time = time .time ()
485- x = force_data_on_gpu_fn (x )
486- memcopy_times .append (time .time () - start_time )
510+ with tracing_ctx ('' , step_num = step_idx , _r = 1 ):
487511
488- with tracing_ctx ('GPU Inference' , step_num = step_idx , _r = 1 ):
512+ with tracing_ctx ('Input Dequeueing' ):
513+ try :
489514 start_time = time .time ()
490- y_pred = infer_batch (x )
491- iter_times .append (time .time () - start_time )
492-
493- if not self ._args .debug_performance :
494- log_step (
495- step_idx ,
496- display_every = self ._args .display_every ,
497- iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
498- memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
499- dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
500- )
501- else :
502- print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
503- print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
504- print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
515+ data_batch = dequeue_batch_fn ()
516+ dequeue_times .append (time .time () - start_time )
517+ except (StopIteration , OutOfRangeError ):
518+ print ("[Exiting] Reached end of dataset ..." )
519+ break
520+
521+ with tracing_ctx ('Inputs Preprocessing' ):
522+ x , y = self .preprocess_model_inputs (data_batch )
523+
524+ with tracing_ctx ('Inputs MemcpyHtoD' ):
525+ start_time = time .time ()
526+ x = force_data_on_gpu_fn (x )
527+ memcopy_times .append (time .time () - start_time )
528+
529+ with tracing_ctx ('GPU Inference' ):
530+ start_time = time .time ()
531+ y_pred = infer_batch (x )
532+ iter_times .append (time .time () - start_time )
533+
534+ if not self ._args .debug_performance :
535+ log_step (
536+ step_idx ,
537+ display_every = self ._args .display_every ,
538+ iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
539+ memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
540+ dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
541+ )
542+ else :
543+ print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
544+ print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
545+ print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
505546
506- if not self ._args .use_synthetic_data :
507- data_aggregator .aggregate_data (y_pred , y )
547+ if not self ._args .use_synthetic_data :
548+ data_aggregator .aggregate_data (y_pred , y )
508549
509550 if (
510551 not self ._args .debug_performance and
@@ -518,6 +559,9 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
518559 dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
519560 )
520561
562+ if step_idx >= 100 :
563+ stop_profiling ()
564+
521565 with timed_section ("Metric Computation" ):
522566
523567 metrics = dict ()
0 commit comments