Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

StableVideoDiffusion model converter #137

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/footprints/
/cache/
/result_*.png
20 changes: 20 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# OnnxStack.Converter

## Requirements
```bash
pip install onnxruntime-directml
pip install olive-ai[directml]
python -m pip install -r requirements.txt
```

## Usage
```bash
convert.py --optimize --model_input '..\stable-video-diffusion-img2vid-xt' --model_output '..\converted'
```
`--optimize` - Run the model optimization

`--model_input` - Safetensor model to convert

`--model_output` - Output for converted ONNX model (NOTE: This folder is deleted before each run)

`--image_encoder` - Convert the optional image encoder
8 changes: 8 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

vae_sample_size = 512
unet_sample_size = 24
cross_attention_dim = 1280
110 changes: 110 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/config_unet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"input_model": {
"type": "PyTorchModel",
"config": {
"model_path": "stabilityai/stable-video-diffusion-img2vid-xt",
"model_loader": "unet_load",
"model_script": "models.py",
"io_config": {
"input_names": [ "sample", "timestep", "encoder_hidden_states", "added_time_ids" ],
"output_names": [ "out_sample" ],
"dynamic_axes": {
"sample": {"0": "batch", "1": "frames", "2": "channel", "3": "height", "4": "width"},
"timestep": {"0": "timestep"},
"encoder_hidden_states": {"0": "batch", "1": "sequence_length", "2": "cross_attention_dim"},
"added_time_ids": {"0": "batch", "1": "num_additional_ids" }
}
},
"dummy_inputs_func": "unet_conversion_inputs"
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"accelerators": [
{
"device": "gpu",
"execution_providers": [
"DmlExecutionProvider"
]
}
]
}
}
},
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "latency",
"type": "latency",
"sub_types": [{"name": "avg"}],
"user_config": {
"user_script": "models.py",
"dataloader_func": "unet_data_loader",
"batch_size": 2
}
}
]
}
},
"passes": {
"convert": {
"type": "OnnxConversion",
"config": {
"target_opset": 16,
"save_as_external_data": true,
"all_tensors_to_one_file": true
}
},
"optimize": {
"type": "OrtTransformersOptimization",
"config": {
"model_type": "unet",
"opt_level": 0,
"float16": true,
"use_gpu": true,
"keep_io_types": true,
"optimization_options": {
"enable_gelu": true,
"enable_layer_norm": true,
"enable_attention": true,
"use_multi_head_attention": true,
"enable_skip_layer_norm": false,
"enable_embed_layer_norm": true,
"enable_bias_skip_layer_norm": false,
"enable_bias_gelu": true,
"enable_gelu_approximation": false,
"enable_qordered_matmul": false,
"enable_shape_inference": true,
"enable_gemm_fast_gelu": false,
"enable_nhwc_conv": false,
"enable_group_norm": true,
"enable_bias_splitgelu": false,
"enable_packed_qkv": true,
"enable_packed_kv": true,
"enable_bias_add": false,
"group_norm_channels_last": false
},
"force_fp32_ops": ["RandomNormalLike"],
"force_fp16_inputs": {
"GroupNorm": [0, 1, 2]
}
}
}
},
"pass_flows": [
["convert", "optimize"]
],
"engine": {
"log_severity_level": 0,
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
"target": "local_system",
"cache_dir": "cache",
"output_name": "unet",
"output_dir": "footprints"
}
}
105 changes: 105 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/config_vae_decoder.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"input_model": {
"type": "PyTorchModel",
"config": {
"model_path": "stabilityai/stable-video-diffusion-img2vid-xt",
"model_loader": "vae_decoder_load",
"model_script": "models.py",
"io_config": {
"input_names": [ "latent_sample", "num_frames" ],
"output_names": [ "sample" ],
"dynamic_axes": {
"latent_sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" }
}
},
"dummy_inputs_func": "vae_decoder_conversion_inputs"
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"accelerators": [
{
"device": "gpu",
"execution_providers": [
"DmlExecutionProvider"
]
}
]
}
}
},
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "latency",
"type": "latency",
"sub_types": [{"name": "avg"}],
"user_config": {
"user_script": "models.py",
"dataloader_func": "vae_decoder_data_loader",
"batch_size": 1
}
}
]
}
},
"passes": {
"convert": {
"type": "OnnxConversion",
"config": {
"target_opset": 16
}
},
"optimize": {
"type": "OrtTransformersOptimization",
"config": {
"model_type": "vae",
"opt_level": 0,
"float16": true,
"use_gpu": true,
"keep_io_types": false,
"optimization_options": {
"enable_gelu": true,
"enable_layer_norm": true,
"enable_attention": true,
"use_multi_head_attention": true,
"enable_skip_layer_norm": false,
"enable_embed_layer_norm": true,
"enable_bias_skip_layer_norm": false,
"enable_bias_gelu": true,
"enable_gelu_approximation": false,
"enable_qordered_matmul": false,
"enable_shape_inference": true,
"enable_gemm_fast_gelu": false,
"enable_nhwc_conv": false,
"enable_group_norm": true,
"enable_bias_splitgelu": false,
"enable_packed_qkv": true,
"enable_packed_kv": true,
"enable_bias_add": false,
"group_norm_channels_last": false
},
"force_fp32_ops": ["RandomNormalLike"],
"force_fp16_inputs": {
"GroupNorm": [0, 1, 2]
}
}
}
},
"pass_flows": [
["convert", "optimize"]
],
"engine": {
"log_severity_level": 0,
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
"target": "local_system",
"cache_dir": "cache",
"output_name": "vae_decoder",
"output_dir": "footprints"
}
}
103 changes: 103 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/config_vae_encoder.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
"input_model": {
"type": "PyTorchModel",
"config": {
"model_path": "stabilityai/stable-video-diffusion-img2vid-xt",
"model_loader": "vae_encoder_load",
"model_script": "models.py",
"io_config": {
"input_names": [ "sample" ],
"output_names": [ "latent_sample" ],
"dynamic_axes": { "sample": { "0": "batch", "1": "channels", "2": "height", "3": "width" } }
},
"dummy_inputs_func": "vae_encoder_conversion_inputs"
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"accelerators": [
{
"device": "gpu",
"execution_providers": [
"DmlExecutionProvider"
]
}
]
}
}
},
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "latency",
"type": "latency",
"sub_types": [{"name": "avg"}],
"user_config": {
"user_script": "models.py",
"dataloader_func": "vae_encoder_data_loader",
"batch_size": 1
}
}
]
}
},
"passes": {
"convert": {
"type": "OnnxConversion",
"config": {
"target_opset": 16
}
},
"optimize": {
"type": "OrtTransformersOptimization",
"config": {
"model_type": "vae",
"opt_level": 0,
"float16": true,
"use_gpu": true,
"keep_io_types": false,
"optimization_options": {
"enable_gelu": true,
"enable_layer_norm": true,
"enable_attention": true,
"use_multi_head_attention": true,
"enable_skip_layer_norm": false,
"enable_embed_layer_norm": true,
"enable_bias_skip_layer_norm": false,
"enable_bias_gelu": true,
"enable_gelu_approximation": false,
"enable_qordered_matmul": false,
"enable_shape_inference": true,
"enable_gemm_fast_gelu": false,
"enable_nhwc_conv": false,
"enable_group_norm": true,
"enable_bias_splitgelu": false,
"enable_packed_qkv": true,
"enable_packed_kv": true,
"enable_bias_add": false,
"group_norm_channels_last": false
},
"force_fp32_ops": ["RandomNormalLike"],
"force_fp16_inputs": {
"GroupNorm": [0, 1, 2]
}
}
}
},
"pass_flows": [
["convert", "optimize"]
],
"engine": {
"log_severity_level": 0,
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"host": "local_system",
"target": "local_system",
"cache_dir": "cache",
"output_name": "vae_encoder",
"output_dir": "footprints"
}
}
211 changes: 211 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import json
import shutil
import sys
import warnings
from pathlib import Path
from typing import Dict

