Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions examples/diffusers/quantization/diffusion_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse

import numpy as np
import torch
from onnx_utils.export import (
generate_dummy_inputs_and_dynamic_axes_and_shapes,
Expand Down Expand Up @@ -49,6 +50,7 @@
}


@torch.inference_mode()
def generate_image(pipe, prompt, image_name):
seed = 42
image = pipe(
Expand All @@ -61,56 +63,56 @@ def generate_image(pipe, prompt, image_name):
print(f"Image generated saved as {image_name}")


def benchmark_model(
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16
@torch.inference_mode()
def benchmark_backbone_standalone(
pipe,
num_warmup=10,
num_benchmark=100,
model_name="flux-dev",
):
"""Benchmark the backbone model inference time."""
"""Benchmark the backbone model directly without running the full pipeline."""
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet

backbone_times = []
# Generate dummy inputs for the backbone
dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone)

# Extract the dict from the tuple and move to cuda
dummy_inputs_dict = {
k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items()
}

# Warmup
print(f"Warming up: {num_warmup} iterations")
for _ in tqdm(range(num_warmup), desc="Warmup"):
_ = backbone(**dummy_inputs_dict)

# Benchmark
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

def forward_pre_hook(_module, _input):
print(f"Benchmarking: {num_benchmark} iterations")
times = []
for _ in tqdm(range(num_benchmark), desc="Benchmark"):
torch.cuda.profiler.cudart().cudaProfilerStart()
start_event.record()

def forward_hook(_module, _input, _output):
_ = backbone(**dummy_inputs_dict)
end_event.record()
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to call sync here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The synchronization call is needed. Or we run into this error:

RuntimeError: Both events must be completed before calculating elapsed time.

backbone_times.append(start_event.elapsed_time(end_event))

pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
post_handle = backbone.register_forward_hook(forward_hook)

try:
print(f"Starting warmup: {num_warmup} runs")
for _ in tqdm(range(num_warmup), desc="Warmup"):
with torch.amp.autocast("cuda", dtype=model_dtype):
_ = pipe(
prompt,
output_type="pil",
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(42),
)

backbone_times.clear()

print(f"Starting benchmark: {num_runs} runs")
for _ in tqdm(range(num_runs), desc="Benchmark"):
with torch.amp.autocast("cuda", dtype=model_dtype):
_ = pipe(
prompt,
output_type="pil",
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(42),
)
finally:
pre_handle.remove()
post_handle.remove()

total_backbone_time = sum(backbone_times)
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
torch.cuda.profiler.cudart().cudaProfilerStop()
times.append(start_event.elapsed_time(end_event))

avg_latency = sum(times) / len(times)
p50 = np.percentile(times, 50)
p95 = np.percentile(times, 95)
p99 = np.percentile(times, 99)

print("\nBackbone-only inference latency:")
print(f" Average: {avg_latency:.2f} ms")
print(f" P50: {p50:.2f} ms")
print(f" P95: {p95:.2f} ms")
print(f" P99: {p99:.2f} ms")

return avg_latency


Expand Down Expand Up @@ -196,7 +198,12 @@ def main():
pipe.to("cuda")

if args.benchmark:
benchmark_model(pipe, args.prompt, model_dtype=model_dtype)
benchmark_backbone_standalone(
pipe,
num_warmup=10,
num_benchmark=100,
model_name=args.model,
)

if not args.skip_image:
generate_image(pipe, args.prompt, image_name)
Expand Down
3 changes: 1 addition & 2 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import torch.nn as nn
from onnx import ModelProto
from onnxconverter_common import convert_float_to_float16
from packaging.version import Version
from torch.nn.parallel import DataParallel, DistributedDataParallel

from modelopt.onnx.autocast.convert import convert_to_f16
Expand Down Expand Up @@ -443,7 +442,7 @@ def get_onnx_bytes_and_metadata(
)
with torch.inference_mode(), autocast, quantizer_context:
additional_kwargs = {}
if not dynamo_export and Version(torch.__version__) >= Version("2.8"):
if not dynamo_export:
additional_kwargs["dynamic_axes"] = dynamic_axes
torch.onnx.export(
model,
Expand Down