diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh
new file mode 100755
index 000000000..ca51d365c
--- /dev/null
+++ b/benchmarks/run_benchmark.sh
@@ -0,0 +1,166 @@
+#!/bin/bash
+set -e
+
+# indicate which model to run
+# e.g. ./run_benchmark.sh sd15,sd21,sdxl or ./run_benchmark.sh all
+run_model=$1
+
+
+
+# set environment variables
+export NEXFORT_GRAPH_CACHE=1
+export NEXFORT_FX_FORCE_TRITON_SDPA=1
+
+
+# model path
+model_dir="/data1/hf_model"
+sd15_path="${model_dir}/stable-diffusion-v1-5"
+sd21_path="${model_dir}/stable-diffusion-2-1"
+sdxl_path="${model_dir}/stable-diffusion-xl-base-1.0"
+sd3_path="/data1/home/zhangxu/stable-diffusion-3-medium-diffusers"
+flux_dev_path="${model_dir}/FLUX.1-dev/snapshots/0ef5fff789c832c5c7f4e127f94c8b54bbcced44"
+flux_schnell_path="${model_dir}/FLUX.1-schnell"
+
+# get current time
+current_time=$(date +"%Y-%m-%d")
+echo "Current time: ${current_time}"
+
+# get NVIDIA GPU name
+gpu_name=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader,nounits | head -n 1 | sed 's/NVIDIA //; s/ /_/g')
+
+# table header
+BENCHMARK_RESULT_TEXT="| Data update date (yyyy-mm-dd) | GPU | Model | HxW | Compiler | Quantization | Iteration speed (it/s) | E2E Time (s) | Max used CUDA memory (GiB) | Warmup time (s) |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n"
+
+
+prompt="beautiful scenery nature glass bottle landscape, purple galaxy bottle"
+quantize_config='{"quant_type": "fp8_e4m3_e4m3_dynamic_per_tensor"}'
+
+# oneflow 没有compiler_config
+#sd15_nexfort_compiler_config=""
+#sd21_nexfort_compiler_config=""
+#sdxl_nexfort_compiler_config=""
+
+sd3_nexfort_compiler_config='{"mode": "max-optimize:max-autotune:low-precision:cache-all", "memory_format": "channels_last"}'
+flux_nexfort_compiler_config='{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}'
+
+
+# benchmark model with one resolution function
+benchmark_model_with_one_resolution() {
+ # model_name is the name of the model
+ model_name=$1
+ # model_path is the path of the model
+ model_path=$2
+ # steps is the number of inference steps
+ steps=$3
+ # compiler is the compiler used, e.g. none, oneflow, nexfort, transform
+ compiler=$4
+ # compiler_config is the compiler config used
+ compiler_config=$5
+ # height and width are the resolution of the image
+ height=$6
+ width=$7
+ # quantize is whether to quantize
+ quantize=$8
+
+ echo "Running ${model_path} ${height}x${width}..."
+
+ # if model_name contains sd3, use sd3 script
+ if [[ "${model_name}" =~ sd3 ]]; then
+ script_path="onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py"
+ # if model_name contains flux, use flux script
+ elif [[ "${model_name}" =~ flux ]]; then
+ script_path="onediff_diffusers_extensions/examples/flux/text_to_image_flux.py"
+ else
+ # otherwise, use sd script
+ script_path="benchmarks/text_to_image.py"
+ fi
+
+ # if quantize is True, add --quantize and --quantize-config
+ if [[ ${quantize} == True ]]; then
+ script_output=$(python3 ${script_path} \
+ --model ${model_path} --variant fp16 --steps ${steps} \
+ --height ${height} --width ${width} --seed 1 \
+ --compiler ${compiler} --compiler-config "${compiler_config}" \
+ --quantize --quantize-config "${quantize_config}" \
+ --prompt "${prompt}" --print-output | tee /dev/tty)
+ else
+ script_output=$(python3 ${script_path} \
+ --model ${model_path} --variant fp16 --steps ${steps} \
+ --height ${height} --width ${width} --seed 1 \
+ --compiler ${compiler} --compiler-config "${compiler_config}" \
+ --prompt "${prompt}" --print-output | tee /dev/tty)
+ fi
+
+ # get inference time, iterations per second, max used cuda memory, warmup time
+ inference_time=$(echo "${script_output}" | grep -oP '(?<=Inference time: )\d+\.\d+')
+ iterations_per_second=$(echo "${script_output}" | grep -oP '(?<=Iterations per second: )\d+\.\d+')
+ max_used_cuda_memory=$(echo "${script_output}" | grep -oP '(?<=Max used CUDA memory : )\d+\.\d+')
+ warmup_time=$(echo "${script_output}" | grep -oP '(?<=Warmup time: )\d+\.\d+')
+
+ # add benchmark result to BENCHMARK_RESULT_TEXT
+ BENCHMARK_RESULT_TEXT="${BENCHMARK_RESULT_TEXT}| "${current_time}" | "${gpu_name}" | "${model_name}" | ${height}x${width} | ${compiler} | ${quantize} | ${iterations_per_second} | ${inference_time} | ${max_used_cuda_memory} | ${warmup_time} |\n"
+}
+
+# conda init
+source ~/miniconda3/etc/profile.d/conda.sh
+
+#########################################
+# if run_model contains sd15 or all, run sd15
+if [[ "${run_model}" =~ sd15|all ]]; then
+ conda activate oneflow
+ benchmark_model_with_one_resolution sd15 ${sd15_path} 30 none none 512 512 False
+ benchmark_model_with_one_resolution sd15 ${sd15_path} 30 oneflow none 512 512 False
+ benchmark_model_with_one_resolution sd15 ${sd15_path} 30 oneflow none 512 512 True
+fi
+
+# if run_model contains sd21 or all, run sd21
+if [[ "${run_model}" =~ sd21|all ]]; then
+ # activate oneflow environment
+ conda activate oneflow
+ benchmark_model_with_one_resolution sd21 ${sd21_path} 20 none none 768 768 False
+ benchmark_model_with_one_resolution sd21 ${sd21_path} 20 oneflow none 768 768 False
+ benchmark_model_with_one_resolution sd21 ${sd21_path} 20 oneflow none 768 768 True
+fi
+
+# if run_model contains sdxl or all, run sdxl
+if [[ "${run_model}" =~ sdxl|all ]]; then
+ # activate oneflow environment
+ conda activate oneflow
+ benchmark_model_with_one_resolution sdxl ${sdxl_path} 30 none none 1024 1024 False
+ benchmark_model_with_one_resolution sdxl ${sdxl_path} 30 oneflow none 1024 1024 False
+ benchmark_model_with_one_resolution sdxl ${sdxl_path} 30 oneflow none 1024 1024 True
+fi
+#########################################
+
+#########################################
+# if run_model contains sd3 or all, run sd3
+if [[ "${run_model}" =~ sd3|all ]]; then
+ conda activate nexfort
+ # activate nexfort environment
+ benchmark_model_with_one_resolution sd3 ${sd3_path} 28 none none 1024 1024 False
+ benchmark_model_with_one_resolution sd3 ${sd3_path} 28 nexfort "${sd3_nexfort_compiler_config}" 1024 1024 False
+ benchmark_model_with_one_resolution sd3 ${sd3_path} 28 nexfort "${sd3_nexfort_compiler_config}" 1024 1024 True
+fi
+
+# if run_model contains flux or all, run flux
+if [[ "${run_model}" =~ flux|all ]]; then
+ # activate nexfort environment
+ conda activate nexfort
+ benchmark_model_with_one_resolution flux_dev ${flux_dev_path} 20 none none 1024 1024 False
+ benchmark_model_with_one_resolution flux_dev ${flux_dev_path} 20 nexfort "${flux_nexfort_compiler_config}" 1024 1024 False
+ benchmark_model_with_one_resolution flux_dev ${flux_dev_path} 20 nexfort "${flux_nexfort_compiler_config}" 1024 1024 True
+ benchmark_model_with_one_resolution flux_dev ${flux_dev_path} 20 transform none 1024 1024 False
+
+
+ benchmark_model_with_one_resolution flux_schnell ${flux_schnell_path} 4 none none 1024 1024 False
+ benchmark_model_with_one_resolution flux_schnell ${flux_schnell_path} 4 nexfort "${flux_nexfort_compiler_config}" 1024 1024 False
+ benchmark_model_with_one_resolution flux_schnell ${flux_schnell_path} 4 nexfort "${flux_nexfort_compiler_config}" 1024 1024 True
+ benchmark_model_with_one_resolution flux_schnell ${flux_schnell_path} 4 transform none 1024 1024 False
+fi
+#########################################
+
+
+echo -e "\nBenchmark Results:"
+# print benchmark result and add benchmark result to markdown file
+echo -e ${BENCHMARK_RESULT_TEXT} | tee -a benchmark_result_"${gpu_name}".md
+echo -e "\nBenchmark Done!"
diff --git a/benchmarks/text_to_image.py b/benchmarks/text_to_image.py
index 85ec6bb43..426597427 100644
--- a/benchmarks/text_to_image.py
+++ b/benchmarks/text_to_image.py
@@ -35,6 +35,7 @@
import torch
from diffusers.utils import load_image
from onediff.infer_compiler import oneflow_compile
+from onediff.optimization.quant_optimizer import quantize_model
from onediffx import ( # quantize_pipe currently only supports the nexfort backend.
compile_pipe,
@@ -252,6 +253,13 @@ def main():
print("Oneflow backend is now active...")
# Note: The compile_pipe() based on the oneflow backend is incompatible with T5EncoderModel.
# pipe = compile_pipe(pipe)
+
+ if args.quantize:
+ if hasattr(pipe, "unet"):
+ pipe.unet = quantize_model(pipe.unet)
+ if hasattr(pipe, "transformer"):
+ pipe.transformer = quantize_model(pipe.transformer)
+
if hasattr(pipe, "unet"):
pipe.unet = oneflow_compile(pipe.unet)
if hasattr(pipe, "transformer"):
diff --git a/onediff_diffusers_extensions/examples/flux/README.md b/onediff_diffusers_extensions/examples/flux/README.md
new file mode 100644
index 000000000..8ff2c16e5
--- /dev/null
+++ b/onediff_diffusers_extensions/examples/flux/README.md
@@ -0,0 +1,129 @@
+# Run Flux with onediff
+
+
+## Environment setup
+
+### Set up onediff
+https://github.com/siliconflow/onediff?tab=readme-ov-file#installation
+
+### Set up compiler backend
+Support two backends: oneflow and nexfort.
+
+https://github.com/siliconflow/onediff?tab=readme-ov-file#install-a-compiler-backend
+
+### Set up flux
+HF model: https://huggingface.co/black-forest-labs/FLUX.1-dev and https://huggingface.co/black-forest-labs/FLUX.1-schnell
+
+HF pipeline: https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
+
+### Set up others
+Install extra pkgs and set environment variable.
+```bash
+pip install --upgrade transformers
+pip install --upgrade diffusers[torch]
+pip install nvidia-cublas-cu12==12.4.5.8
+
+export NEXFORT_FX_FORCE_TRITON_SDPA=1
+```
+
+## Run
+
+### Run FLUX.1-dev 1024*1024 without compile (the original pytorch HF diffusers baseline)
+```
+python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
+--model black-forest-labs/FLUX.1-dev \
+--height 1024 \
+--width 1024 \
+--steps 20 \
+--seed 1 \
+--output-image ./flux.png
+```
+
+### Run FLUX.1-dev 1024*1024 with compile [nexfort backend]
+
+```
+python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
+--model black-forest-labs/FLUX.1-dev \
+--height 1024 \
+--width 1024 \
+--steps 20 \
+--seed 1 \
+--compiler nexfort \
+--compiler-config '{"mode": "max-optimize:max-autotune:low-precision:cache-all", "memory_format": "channels_last"}' \
+--output-image ./flux_nexfort_compile.png
+```
+
+
+### Run FLUX.1-schnell 1024*1024 without compile (the original pytorch HF diffusers baseline)
+```
+python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
+--model black-forest-labs/FLUX.1-schnell \
+--height 1024 \
+--width 1024 \
+--steps 4 \
+--seed 1 \
+--output-image ./flux.png
+```
+
+### Run FLUX.1-schnell 1024*1024 with compile [nexfort backend]
+
+```
+python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
+--model black-forest-labs/FLUX.1-schnell \
+--height 1024 \
+--width 1024 \
+--steps 4 \
+--seed 1 \
+--compiler nexfort \
+--compiler-config '{"mode": "max-optimize:max-autotune:low-precision:cache-all", "memory_format": "channels_last"}' \
+--output-image ./flux_nexfort_compile.png
+```
+
+
+## FLUX.1-dev Performance comparation
+**Testing on NVIDIA H20-SXM4-80GB:**
+
+Data update date: 2024-10-23
+
+| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 1 | Warmup with Cache time (seconds) |
+|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
+| PyTorch | 1.30 | 15.72 | 35.73 | 16.68 | - |
+| OneDiff (NexFort) | 1.76 (+35.4%) | 11.57 (-26.4%) | 34.85 | 750.78 | 28.57 |
+
+ 1 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468V.
+
+**Testing on NVIDIA L20-SXM4-48GB:**
+
+Data update date: 2024-10-28
+
+| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 2 | Warmup with Cache time (seconds) |
+|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
+| PyTorch | 1.10 | 18.45 | 35.71 | 18.695 | - |
+| OneDiff (NexFort) | 1.41 (+28.2%) | 14.44 (-21.7%) | 34.83 | 546.52 | 25.32 |
+
+ 2 OneDiff Warmup with Compilation time is tested on AMD EPYC 9354 32-Core Processor.
+
+
+
+## FLUX.1-schnell Performance comparation
+**Testing on NVIDIA H20-SXM4-80GB:**
+
+Data update date: 2024-10-23
+
+| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 1 | Warmup with Cache time (seconds) |
+|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
+| PyTorch | 1.30 | 3.38 | 35.71 | 4.35 | - |
+| OneDiff (NexFort) | 1.75 (+34.6%) | 2.46 (-27.2%) | 34.83 | 201.41 | 19.57 |
+
+ 1 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468V.
+
+**Testing on NVIDIA L20-SXM4-48GB:**
+
+Data update date: 2024-10-28
+
+| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 2 | Warmup with Cache time (seconds) |
+|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
+| PyTorch | 1.10 | 3.94 | 35.69 | 4.15 | - |
+| OneDiff (NexFort) | 1.41 (+28.2%) | 3.03 (-23.1%) | 34.81 | 145.63 | 13.56 |
+
+ 2 OneDiff Warmup with Compilation time is tested on AMD EPYC 9354 32-Core Processor.
diff --git a/onediff_diffusers_extensions/examples/flux/text_to_image_flux.py b/onediff_diffusers_extensions/examples/flux/text_to_image_flux.py
new file mode 100644
index 000000000..3df6be1d3
--- /dev/null
+++ b/onediff_diffusers_extensions/examples/flux/text_to_image_flux.py
@@ -0,0 +1,446 @@
+MODEL = "black-forest-labs/FLUX.1-dev"
+VARIANT = None
+CUSTOM_PIPELINE = None
+SCHEDULER = "FlowMatchEulerDiscreteScheduler"
+LORA = None
+CONTROLNET = None
+STEPS = 20
+PROMPT = "best quality, realistic, unreal engine, 4K, a beautiful girl"
+SEED = 1
+WARMUPS = 1
+BATCH = 1
+HEIGHT = None
+WIDTH = None
+INPUT_IMAGE = None
+CONTROL_IMAGE = None
+OUTPUT_IMAGE = None
+EXTRA_CALL_KWARGS = None
+CACHE_INTERVAL = 3
+CACHE_LAYER_ID = 0
+CACHE_BLOCK_ID = 0
+COMPILER = "nexfort"
+COMPILER_CONFIG = None
+QUANTIZE_CONFIG = None
+
+import argparse
+import importlib
+import inspect
+import json
+import os
+import time
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from diffusers.utils import load_image
+
+from nexfort.compilers.transform_model import transform_model
+from nexfort.quantization import quantize
+
+from onediffx import ( # quantize_pipe currently only supports the nexfort backend.
+ compile_pipe,
+ quantize_pipe,
+)
+
+from PIL import Image, ImageDraw
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", type=str, default=MODEL)
+ parser.add_argument("--variant", type=str, default=VARIANT)
+ parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
+ parser.add_argument("--scheduler", type=str, default=SCHEDULER)
+ parser.add_argument("--lora", type=str, default=LORA)
+ parser.add_argument("--controlnet", type=str, default=CONTROLNET)
+ parser.add_argument("--steps", type=int, default=STEPS)
+ parser.add_argument("--prompt", type=str, default=PROMPT)
+ parser.add_argument("--seed", type=int, default=SEED)
+ parser.add_argument("--warmups", type=int, default=WARMUPS)
+ parser.add_argument("--batch", type=int, default=BATCH)
+ parser.add_argument("--height", type=int, default=HEIGHT)
+ parser.add_argument("--width", type=int, default=WIDTH)
+ parser.add_argument("--cache_interval", type=int, default=CACHE_INTERVAL)
+ parser.add_argument("--cache_layer_id", type=int, default=CACHE_LAYER_ID)
+ parser.add_argument("--cache_block_id", type=int, default=CACHE_BLOCK_ID)
+ parser.add_argument("--extra-call-kwargs", type=str, default=EXTRA_CALL_KWARGS)
+ parser.add_argument("--input-image", type=str, default=INPUT_IMAGE)
+ parser.add_argument("--control-image", type=str, default=CONTROL_IMAGE)
+ parser.add_argument("--output-image", type=str, default=OUTPUT_IMAGE)
+ parser.add_argument("--print-output", action="store_true")
+ parser.add_argument("--throughput", action="store_true")
+ parser.add_argument("--deepcache", action="store_true")
+ parser.add_argument(
+ "--compiler",
+ type=str,
+ default=COMPILER,
+ choices=["none", "transform", "nexfort", "compile", "compile-max-autotune"],
+ )
+ parser.add_argument(
+ "--compiler-config",
+ type=str,
+ default=COMPILER_CONFIG,
+ )
+ parser.add_argument(
+ "--run_multiple_resolutions",
+ type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
+ default=False,
+ )
+ parser.add_argument("--quantize", action="store_true")
+ parser.add_argument(
+ "--quantize-config",
+ type=str,
+ default=QUANTIZE_CONFIG,
+ )
+ parser.add_argument("--quant-submodules-config-path", type=str, default=None)
+ return parser.parse_args()
+
+
+args = parse_args()
+
+
+def get_gpu_memory():
+ gpu_id = torch.cuda.current_device()
+ total_memory = torch.cuda.get_device_properties(gpu_id).total_memory
+
+ return total_memory / 1024**3
+
+
+def load_pipe(
+ pipeline_cls,
+ model_name,
+ variant=None,
+ dtype=torch.float16,
+ device="cuda",
+ custom_pipeline=None,
+ scheduler=None,
+ lora=None,
+ controlnet=None,
+):
+ extra_kwargs = {}
+ if custom_pipeline is not None:
+ extra_kwargs["custom_pipeline"] = custom_pipeline
+ if variant is not None:
+ extra_kwargs["variant"] = variant
+ if dtype is not None:
+ extra_kwargs["torch_dtype"] = dtype
+ if controlnet is not None:
+ from diffusers import ControlNetModel
+
+ controlnet = ControlNetModel.from_pretrained(
+ controlnet,
+ torch_dtype=dtype,
+ )
+ extra_kwargs["controlnet"] = controlnet
+
+ pipe = pipeline_cls.from_pretrained(model_name, **extra_kwargs)
+
+ if scheduler is not None and scheduler != "none":
+ scheduler_cls = getattr(importlib.import_module("diffusers"), scheduler)
+ pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
+ if lora is not None:
+ pipe.load_lora_weights(lora)
+ pipe.fuse_lora()
+ pipe.safety_checker = None
+ if device is not None and get_gpu_memory() > 24:
+ pipe.to(torch.device(device))
+ return pipe
+
+
+class IterationProfiler:
+ def __init__(self):
+ self.begin = None
+ self.end = None
+ self.num_iterations = 0
+
+ def get_iter_per_sec(self):
+ if self.begin is None or self.end is None:
+ return None
+ self.end.synchronize()
+ dur = self.begin.elapsed_time(self.end)
+ return self.num_iterations / dur * 1000.0
+
+ def callback_on_step_end(self, pipe, i, t, callback_kwargs={}):
+ if self.begin is None:
+ event = torch.cuda.Event(enable_timing=True)
+ event.record()
+ self.begin = event
+ else:
+ event = torch.cuda.Event(enable_timing=True)
+ event.record()
+ self.end = event
+ self.num_iterations += 1
+ return callback_kwargs
+
+
+def calculate_inference_time_and_throughput(height, width, n_steps, model):
+ start_time = time.time()
+ model(prompt=args.prompt, height=height, width=width, num_inference_steps=n_steps)
+ end_time = time.time()
+ inference_time = end_time - start_time
+ # pixels_processed = height * width * n_steps
+ # throughput = pixels_processed / inference_time
+ throughput = n_steps / inference_time
+ return inference_time, throughput
+
+
+def generate_data_and_fit_model(model, steps_range):
+ height, width = 1024, 1024
+ data = {"steps": [], "inference_time": [], "throughput": []}
+
+ for n_steps in steps_range:
+ inference_time, throughput = calculate_inference_time_and_throughput(
+ height, width, n_steps, model
+ )
+ data["steps"].append(n_steps)
+ data["inference_time"].append(inference_time)
+ data["throughput"].append(throughput)
+ print(
+ f"Steps: {n_steps}, Inference Time: {inference_time:.2f} seconds, Throughput: {throughput:.2f} steps/s"
+ )
+
+ average_throughput = np.mean(data["throughput"])
+ print(f"Average Throughput: {average_throughput:.2f} steps/s")
+
+ coefficients = np.polyfit(data["steps"], data["inference_time"], 1)
+ base_time_without_base_cost = 1 / coefficients[0]
+ print(f"Throughput without base cost: {base_time_without_base_cost:.2f} steps/s")
+ return data, coefficients
+
+
+def plot_data_and_model(data, coefficients):
+ plt.figure(figsize=(10, 5))
+ plt.scatter(data["steps"], data["inference_time"], color="blue")
+ plt.plot(data["steps"], np.polyval(coefficients, data["steps"]), color="red")
+ plt.title("Inference Time vs. Steps")
+ plt.xlabel("Steps")
+ plt.ylabel("Inference Time (seconds)")
+ plt.grid(True)
+ # plt.savefig("output.png")
+ plt.show()
+
+ print(
+ f"Model: Inference Time = {coefficients[0]:.2f} * Steps + {coefficients[1]:.2f}"
+ )
+
+
+def main():
+
+ from diffusers import FluxPipeline as pipeline_cls
+
+ pipe = load_pipe(
+ pipeline_cls,
+ args.model,
+ variant=args.variant,
+ custom_pipeline=args.custom_pipeline,
+ scheduler=args.scheduler,
+ lora=args.lora,
+ controlnet=args.controlnet,
+ )
+
+ core_net = None
+ if core_net is None:
+ core_net = getattr(pipe, "unet", None)
+ if core_net is None:
+ core_net = getattr(pipe, "transformer", None)
+ height = args.height or core_net.config.sample_size * pipe.vae_scale_factor
+ width = args.width or core_net.config.sample_size * pipe.vae_scale_factor
+
+ if args.compiler == "none":
+ pass
+ elif args.compiler == "nexfort":
+ print("Nexfort backend is now active...")
+ if args.quantize:
+ if args.quantize_config is not None:
+ quantize_config = json.loads(args.quantize_config)
+ else:
+ quantize_config = '{"quant_type": "fp8_e4m3_e4m3_dynamic_per_tensor"}'
+
+ if get_gpu_memory() > 24:
+ _ = quantize(pipe.transformer, **quantize_config)
+ else:
+ # (TODO:support 4090) for gpu with little memory, such as 4090
+ if hasattr(pipe, "transformer"):
+ pipe.transformer = pipe.transformer.to("cuda")
+ _ = quantize(pipe.transformer, **quantize_config)
+ pipe.transformer = pipe.transformer.to("cpu")
+
+ if hasattr(pipe, "text_encoder_2"):
+ pipe.text_encoder_2 = pipe.text_encoder_2.to("cuda")
+ _ = quantize(pipe.text_encoder_2, **quantize_config) # t5xxl
+ pipe.text_encoder_2 = pipe.text_encoder_2.to("cpu")
+
+ # load pipe to GPU
+ pipe.to("cuda")
+
+ if args.compiler_config is not None:
+ # config with dict
+ options = json.loads(args.compiler_config)
+ else:
+ # config with string
+ options = '{"mode": "max-optimize:max-autotune:low-precision:cache-all", "memory_format": "channels_last"}'
+
+ pipe = compile_pipe(
+ pipe, backend="nexfort", options=options, fuse_qkv_projections=True
+ )
+ elif args.compiler == "transform":
+ if args.quantize:
+ if args.quantize_config is not None:
+ quantize_config = json.loads(args.quantize_config)
+ else:
+ quantize_config = '{"quant_type": "fp8_e4m3_e4m3_dynamic_per_tensor"}'
+
+ if get_gpu_memory() > 24:
+ _ = quantize(pipe.transformer, **quantize_config)
+ else:
+ # for gpu with little memory, such as 4090
+ if hasattr(pipe, "transformer"):
+ pipe.transformer = pipe.transformer.to("cuda")
+ _ = quantize(pipe.transformer, **quantize_config)
+ pipe.transformer = pipe.transformer.to("cpu")
+
+ if hasattr(pipe, "text_encoder_2"):
+ pipe.text_encoder_2 = pipe.text_encoder_2.to("cuda")
+ _ = quantize(pipe.text_encoder_2, **quantize_config) # t5xxl
+ pipe.text_encoder_2 = pipe.text_encoder_2.to("cpu")
+
+ # load pipe to GPU
+ pipe.to("cuda")
+
+ _ = transform_model(pipe.transformer)
+ elif args.compiler in ("compile", "compile-max-autotune"):
+ mode = "max-autotune" if args.compiler == "compile-max-autotune" else None
+ if hasattr(pipe, "unet"):
+ pipe.unet = torch.compile(pipe.unet, mode=mode)
+ if hasattr(pipe, "transformer"):
+ pipe.transformer = torch.compile(pipe.transformer, mode=mode)
+ if hasattr(pipe, "controlnet"):
+ pipe.controlnet = torch.compile(pipe.controlnet, mode=mode)
+ pipe.vae = torch.compile(pipe.vae, mode=mode)
+ else:
+ raise ValueError(f"Unknown compiler: {args.compiler}")
+
+ if args.input_image is None:
+ input_image = None
+ else:
+ input_image = load_image(args.input_image)
+ input_image = input_image.resize((width, height), Image.LANCZOS)
+
+ if args.control_image is None:
+ if args.controlnet is None:
+ control_image = None
+ else:
+ control_image = Image.new("RGB", (width, height))
+ draw = ImageDraw.Draw(control_image)
+ draw.ellipse(
+ (args.width // 4, height // 4, args.width // 4 * 3, height // 4 * 3),
+ fill=(255, 255, 255),
+ )
+ del draw
+ else:
+ control_image = load_image(args.control_image)
+ control_image = control_image.resize((width, height), Image.LANCZOS)
+
+ def get_kwarg_inputs():
+ kwarg_inputs = dict(
+ prompt=args.prompt,
+ height=height,
+ width=width,
+ num_images_per_prompt=args.batch,
+ generator=(
+ None
+ if args.seed is None
+ else torch.Generator(device="cuda").manual_seed(args.seed)
+ ),
+ **(
+ dict()
+ if args.extra_call_kwargs is None
+ else json.loads(args.extra_call_kwargs)
+ ),
+ )
+ if args.steps is not None:
+ kwarg_inputs["num_inference_steps"] = args.steps
+ if input_image is not None:
+ kwarg_inputs["image"] = input_image
+ if control_image is not None:
+ if input_image is None:
+ kwarg_inputs["image"] = control_image
+ else:
+ kwarg_inputs["control_image"] = control_image
+ if args.deepcache:
+ kwarg_inputs["cache_interval"] = args.cache_interval
+ kwarg_inputs["cache_layer_id"] = args.cache_layer_id
+ kwarg_inputs["cache_block_id"] = args.cache_block_id
+ return kwarg_inputs
+
+ # NOTE: Warm it up.
+ # The initial calls will trigger compilation and might be very slow.
+ # After that, it should be very fast.
+ if args.warmups > 0:
+ begin = time.time()
+ print("=======================================")
+ print("Begin warmup")
+ for _ in range(args.warmups):
+ pipe(**get_kwarg_inputs())
+ end = time.time()
+ print("End warmup")
+ print(f"Warmup time: {end - begin:.3f}s")
+ print("=======================================")
+
+ # Let"s see it!
+ # Note: Progress bar might work incorrectly due to the async nature of CUDA.
+ kwarg_inputs = get_kwarg_inputs()
+ iter_profiler = IterationProfiler()
+ if "callback_on_step_end" in inspect.signature(pipe).parameters:
+ kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
+ elif "callback" in inspect.signature(pipe).parameters:
+ kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
+ begin = time.time()
+ output_images = pipe(**kwarg_inputs).images
+ end = time.time()
+
+ print("=======================================")
+ print(f"Inference time: {end - begin:.3f}s")
+ iter_per_sec = iter_profiler.get_iter_per_sec()
+ if iter_per_sec is not None:
+ print(f"Iterations per second: {iter_per_sec:.3f}")
+
+ cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
+ print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB")
+ print("=======================================")
+
+ if args.print_output:
+ from onediff.utils.import_utils import is_nexfort_available
+
+ if is_nexfort_available():
+ from nexfort.utils.term_image import print_image
+
+ for image in output_images:
+ print_image(image, max_width=80)
+
+ if args.output_image is not None:
+ output_images[0].save(args.output_image)
+ else:
+ print("Please set `--output-image` to save the output image")
+
+ if args.run_multiple_resolutions:
+ print("Test run with multiple resolutions...")
+ sizes = [1024, 512, 768, 256]
+ for h in sizes:
+ for w in sizes:
+ kwarg_inputs["height"] = h
+ kwarg_inputs["width"] = w
+ print(f"Running at resolution: {h}x{w}")
+ start_time = time.time()
+ image = pipe(**kwarg_inputs).images
+ end_time = time.time()
+ print(f"Inference time: {end_time - start_time:.2f} seconds")
+
+ if args.throughput:
+ steps_range = range(1, 100, 1)
+ data, coefficients = generate_data_and_fit_model(pipe, steps_range)
+ plot_data_and_model(data, coefficients)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
index c571bd212..2bd33cf57 100644
--- a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
+++ b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
@@ -1,219 +1,406 @@
+MODEL = "stabilityai/stable-diffusion-3-medium-diffusers"
+VARIANT = None
+CUSTOM_PIPELINE = None
+SCHEDULER = "FlowMatchEulerDiscreteScheduler"
+LORA = None
+CONTROLNET = None
+STEPS = 28
+PROMPT = "best quality, realistic, unreal engine, 4K, a beautiful girl"
+NEGATIVE_PROMPT = ""
+SEED = 1
+WARMUPS = 1
+BATCH = 1
+HEIGHT = None
+WIDTH = None
+INPUT_IMAGE = None
+CONTROL_IMAGE = None
+OUTPUT_IMAGE = None
+EXTRA_CALL_KWARGS = None
+CACHE_INTERVAL = 3
+CACHE_LAYER_ID = 0
+CACHE_BLOCK_ID = 0
+COMPILER = "nexfort"
+COMPILER_CONFIG = None
+QUANTIZE_CONFIG = None
+
import argparse
+import importlib
+import inspect
import json
+import os
import time
+import matplotlib.pyplot as plt
+import numpy as np
import torch
-from diffusers import StableDiffusion3Pipeline
-from onediffx import compile_pipe, quantize_pipe
+from diffusers.utils import load_image
+
+from onediffx import ( # quantize_pipe currently only supports the nexfort backend.
+ compile_pipe,
+ quantize_pipe,
+)
+from PIL import Image, ImageDraw
def parse_args():
- parser = argparse.ArgumentParser(
- description="Use onediif (nexfort) to accelerate image generation with Stable Diffusion 3."
- )
- parser.add_argument(
- "--model",
- type=str,
- default="stabilityai/stable-diffusion-3-medium-diffusers",
- help="Model path or identifier.",
- )
- parser.add_argument(
- "--compiler-config", type=str, help="JSON string for compiler config."
- )
- parser.add_argument(
- "--quantize-config", type=str, help="JSON string for quantization config."
- )
- parser.add_argument(
- "--prompt",
- type=str,
- default="photo of a dog and a cat both standing on a red box, with a blue ball in the middle with a parrot standing on top of the ball. The box has the text 'onediff'",
- help="Prompt for the image generation.",
- )
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", type=str, default=MODEL)
+ parser.add_argument("--variant", type=str, default=VARIANT)
+ parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
+ parser.add_argument("--scheduler", type=str, default=SCHEDULER)
+ parser.add_argument("--lora", type=str, default=LORA)
+ parser.add_argument("--controlnet", type=str, default=CONTROLNET)
+ parser.add_argument("--steps", type=int, default=STEPS)
+ parser.add_argument("--prompt", type=str, default=PROMPT)
+ parser.add_argument("--negative-prompt", type=str, default=NEGATIVE_PROMPT)
+ parser.add_argument("--seed", type=int, default=SEED)
+ parser.add_argument("--warmups", type=int, default=WARMUPS)
+ parser.add_argument("--batch", type=int, default=BATCH)
+ parser.add_argument("--height", type=int, default=HEIGHT)
+ parser.add_argument("--width", type=int, default=WIDTH)
+ parser.add_argument("--cache_interval", type=int, default=CACHE_INTERVAL)
+ parser.add_argument("--cache_layer_id", type=int, default=CACHE_LAYER_ID)
+ parser.add_argument("--cache_block_id", type=int, default=CACHE_BLOCK_ID)
+ parser.add_argument("--extra-call-kwargs", type=str, default=EXTRA_CALL_KWARGS)
+ parser.add_argument("--input-image", type=str, default=INPUT_IMAGE)
+ parser.add_argument("--control-image", type=str, default=CONTROL_IMAGE)
+ parser.add_argument("--output-image", type=str, default=OUTPUT_IMAGE)
+ parser.add_argument("--print-output", action="store_true")
+ parser.add_argument("--throughput", action="store_true")
+ parser.add_argument("--deepcache", action="store_true")
parser.add_argument(
- "--negative_prompt",
+ "--compiler",
type=str,
- default="",
- help="Negative prompt for the image generation.",
- )
- parser.add_argument(
- "--height", type=int, default=1024, help="Height of the generated image."
- )
- parser.add_argument(
- "--width", type=int, default=1024, help="Width of the generated image."
+ default=COMPILER,
+ choices=["none", "nexfort", "compile", "compile-max-autotune"],
)
parser.add_argument(
- "--guidance_scale",
- type=float,
- default=4.5,
- help="The scale factor for the guidance.",
- )
- parser.add_argument(
- "--num-inference-steps", type=int, default=28, help="Number of inference steps."
- )
- parser.add_argument(
- "--saved-image",
+ "--compiler-config",
type=str,
- default="./sd3.png",
- help="Path to save the generated image.",
- )
- parser.add_argument(
- "--seed", type=int, default=1, help="Seed for random number generation."
+ default=COMPILER_CONFIG,
)
parser.add_argument(
"--run_multiple_resolutions",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=False,
)
+ parser.add_argument("--quantize", action="store_true")
parser.add_argument(
- "--run_multiple_prompts",
- type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
- default=False,
+ "--quantize-config",
+ type=str,
+ default=QUANTIZE_CONFIG,
)
+ parser.add_argument("--quant-submodules-config-path", type=str, default=None)
return parser.parse_args()
args = parse_args()
-device = torch.device("cuda")
-
-
-def generate_texts(min_length=50, max_length=302):
- base_text = "a female character with long, flowing hair that appears to be made of ethereal, swirling patterns resembling the Northern Lights or Aurora Borealis. The background is dominated by deep blues and purples, creating a mysterious and dramatic atmosphere. The character's face is serene, with pale skin and striking features. She"
-
- additional_words = [
- "gracefully",
- "beautifully",
- "elegant",
- "radiant",
- "mysteriously",
- "vibrant",
- "softly",
- "gently",
- "luminescent",
- "sparkling",
- "delicately",
- "glowing",
- "brightly",
- "shimmering",
- "enchanting",
- "gloriously",
- "magnificent",
- "majestic",
- "fantastically",
- "dazzlingly",
- ]
-
- for i in range(min_length, max_length):
- idx = i % len(additional_words)
- base_text += " " + additional_words[idx]
- yield base_text
-
-
-class SD3Generator:
- def __init__(self, model, compiler_config=None, quantize_config=None):
- self.pipe = StableDiffusion3Pipeline.from_pretrained(
- model,
- torch_dtype=torch.float16,
+
+def load_pipe(
+ pipeline_cls,
+ model_name,
+ variant=None,
+ dtype=torch.float16,
+ device="cuda",
+ custom_pipeline=None,
+ scheduler=None,
+ lora=None,
+ controlnet=None,
+):
+ extra_kwargs = {}
+ if custom_pipeline is not None:
+ extra_kwargs["custom_pipeline"] = custom_pipeline
+ if variant is not None:
+ extra_kwargs["variant"] = variant
+ if dtype is not None:
+ extra_kwargs["torch_dtype"] = dtype
+ if controlnet is not None:
+ from diffusers import ControlNetModel
+
+ controlnet = ControlNetModel.from_pretrained(
+ controlnet,
+ torch_dtype=dtype,
+ )
+ extra_kwargs["controlnet"] = controlnet
+
+
+ pipe = pipeline_cls.from_pretrained(model_name, **extra_kwargs)
+
+
+ if scheduler is not None and scheduler != "none":
+ scheduler_cls = getattr(importlib.import_module("diffusers"), scheduler)
+ pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
+ if lora is not None:
+ pipe.load_lora_weights(lora)
+ pipe.fuse_lora()
+ pipe.safety_checker = None
+ if device is not None:
+ pipe.to(torch.device(device))
+ return pipe
+
+
+class IterationProfiler:
+ def __init__(self):
+ self.begin = None
+ self.end = None
+ self.num_iterations = 0
+
+ def get_iter_per_sec(self):
+ if self.begin is None or self.end is None:
+ return None
+ self.end.synchronize()
+ dur = self.begin.elapsed_time(self.end)
+ return self.num_iterations / dur * 1000.0
+
+ def callback_on_step_end(self, pipe, i, t, callback_kwargs={}):
+ if self.begin is None:
+ event = torch.cuda.Event(enable_timing=True)
+ event.record()
+ self.begin = event
+ else:
+ event = torch.cuda.Event(enable_timing=True)
+ event.record()
+ self.end = event
+ self.num_iterations += 1
+ return callback_kwargs
+
+
+def calculate_inference_time_and_throughput(height, width, n_steps, model):
+ start_time = time.time()
+ model(prompt=args.prompt, height=height, width=width, num_inference_steps=n_steps)
+ end_time = time.time()
+ inference_time = end_time - start_time
+ # pixels_processed = height * width * n_steps
+ # throughput = pixels_processed / inference_time
+ throughput = n_steps / inference_time
+ return inference_time, throughput
+
+
+def generate_data_and_fit_model(model, steps_range):
+ height, width = 1024, 1024
+ data = {"steps": [], "inference_time": [], "throughput": []}
+
+ for n_steps in steps_range:
+ inference_time, throughput = calculate_inference_time_and_throughput(
+ height, width, n_steps, model
+ )
+ data["steps"].append(n_steps)
+ data["inference_time"].append(inference_time)
+ data["throughput"].append(throughput)
+ print(
+ f"Steps: {n_steps}, Inference Time: {inference_time:.2f} seconds, Throughput: {throughput:.2f} steps/s"
)
- self.pipe.to(device)
- if compiler_config:
- print("compile...")
- self.pipe = self.compile_pipe(self.pipe, compiler_config)
+ average_throughput = np.mean(data["throughput"])
+ print(f"Average Throughput: {average_throughput:.2f} steps/s")
- if quantize_config:
- print("quant...")
- self.pipe = self.quantize_pipe(self.pipe, quantize_config)
+ coefficients = np.polyfit(data["steps"], data["inference_time"], 1)
+ base_time_without_base_cost = 1 / coefficients[0]
+ print(f"Throughput without base cost: {base_time_without_base_cost:.2f} steps/s")
+ return data, coefficients
- def warmup(self, gen_args, warmup_iterations=1):
- warmup_args = gen_args.copy()
- warmup_args["generator"] = torch.Generator(device=device).manual_seed(0)
+def plot_data_and_model(data, coefficients):
+ plt.figure(figsize=(10, 5))
+ plt.scatter(data["steps"], data["inference_time"], color="blue")
+ plt.plot(data["steps"], np.polyval(coefficients, data["steps"]), color="red")
+ plt.title("Inference Time vs. Steps")
+ plt.xlabel("Steps")
+ plt.ylabel("Inference Time (seconds)")
+ plt.grid(True)
+ # plt.savefig("output.png")
+ plt.show()
- print("Starting warmup...")
- start_time = time.time()
- for _ in range(warmup_iterations):
- self.pipe(**warmup_args)
- end_time = time.time()
- print("Warmup complete.")
- print(f"Warmup time: {end_time - start_time:.2f} seconds")
+ print(
+ f"Model: Inference Time = {coefficients[0]:.2f} * Steps + {coefficients[1]:.2f}"
+ )
- def generate(self, gen_args):
- gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed)
- # Run the model
- start_time = time.time()
- images = self.pipe(**gen_args).images
- end_time = time.time()
+def main():
- images[0].save(args.saved_image)
+ from diffusers import StableDiffusion3Pipeline as pipeline_cls
- return images[0], end_time - start_time
+ pipe = load_pipe(
+ pipeline_cls,
+ args.model,
+ variant=args.variant,
+ custom_pipeline=args.custom_pipeline,
+ scheduler=args.scheduler,
+ lora=args.lora,
+ controlnet=args.controlnet,
+ )
- def compile_pipe(self, pipe, compiler_config):
- options = compiler_config
+ core_net = None
+ if core_net is None:
+ core_net = getattr(pipe, "unet", None)
+ if core_net is None:
+ core_net = getattr(pipe, "transformer", None)
+ height = args.height or core_net.config.sample_size * pipe.vae_scale_factor
+ width = args.width or core_net.config.sample_size * pipe.vae_scale_factor
+
+ if args.compiler == "none":
+ pass
+ elif args.compiler == "nexfort":
+ print("Nexfort backend is now active...")
+ if args.quantize:
+ if args.quantize_config is not None:
+ quantize_config = json.loads(args.quantize_config)
+ else:
+ quantize_config = '{"quant_type": "fp8_e4m3_e4m3_dynamic_per_tensor"}'
+ if args.quant_submodules_config_path:
+ # download: https://huggingface.co/siliconflow/PixArt-alpha-onediff-nexfort-fp8/blob/main/fp8_e4m3.json
+ pipe = quantize_pipe(
+ pipe,
+ quant_submodules_config_path=args.quant_submodules_config_path,
+ ignores=[],
+ **quantize_config,
+ )
+ else:
+ pipe = quantize_pipe(pipe, ignores=[], **quantize_config)
+ if args.compiler_config is not None:
+ # config with dict
+ options = json.loads(args.compiler_config)
+ else:
+ # config with string
+ options = '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}'
pipe = compile_pipe(
pipe, backend="nexfort", options=options, fuse_qkv_projections=True
)
- return pipe
-
- def quantize_pipe(self, pipe, quantize_config):
- pipe = quantize_pipe(pipe, ignores=[], **quantize_config)
- return pipe
-
-
-def main():
- compiler_config = json.loads(args.compiler_config) if args.compiler_config else None
- quantize_config = json.loads(args.quantize_config) if args.quantize_config else None
-
- sd3 = SD3Generator(args.model, compiler_config, quantize_config)
+ elif args.compiler in ("compile", "compile-max-autotune"):
+ mode = "max-autotune" if args.compiler == "compile-max-autotune" else None
+ if hasattr(pipe, "unet"):
+ pipe.unet = torch.compile(pipe.unet, mode=mode)
+ if hasattr(pipe, "transformer"):
+ pipe.transformer = torch.compile(pipe.transformer, mode=mode)
+ if hasattr(pipe, "controlnet"):
+ pipe.controlnet = torch.compile(pipe.controlnet, mode=mode)
+ pipe.vae = torch.compile(pipe.vae, mode=mode)
+ else:
+ raise ValueError(f"Unknown compiler: {args.compiler}")
- if args.run_multiple_prompts:
- # Note: diffusers will truncate the input prompt (limited to 77 tokens).
- # https://github.com/huggingface/diffusers/blob/8e1b7a084addc4711b8d9be2738441dfad680ce0/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L238
- dynamic_prompts = generate_texts(max_length=101)
- prompt_list = list(dynamic_prompts)
+ if args.input_image is None:
+ input_image = None
else:
- prompt_list = [args.prompt]
-
- gen_args = {
- "prompt": args.prompt,
- "num_inference_steps": args.num_inference_steps,
- "height": args.height,
- "width": args.width,
- "guidance_scale": args.guidance_scale,
- "negative_prompt": args.negative_prompt,
- }
-
- sd3.warmup(gen_args)
-
- for prompt in prompt_list:
- gen_args["prompt"] = prompt
- print(f"Processing prompt of length {len(prompt)} characters.")
- image, inference_time = sd3.generate(gen_args)
- assert inference_time < 20, "Prompt inference took too long"
- print(
- f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
+ input_image = load_image(args.input_image)
+ input_image = input_image.resize((width, height), Image.LANCZOS)
+
+ if args.control_image is None:
+ if args.controlnet is None:
+ control_image = None
+ else:
+ control_image = Image.new("RGB", (width, height))
+ draw = ImageDraw.Draw(control_image)
+ draw.ellipse(
+ (args.width // 4, height // 4, args.width // 4 * 3, height // 4 * 3),
+ fill=(255, 255, 255),
+ )
+ del draw
+ else:
+ control_image = load_image(args.control_image)
+ control_image = control_image.resize((width, height), Image.LANCZOS)
+
+ def get_kwarg_inputs():
+ kwarg_inputs = dict(
+ prompt=args.prompt,
+ height=height,
+ width=width,
+ num_images_per_prompt=args.batch,
+ generator=(
+ None
+ if args.seed is None
+ else torch.Generator(device="cuda").manual_seed(args.seed)
+ ),
+ **(
+ dict()
+ if args.extra_call_kwargs is None
+ else json.loads(args.extra_call_kwargs)
+ ),
)
- cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
- print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")
+ if args.steps is not None:
+ kwarg_inputs["num_inference_steps"] = args.steps
+ if input_image is not None:
+ kwarg_inputs["image"] = input_image
+ if control_image is not None:
+ if input_image is None:
+ kwarg_inputs["image"] = control_image
+ else:
+ kwarg_inputs["control_image"] = control_image
+ if args.deepcache:
+ kwarg_inputs["cache_interval"] = args.cache_interval
+ kwarg_inputs["cache_layer_id"] = args.cache_layer_id
+ kwarg_inputs["cache_block_id"] = args.cache_block_id
+ return kwarg_inputs
+
+ # NOTE: Warm it up.
+ # The initial calls will trigger compilation and might be very slow.
+ # After that, it should be very fast.
+ if args.warmups > 0:
+ begin = time.time()
+ print("=======================================")
+ print("Begin warmup")
+ for _ in range(args.warmups):
+ pipe(**get_kwarg_inputs())
+ end = time.time()
+ print("End warmup")
+ print(f"Warmup time: {end - begin:.3f}s")
+ print("=======================================")
+
+
+ # Let"s see it!
+ # Note: Progress bar might work incorrectly due to the async nature of CUDA.
+ kwarg_inputs = get_kwarg_inputs()
+ iter_profiler = IterationProfiler()
+ if "callback_on_step_end" in inspect.signature(pipe).parameters:
+ kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
+ elif "callback" in inspect.signature(pipe).parameters:
+ kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
+ begin = time.time()
+ output_images = pipe(**kwarg_inputs).images
+ end = time.time()
+
+ print("=======================================")
+ print(f"Inference time: {end - begin:.3f}s")
+ iter_per_sec = iter_profiler.get_iter_per_sec()
+ if iter_per_sec is not None:
+ print(f"Iterations per second: {iter_per_sec:.3f}")
+
+
+ cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
+ print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB")
+ print("=======================================")
+
+ if args.print_output:
+ from onediff.utils.import_utils import is_nexfort_available
+
+ if is_nexfort_available():
+ from nexfort.utils.term_image import print_image
+
+ for image in output_images:
+ print_image(image, max_width=80)
+
+ if args.output_image is not None:
+ output_images[0].save(args.output_image)
+ else:
+ print("Please set `--output-image` to save the output image")
if args.run_multiple_resolutions:
- gen_args["prompt"] = args.prompt
print("Test run with multiple resolutions...")
- sizes = [1536, 1024, 768, 720, 576, 512, 256]
+ sizes = [1024, 512, 768, 256]
for h in sizes:
for w in sizes:
- gen_args["height"] = h
- gen_args["width"] = w
+ kwarg_inputs["height"] = h
+ kwarg_inputs["width"] = w
print(f"Running at resolution: {h}x{w}")
start_time = time.time()
- sd3.generate(gen_args)
+ image = pipe(**kwarg_inputs).images
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")
- assert (
- end_time - start_time
- ) < 20, "Resolution switch test took too long"
+
+ if args.throughput:
+ steps_range = range(1, 100, 1)
+ data, coefficients = generate_data_and_fit_model(pipe, steps_range)
+ plot_data_and_model(data, coefficients)
if __name__ == "__main__":