import config
import torch
from diffusers import DiffusionPipeline
from packaging import version

from olive.common.utils import set_tempdir
from olive.workflows import run as olive_run


# pylint: disable=redefined-outer-name
# ruff: noqa: TID252, T201


def save_image(result, batch_size, provider, num_images, images_saved, image_callback=None):
passed_safety_checker = 0
for image_index in range(batch_size):
if result.nsfw_content_detected is None or not result.nsfw_content_detected[image_index]:
passed_safety_checker += 1
if images_saved < num_images:
output_path = f"result_{images_saved}.png"
result.images[image_index].save(output_path)
if image_callback:
image_callback(images_saved, output_path)
images_saved += 1
print(f"Generated {output_path}")
print(f"Inference Batch End ({passed_safety_checker}/{batch_size} images).")
print("Images passed the safety checker.")
return images_saved


def run_inference_loop(
pipeline,
prompt,
num_images,
batch_size,
image_size,
num_inference_steps,
guidance_scale,
strength: float,
provider: str,
image_callback=None,
step_callback=None,
):
images_saved = 0

def update_steps(step, timestep, latents):
if step_callback:
step_callback((images_saved // batch_size) * num_inference_steps + step)

while images_saved < num_images:
print(f"\nInference Batch Start (batch size = {batch_size}).")

kwargs = {}

result = pipeline(
[prompt] * batch_size,
num_inference_steps=num_inference_steps,
callback=update_steps if step_callback else None,
height=image_size,
width=image_size,
guidance_scale=guidance_scale,
**kwargs,
)

images_saved = save_image(result, batch_size, provider, num_images, images_saved, image_callback)


def update_config_with_provider(config: Dict, provider: str):
if provider == "dml":
# DirectML EP is the default, so no need to update config.
return config
elif provider == "cuda":
from sd_utils.ort import update_cuda_config

return update_cuda_config(config)
else:
raise ValueError(f"Unsupported provider: {provider}")


def optimize(
model_input: str,
model_output: Path,
provider: str,
image_encoder: bool
):
from google.protobuf import __version__ as protobuf_version

# protobuf 4.x aborts with OOM when optimizing unet
if version.parse(protobuf_version) > version.parse("3.20.3"):
print("This script requires protobuf 3.20.3. Please ensure your package version matches requirements.txt.")
sys.exit(1)

model_dir = model_input
script_dir = Path(__file__).resolve().parent

# Clean up previously optimized models, if any.
shutil.rmtree(script_dir / "footprints", ignore_errors=True)
shutil.rmtree(model_output, ignore_errors=True)

# Load the entire PyTorch pipeline to ensure all models and their configurations are downloaded and cached.
# This avoids an issue where the non-ONNX components (tokenizer, scheduler, and feature extractor) are not
# automatically cached correctly if individual models are fetched one at a time.
print("Download stable diffusion PyTorch pipeline...")
pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float32, **{"local_files_only": True})
# config.vae_sample_size = pipeline.vae.config.sample_size
# config.cross_attention_dim = pipeline.unet.config.cross_attention_dim
# config.unet_sample_size = pipeline.unet.config.sample_size

model_info = {}

submodel_names = [ "vae_encoder", "vae_decoder" ]

if image_encoder:
submodel_names.append("image_encoder")

for submodel_name in submodel_names:
print(f"\nOptimizing {submodel_name}")

olive_config = None
with (script_dir / f"config_{submodel_name}.json").open() as fin:
olive_config = json.load(fin)
olive_config = update_config_with_provider(olive_config, provider)
olive_config["input_model"]["config"]["model_path"] = model_dir

run_res = olive_run(olive_config)

from sd_utils.ort import save_optimized_onnx_submodel

save_optimized_onnx_submodel(submodel_name, provider, model_info)

from sd_utils.ort import save_onnx_pipeline

save_onnx_pipeline(
model_info, model_output, pipeline, submodel_names
)

return model_info


def parse_common_args(raw_args):
parser = argparse.ArgumentParser("Common arguments")
parser.add_argument("--model_input", default="stable-diffusion-v1-5", type=str)
parser.add_argument("--model_output", default="stable-diffusion-v1-5", type=Path)
parser.add_argument("--image_encoder",action="store_true", help="Create image encoder model")
parser.add_argument("--provider", default="dml", type=str, choices=["dml", "cuda"], help="Execution provider to use")
parser.add_argument("--optimize", action="store_true", help="Runs the optimization step")
parser.add_argument("--clean_cache", action="store_true", help="Deletes the Olive cache")
parser.add_argument("--test_unoptimized", action="store_true", help="Use unoptimized model for inference")
parser.add_argument("--tempdir", default=None, type=str, help="Root directory for tempfile directories and files")
return parser.parse_known_args(raw_args)


def parse_ort_args(raw_args):
parser = argparse.ArgumentParser("ONNX Runtime arguments")

parser.add_argument(
"--static_dims",
action="store_true",
help="DEPRECATED (now enabled by default). Use --dynamic_dims to disable static_dims.",
)
parser.add_argument("--dynamic_dims", action="store_true", help="Disable static shape optimization")

return parser.parse_known_args(raw_args)


def main(raw_args=None):
common_args, extra_args = parse_common_args(raw_args)

provider = common_args.provider
model_input = common_args.model_input
model_output = common_args.model_output

script_dir = Path(__file__).resolve().parent


if common_args.clean_cache:
shutil.rmtree(script_dir / "cache", ignore_errors=True)

ort_args = None, None
ort_args, extra_args = parse_ort_args(extra_args)

if common_args.optimize or not model_output.exists():
set_tempdir(common_args.tempdir)

# TODO(jstoecker): clean up warning filter (mostly during conversion from torch to ONNX)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

from sd_utils.ort import validate_args

validate_args(ort_args, common_args.provider)
optimize(common_args.model_input, common_args.model_output, common_args.provider, common_args.image_encoder)

if not common_args.optimize:
print("TODO: Create OnnxStableCascadePipeline")


if __name__ == "__main__":
main()
101 changes: 101 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import config
import torch
from typing import Union, Optional, Tuple
from diffusers import UNetSpatioTemporalConditionModel, AutoencoderKLTemporalDecoder
from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
from dataclasses import dataclass

# Helper latency-only dataloader that creates random tensors with no label
class RandomDataLoader:
def __init__(self, create_inputs_func, batchsize, torch_dtype):
self.create_input_func = create_inputs_func
self.batchsize = batchsize
self.torch_dtype = torch_dtype

def __getitem__(self, idx):
label = None
return self.create_input_func(self.batchsize, self.torch_dtype), label



# -----------------------------------------------------------------------------
# UNET
# -----------------------------------------------------------------------------

def unet_inputs(batchsize, torch_dtype, is_conversion_inputs=False):
inputs = {
"sample": torch.rand((batchsize, 25, 8, 72, 128), dtype=torch_dtype),
"timestep": torch.rand((1,), dtype=torch_dtype),
"encoder_hidden_states": torch.rand((batchsize , 1, 1024), dtype=torch_dtype),
"added_time_ids": torch.rand((batchsize, 3), dtype=torch_dtype)
}
return inputs


def unet_load(model_name):
model = UNetSpatioTemporalConditionModel.from_pretrained(model_name, subfolder="unet")
return model


def unet_conversion_inputs(model=None):
return tuple(unet_inputs(1, torch.float32, True).values())


def unet_data_loader(data_dir, batchsize, *args, **kwargs):
return RandomDataLoader(unet_inputs, batchsize, torch.float16)



# -----------------------------------------------------------------------------
# VAE ENCODER
# -----------------------------------------------------------------------------


def vae_encoder_inputs(batchsize, torch_dtype):
return {"sample": torch.rand((batchsize, 3, 72, 128), dtype=torch_dtype)}


def vae_encoder_load(model_name):
model = AutoencoderKLTemporalDecoder.from_pretrained(model_name, subfolder="vae", use_safetensors=True)
model.forward = lambda sample: model.encode(sample)[0].sample()
return model


def vae_encoder_conversion_inputs(model=None):
return tuple(vae_encoder_inputs(1, torch.float32).values())


def vae_encoder_data_loader(data_dir, batchsize, *args, **kwargs):
return RandomDataLoader(vae_encoder_inputs, batchsize, torch.float16)




# -----------------------------------------------------------------------------
# VAE DECODER
# -----------------------------------------------------------------------------


def vae_decoder_inputs(batchsize, torch_dtype):
return {
"latent_sample": torch.rand((batchsize, 4, 72, 128), dtype=torch_dtype),
"num_frames": 1,
}


def vae_decoder_load(model_name):
model = AutoencoderKLTemporalDecoder.from_pretrained(model_name, subfolder="vae", use_safetensors=True)
model.forward = model.decode
return model


def vae_decoder_conversion_inputs(model=None):
return tuple(vae_decoder_inputs(1, torch.float32).values())


def vae_decoder_data_loader(data_dir, batchsize, *args, **kwargs):
return RandomDataLoader(vae_decoder_inputs, batchsize, torch.float16)
9 changes: 9 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
accelerate
diffusers
onnx
pillow
protobuf==3.20.3 # protobuf 4.x aborts with OOM when optimizing unet
tabulate
torch
transformers
onnxruntime-directml>=1.16.0
117 changes: 117 additions & 0 deletions OnnxStack.Converter/stable_diffusion_video/sd_utils/ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import json
import shutil
import sys
from pathlib import Path
from typing import Dict

import onnxruntime as ort
from diffusers import OnnxRuntimeModel, StableCascadePriorPipeline
from onnxruntime import __version__ as OrtVersion
from packaging import version

from olive.model import ONNXModelHandler

# ruff: noqa: TID252, T201


def update_cuda_config(config: Dict):
if version.parse(OrtVersion) < version.parse("1.17.0"):
# disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models
config["passes"]["optimize_cuda"]["config"]["optimization_options"] = {"enable_skip_group_norm": False}
config["pass_flows"] = [["convert", "optimize_cuda"]]
config["systems"]["local_system"]["config"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"]
return config


def validate_args(args, provider):
ort.set_default_logger_severity(4)
if args.static_dims:
print(
"WARNING: the --static_dims option is deprecated, and static shape optimization is enabled by default. "
"Use --dynamic_dims to disable static shape optimization."
)

validate_ort_version(provider)


def validate_ort_version(provider: str):
if provider == "dml" and version.parse(OrtVersion) < version.parse("1.16.0"):
print("This script requires onnxruntime-directml 1.16.0 or newer")
sys.exit(1)
elif provider == "cuda" and version.parse(OrtVersion) < version.parse("1.17.0"):
if version.parse(OrtVersion) < version.parse("1.16.2"):
print("This script requires onnxruntime-gpu 1.16.2 or newer")
sys.exit(1)
print(
f"WARNING: onnxruntime {OrtVersion} has known issues with shape inference for SkipGroupNorm. Will disable"
" skip_group_norm fusion. onnxruntime-gpu 1.17.0 or newer is strongly recommended!"
)


def save_optimized_onnx_submodel(submodel_name, provider, model_info):
footprints_file_path = (
Path(__file__).resolve().parents[1] / "footprints" / f"{submodel_name}_gpu-{provider}_footprints.json"
)
with footprints_file_path.open("r") as footprint_file:
footprints = json.load(footprint_file)

conversion_footprint = None
optimizer_footprint = None
for footprint in footprints.values():
if footprint["from_pass"] == "OnnxConversion":
conversion_footprint = footprint
elif footprint["from_pass"] == "OrtTransformersOptimization":
optimizer_footprint = footprint

assert conversion_footprint
assert optimizer_footprint

unoptimized_olive_model = ONNXModelHandler(**conversion_footprint["model_config"]["config"])
optimized_olive_model = ONNXModelHandler(**optimizer_footprint["model_config"]["config"])

model_info[submodel_name] = {
"unoptimized": {
"path": Path(unoptimized_olive_model.model_path),
"data": Path(unoptimized_olive_model.model_path + ".data"),
},
"optimized": {
"path": Path(optimized_olive_model.model_path),
"data": Path(optimized_olive_model.model_path + ".data"),
},
}

print(f"Unoptimized Model : {model_info[submodel_name]['unoptimized']['path']}")
print(f"Optimized Model : {model_info[submodel_name]['optimized']['path']}")


def save_onnx_pipeline(
model_info, model_output, pipeline, submodel_names
):
# Save the unoptimized models in a directory structure that the diffusers library can load and run.
# This is optional, and the optimized models can be used directly in a custom pipeline if desired.
# print("\nCreating ONNX pipeline...")

# TODO: Create OnnxStableCascadePipeline

# Create a copy of the unoptimized model directory, then overwrite with optimized models from the olive cache.
print("Copying optimized models...")
for passType in ["optimized", "unoptimized"]:
model_dir = model_output / passType
for submodel_name in submodel_names:
src_path = model_info[submodel_name][passType]["path"] # model.onnx
src_data_path = model_info[submodel_name][passType]["data"]# model.onnx.data

dst_path = model_dir / submodel_name
if not os.path.exists(dst_path):
os.makedirs(dst_path, exist_ok=True)

shutil.copyfile(src_path, dst_path / "model.onnx")
if os.path.exists(src_data_path):
shutil.copyfile(src_data_path, dst_path / "model.onnx.data")

print(f"The converted model is located here: {model_output}")