diff --git a/imgs/nexfort_sd3_demo.png b/imgs/nexfort_sd3_demo.png index e5f144fbd..57a022cfd 100644 Binary files a/imgs/nexfort_sd3_demo.png and b/imgs/nexfort_sd3_demo.png differ diff --git a/onediff_diffusers_extensions/examples/sd3/README.md b/onediff_diffusers_extensions/examples/sd3/README.md index 71e0f306d..9759e5e4b 100644 --- a/onediff_diffusers_extensions/examples/sd3/README.md +++ b/onediff_diffusers_extensions/examples/sd3/README.md @@ -3,14 +3,15 @@ 1. [Environment Setup](#environment-setup) - [Set Up OneDiff](#set-up-onediff) - [Set Up NexFort Backend](#set-up-nexfort-backend) - - [Set Up Diffusers Library](#set-up-diffusers-library) + - [Set Up Diffusers](#set-up-diffusers) - [Download SD3 Model for Diffusers](#download-sd3-model-for-diffusers) 2. [Execution Instructions](#execution-instructions) - [Run Without Compilation (Baseline)](#run-without-compilation-baseline) - [Run With Compilation](#run-with-compilation) 3. [Performance Comparison](#performance-comparison) 4. [Dynamic Shape for SD3](#dynamic-shape-for-sd3) -5. [Quality](#quality) +5. [Quantization](#quantization) +6. [Quality](#quality) ## Environment setup ### Set up onediff @@ -25,12 +26,12 @@ https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/back # Ensure diffusers include the SD3 pipeline. pip3 install --upgrade diffusers[torch] ``` -### Set up SD3 +### Download SD3 model for diffusers Model version for diffusers: https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers HF pipeline: https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md -## Run +## Execution instructions ### Run 1024*1024 without compile (the original pytorch HF diffusers baseline) ``` @@ -38,24 +39,24 @@ python3 onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py \ --saved-image sd3.png ``` -### Run 1024*1024 with compile +### Run 1024*1024 with onediff (nexfort) compile ``` python3 onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py \ - --compiler-config '{"mode": "max-optimize:max-autotune:low-precision:cache-all:freezing:benchmark", "memory_format": "channels_last"}' \ + --compiler-config '{"mode": "max-optimize:max-autotune:low-precision:cache-all", "memory_format": "channels_last"}' \ --saved-image sd3_compile.png ``` ## Performance comparation -Testing on H800-NVL-80GB, with image size of 1024*1024, iterating 28 steps: +Testing on H800-NVL-80GB with torch 2.3.0, with image size of 1024*1024, iterating 28 steps: | Metric | | | ------------------------------------------------ | ----------------------------------- | -| Data update date(yyyy-mm-dd) | 2024-06-24 | +| Data update date(yyyy-mm-dd) | 2024-06-29 | | PyTorch iteration speed | 15.56 it/s | -| OneDiff iteration speed | 25.91 it/s (+66.5%) | +| OneDiff iteration speed | 24.12 it/s (+55.0%) | | PyTorch E2E time | 1.96 s | -| OneDiff E2E time | 1.15 s (-41.3%) | +| OneDiff E2E time | 1.31 s (-33.2%) | | PyTorch Max Mem Used | 18.784 GiB | | OneDiff Max Mem Used | 18.324 GiB | | PyTorch Warmup with Run time | 2.86 s | @@ -68,11 +69,11 @@ Testing on H800-NVL-80GB, with image size of 1024*1024, iterating 28 steps: Testing on 4090: | Metric | | | ------------------------------------------------ | ----------------------------------- | -| Data update date(yyyy-mm-dd) | 2024-06-24 | +| Data update date(yyyy-mm-dd) | 2024-06-29 | | PyTorch iteration speed | 6.67 it/s | -| OneDiff iteration speed | 12.24 it/s (+83.3%) | +| OneDiff iteration speed | 11.51 it/s (+72.6%) | | PyTorch E2E time | 4.90 s | -| OneDiff E2E time | 2.48 s (-49.4%) | +| OneDiff E2E time | 2.67 s (-45.5%) | | PyTorch Max Mem Used | 18.799 GiB | | OneDiff Max Mem Used | 17.902 GiB | | PyTorch Warmup with Run time | 4.99 s | @@ -95,9 +96,46 @@ python3 onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py \ --run_multiple_resolutions 1 \ --saved-image sd3_compile.png ``` +## Quantization + +> [!NOTE] +Quantization is a feature for onediff enterprise. + +### Run + +Quantization of the model's layers can be selectively performed based on precision. Download `fp8_e4m3.json` or `fp8_e4m3_per_tensor.json` from https://huggingface.co/siliconflow/stable-diffusion-3-onediff-nexfort-fp8. + +The --arg `quant-submodules-config-path` is optional. If left `None`, it will quantize all linear layers. + +``` +# Applies dynamic symmetric per-tensor activation and per-tensor weight quantization to all linear layers. Both activations and weights are quantized to e4m3 format. +python3 onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py \ + --compiler-config '{"mode": "quant:max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}' \ + --quantize-config '{"quant_type": "fp8_e4m3_e4m3_dynamic_per_tensor"}' \ + --quant-submodules-config-path /path/to/fp8_e4m3_per_tensor.json \ + --saved-image sd3_fp8.png +``` +or +``` +# Applies dynamic symmetric per-token activation and per-channel weight quantization to all linear layers. +python3 onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py \ + --compiler-config '{"mode": "quant:max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}' \ + --quantize-config '{"quant_type": "fp8_e4m3_e4m3_dynamic"}' \ + --quant-submodules-config-path /path/to/fp8_e4m3.json \ + --saved-image sd3_fp8.png +``` + +### Metric + +The performance of above quantization types on the H800-NVL-80GB is as follows: + +| quant_type | E2E Inference Time | Iteration speed | Max Used CUDA Memory | +|----------------------------------|--------------------|--------------------|----------------------| +| fp8_e4m3_e4m3_dynamic_per_tensor | 1.22 s (-37.8%) | 25.26 it/s (+62.3%)| 16.933 GiB | +| fp8_e4m3_e4m3_dynamic | 1.14 s (-41.8%) | 27.12 it/s (+74.3%)| 17.098 GiB | ## Quality -When using nexfort as the backend for onediff compilation acceleration, the generated images are lossless. +When using nexfort as the backend for onediff compilation acceleration, the generated images are almost lossless.
diff --git a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
index 4809c9f07..45294811b 100644
--- a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
+++ b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
@@ -26,7 +26,7 @@ def parse_args():
parser.add_argument(
"--prompt",
type=str,
- default="photo of a dog and a cat both standing on a red box, with a blue ball in the middle with a parrot standing on top of the ball. The box has the text 'onediff'",
+ default="photo of a dog and a cat both standing on a red box, with a blue ball in the middle with a parrot standing on top of the ball. The box has the text 'OneDiff'",
help="Prompt for the image generation.",
)
parser.add_argument(
@@ -42,7 +42,10 @@ def parse_args():
"--width", type=int, default=1024, help="Width of the generated image."
)
parser.add_argument(
- "--guidance_scale", type=float, default=4.5, help="The scale factor for the guidance."
+ "--guidance_scale",
+ type=float,
+ default=4.5,
+ help="The scale factor for the guidance.",
)
parser.add_argument(
"--num-inference-steps", type=int, default=28, help="Number of inference steps."
@@ -54,7 +57,13 @@ def parse_args():
help="Path to save the generated image.",
)
parser.add_argument(
- "--seed", type=int, default=1, help="Seed for random number generation."
+ "--seed", type=int, default=2, help="Seed for random number generation."
+ )
+ parser.add_argument(
+ "--warmup-iterations",
+ type=int,
+ default=1,
+ help="Number of warm-up iterations before actual inference.",
)
parser.add_argument(
"--run_multiple_resolutions",
@@ -66,6 +75,13 @@ def parse_args():
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=False,
)
+ parser.add_argument("--quant-submodules-config-path", type=str, default=None)
+ parser.add_argument(
+ "--use_torch_compile",
+ type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
+ default=False,
+ help="Whether to use torch.compile optimizations.",
+ )
return parser.parse_args()
@@ -107,13 +123,18 @@ def generate_texts(min_length=50, max_length=302):
class SD3Generator:
- def __init__(self, model, compiler_config=None, quantize_config=None):
+ def __init__(
+ self, model, compiler_config=None, quantize_config=None, use_torch_compile=False
+ ):
self.pipe = StableDiffusion3Pipeline.from_pretrained(
model,
torch_dtype=torch.float16,
)
self.pipe.to(device)
+ if use_torch_compile:
+ self.setup_torch_compile()
+
if compiler_config:
print("compile...")
self.pipe = self.compile_pipe(self.pipe, compiler_config)
@@ -122,7 +143,7 @@ def __init__(self, model, compiler_config=None, quantize_config=None):
print("quant...")
self.pipe = self.quantize_pipe(self.pipe, quantize_config)
- def warmup(self, gen_args, warmup_iterations=1):
+ def warmup(self, gen_args, warmup_iterations):
warmup_args = gen_args.copy()
warmup_args["generator"] = torch.Generator(device=device).manual_seed(0)
@@ -147,6 +168,23 @@ def generate(self, gen_args):
return images[0], end_time - start_time
+ def setup_torch_compile(self):
+ torch.set_float32_matmul_precision("high")
+ torch._inductor.config.conv_1x1_as_mm = True
+ torch._inductor.config.coordinate_descent_tuning = True
+ torch._inductor.config.epilogue_fusion = False
+ torch._inductor.config.coordinate_descent_check_all_directions = True
+
+ self.pipe.transformer.to(memory_format=torch.channels_last)
+ self.pipe.vae.to(memory_format=torch.channels_last)
+
+ self.pipe.transformer = torch.compile(
+ self.pipe.transformer, mode="max-autotune", fullgraph=True
+ )
+ self.pipe.vae.decode = torch.compile(
+ self.pipe.vae.decode, mode="max-autotune", fullgraph=True
+ )
+
def compile_pipe(self, pipe, compiler_config):
options = compiler_config
pipe = compile_pipe(
@@ -155,7 +193,17 @@ def compile_pipe(self, pipe, compiler_config):
return pipe
def quantize_pipe(self, pipe, quantize_config):
- pipe = quantize_pipe(pipe, ignores=[], **quantize_config)
+ if args.quant_submodules_config_path:
+ # Quantitative submodules configuration file download: https://huggingface.co/siliconflow/stable-diffusion-3-onediff-nexfort-fp8
+ pipe = quantize_pipe(
+ pipe,
+ quant_submodules_config_path=args.quant_submodules_config_path,
+ top_percentage=75,
+ ignores=[],
+ **quantize_config,
+ )
+ else:
+ pipe = quantize_pipe(pipe, ignores=[], **quantize_config)
return pipe
@@ -163,7 +211,13 @@ def main():
compiler_config = json.loads(args.compiler_config) if args.compiler_config else None
quantize_config = json.loads(args.quantize_config) if args.quantize_config else None
- sd3 = SD3Generator(args.model, compiler_config, quantize_config)
+ if not args.use_torch_compile:
+ sd3 = SD3Generator(args.model, compiler_config, quantize_config)
+ else:
+ assert (
+ args.compiler_config is None and args.quantize_config is None
+ ), "compiler_config and quantize_config must be None when use_torch_compile is enabled"
+ sd3 = SD3Generator(args.model, use_torch_compile=True)
if args.run_multiple_prompts:
# Note: diffusers will truncate the input prompt (limited to 77 tokens).
@@ -182,7 +236,7 @@ def main():
"negative_prompt": args.negative_prompt,
}
- sd3.warmup(gen_args)
+ sd3.warmup(gen_args, args.warmup_iterations)
for prompt in prompt_list:
gen_args["prompt"] = prompt