Skip to content

Commit

Permalink
tecache && fbcache added
Browse files Browse the repository at this point in the history
  • Loading branch information
Binary2355 committed Feb 23, 2025
1 parent 836cf85 commit 78edca6
Show file tree
Hide file tree
Showing 20 changed files with 858 additions and 2 deletions.
24 changes: 24 additions & 0 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_tensor_model_parallel_world_size,
get_data_parallel_world_size,
)


Expand Down Expand Up @@ -45,7 +49,27 @@ def main():
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)
from xfuser.model_executor.plugins.cache_.diffusers_adapters import apply_cache_on_transformer
use_cache = engine_args.use_teacache or engine_args.use_fbcache
if (use_cache
and get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
):
cache_args = {
"rel_l1_thresh": 0.6,
"return_hidden_states_first": False,
"num_steps": input_config.num_inference_steps,
}

if engine_args.use_fbcache and engine_args.use_teacache:
cache_args["use_cache"] = "Fb"
elif engine_args.use_teacache:
cache_args["use_cache"] = "Tea"
elif engine_args.use_fbcache:
cache_args["use_cache"] = "Fb"

pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
Expand Down
4 changes: 4 additions & 0 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ mkdir -p ./results
# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"

# cache args
# CACHE_ARGS="--use_teacache"
# CACHE_ARGS="--use_fbcache"

# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
Expand Down Expand Up @@ -64,3 +67,4 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
$CACHE_ARGS \
12 changes: 12 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class xFuserArgs:
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
use_teacache: bool = False
use_fbcache: bool = False
use_fp8_t5_encoder: bool = False

@staticmethod
Expand Down Expand Up @@ -154,6 +156,16 @@ def add_cli_args(parser: FlexibleArgumentParser):
action="store_true",
help="Enable onediff to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_teacache",
action="store_true",
help="Enable teacache to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_fbcache",
action="store_true",
help="Enable teacache to accelerate inference in a single card",
)

# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
Expand Down
2 changes: 2 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class RuntimeConfig:
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
use_teacache: bool = False
use_fbcache: bool = False

def __post_init__(self):
check_packages()
Expand Down
4 changes: 2 additions & 2 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def group_skip_rank(self):
world_size = self.world_size
return (world_size - rank_in_group - 1) % world_size

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
Expand All @@ -206,7 +206,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.world_size == 1:
return input_
else:
torch.distributed.all_reduce(input_, group=self.device_group)
torch.distributed.all_reduce(input_, op=op, group=self.device_group)
return input_

def all_gather(
Expand Down
1 change: 1 addition & 0 deletions xfuser/core/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def init_distributed_environment(
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import importlib
from typing import Type, Dict, TypeVar
from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY


def apply_cache_on_transformer(transformer, *args, **kwargs):
adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer))
if not adapter_name:
raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
70 changes: 70 additions & 0 deletions xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import functools
import unittest

import torch
from torch import nn
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY

from xfuser.model_executor.plugins.cache_ import utils

def create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps):
cached_transformer_class = {
"Fb": utils.FBCachedTransformerBlocks,
"Tea": utils.TeaCachedTransformerBlocks,
}.get(use_cache)

if not cached_transformer_class:
raise ValueError(f"Unsupported use_cache value: {use_cache}")

return cached_transformer_class(
transformer.transformer_blocks,
transformer.single_transformer_blocks,
transformer=transformer,
rel_l1_thresh=rel_l1_thresh,
return_hidden_states_first=return_hidden_states_first,
num_steps=num_steps,
name=TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer)),
)


def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
rel_l1_thresh=0.6,
return_hidden_states_first=False,
num_steps=8,
use_cache="Fb",
):
cached_transformer_blocks = nn.ModuleList([
create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps)
])

dummy_single_transformer_blocks = torch.nn.ModuleList()

original_forward = transformer.forward

@functools.wraps(original_forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)

transformer.forward = new_forward.__get__(transformer)

return transformer

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Type, Dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper

TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {}

def register_transformer_adapter(transformer_class: Type, adapter_name: str):
TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name

register_transformer_adapter(FluxTransformer2DModel, "flux")
register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux")

Loading

0 comments on commit 78edca6

Please sign in to comment.