Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TeaCache and FBCache #451

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ We also have implemented the following parallel strategies for reference:
1. Tensor Parallelism
2. [DistriFusion](https://arxiv.org/abs/2402.19481)

<h3 id="meet-xdit-cache">Cache Acceleration</h3>

Cache method is inspired by work from TeaCache(https://github.com/ali-vilab/TeaCache.git) and ParaAttn(https://github.com/chengzeyi/ParaAttention.git); We adapted the TeaCache and First-Block-Cache in xDiT.

This method is not orthogonal to parallel in xDiT. Only when SP or no parrallelism can activate the cache function.

To use this functionality, you can activate it by `--use_teacache` or `--use_fbcache`, which activate TeaCache and First-Block-Cache respectively. Right now, this repo only supports FLUX model.

The Performance shown as below, tested on 4 H20 with SP=4:
| 方法 | 性能 |
|----------------|--------|
| 原始 | 2.02s |
| use_teacache | 1.58s |
| use_fbcache | 0.93s |

<h3 id="meet-xdit-perf">Computing Acceleration</h3>

Expand Down
24 changes: 24 additions & 0 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_tensor_model_parallel_world_size,
get_data_parallel_world_size,
)


Expand Down Expand Up @@ -45,7 +49,27 @@ def main():
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)
from xfuser.model_executor.plugins.cache_.diffusers_adapters import apply_cache_on_transformer
use_cache = engine_args.use_teacache or engine_args.use_fbcache
if (use_cache
and get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
):
cache_args = {
"rel_l1_thresh": 0.6,
"return_hidden_states_first": False,
"num_steps": input_config.num_inference_steps,
}

if engine_args.use_fbcache and engine_args.use_teacache:
cache_args["use_cache"] = "Fb"
elif engine_args.use_teacache:
cache_args["use_cache"] = "Tea"
elif engine_args.use_fbcache:
cache_args["use_cache"] = "Fb"

pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
Expand Down
6 changes: 5 additions & 1 deletion examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ declare -A MODEL_CONFIGS=(
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev/ 28"
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
)

Expand All @@ -27,6 +27,9 @@ mkdir -p ./results
# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"

# cache args
# CACHE_ARGS="--use_teacache"
# CACHE_ARGS="--use_fbcache"

# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
Expand Down Expand Up @@ -64,3 +67,4 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
$CACHE_ARGS \
12 changes: 12 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class xFuserArgs:
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
use_teacache: bool = False
use_fbcache: bool = False
use_fp8_t5_encoder: bool = False

@staticmethod
Expand Down Expand Up @@ -154,6 +156,16 @@ def add_cli_args(parser: FlexibleArgumentParser):
action="store_true",
help="Enable onediff to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_teacache",
action="store_true",
help="Enable teacache to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_fbcache",
action="store_true",
help="Enable teacache to accelerate inference in a single card",
)

# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
Expand Down
2 changes: 2 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class RuntimeConfig:
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
use_teacache: bool = False
use_fbcache: bool = False

def __post_init__(self):
check_packages()
Expand Down
4 changes: 2 additions & 2 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def group_skip_rank(self):
world_size = self.world_size
return (world_size - rank_in_group - 1) % world_size

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
Expand All @@ -206,7 +206,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.world_size == 1:
return input_
else:
torch.distributed.all_reduce(input_, group=self.device_group)
torch.distributed.all_reduce(input_, op=op, group=self.device_group)
return input_

def all_gather(
Expand Down
1 change: 1 addition & 0 deletions xfuser/core/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def init_distributed_environment(
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
Expand Down
4 changes: 4 additions & 0 deletions xfuser/model_executor/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
4 changes: 4 additions & 0 deletions xfuser/model_executor/plugins/cache_/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
import importlib
from typing import Type, Dict, TypeVar
from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY


def apply_cache_on_transformer(transformer, *args, **kwargs):
adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer))
if not adapter_name:
raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
74 changes: 74 additions & 0 deletions xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
import functools
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加 copy 代码来源。
目录cache_不要加下划线?

import unittest

import torch
from torch import nn
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY

from xfuser.model_executor.plugins.cache_ import utils

def create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps):
cached_transformer_class = {
"Fb": utils.FBCachedTransformerBlocks,
"Tea": utils.TeaCachedTransformerBlocks,
}.get(use_cache)

if not cached_transformer_class:
raise ValueError(f"Unsupported use_cache value: {use_cache}")

return cached_transformer_class(
transformer.transformer_blocks,
transformer.single_transformer_blocks,
transformer=transformer,
rel_l1_thresh=rel_l1_thresh,
return_hidden_states_first=return_hidden_states_first,
num_steps=num_steps,
name=TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer)),
)


def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
rel_l1_thresh=0.6,
return_hidden_states_first=False,
num_steps=8,
use_cache="Fb",
):
cached_transformer_blocks = nn.ModuleList([
create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps)
])

dummy_single_transformer_blocks = torch.nn.ModuleList()

original_forward = transformer.forward

@functools.wraps(original_forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)

transformer.forward = new_forward.__get__(transformer)

return transformer

Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
from typing import Type, Dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper

TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {}

def register_transformer_adapter(transformer_class: Type, adapter_name: str):
TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name

register_transformer_adapter(FluxTransformer2DModel, "flux")
register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux")

Loading
Loading