Skip to content

Commit

Permalink
tecache && fbcache added
Browse files Browse the repository at this point in the history
  • Loading branch information
Binary2355 committed Feb 20, 2025
1 parent 836cf85 commit 54aede0
Show file tree
Hide file tree
Showing 14 changed files with 418 additions and 0 deletions.
22 changes: 22 additions & 0 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
get_pipeline_parallel_world_size,
get_classifier_free_guidance_world_size,
get_tensor_model_parallel_world_size,
get_data_parallel_world_size,
)


Expand Down Expand Up @@ -45,6 +49,24 @@ def main():
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)
from xfuser.model_executor.plugins.teacache.diffusers_adapters import apply_teacache_on_transformer
from xfuser.model_executor.plugins.first_block_cache.diffusers_adapters import apply_fbcache_on_transformer
use_cache = engine_args.use_teacache or engine_args.use_fbcache
if (use_cache
and get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
):
use_cache = True
if engine_args.use_fbcache and engine_args.use_teacache:
pipe.transformer = apply_fbcache_on_transformer(
pipe.transformer, rel_l1_thresh=0.6, use_cache=use_cache, return_hidden_states_first=False)
elif engine_args.use_teacache:
pipe.transformer = apply_teacache_on_transformer(
pipe.transformer, rel_l1_thresh=0.6, use_cache=use_cache, num_steps=input_config.num_inference_steps)
elif engine_args.use_fbcache:
pipe.transformer = apply_fbcache_on_transformer(
pipe.transformer, rel_l1_thresh=0.6, use_cache=use_cache, return_hidden_states_first=False)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Expand Down
4 changes: 4 additions & 0 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ mkdir -p ./results
# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"

# cache args
# CACHE_ARGS="--use_teacache"
# CACHE_ARGS="--use_fbcache"

# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
Expand Down Expand Up @@ -64,3 +67,4 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
$CACHE_ARGS \
12 changes: 12 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class xFuserArgs:
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
use_teacache: bool = False
use_fbcache: bool = False
use_fp8_t5_encoder: bool = False

@staticmethod
Expand Down Expand Up @@ -154,6 +156,16 @@ def add_cli_args(parser: FlexibleArgumentParser):
action="store_true",
help="Enable onediff to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_teacache",
action="store_true",
help="Enable teacache to accelerate inference in a single card",
)
runtime_group.add_argument(
"--use_fbcache",
action="store_true",
help="Enable teacache to accelerate inference in a single card",
)

# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
Expand Down
2 changes: 2 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class RuntimeConfig:
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
use_teacache: bool = False
use_fbcache: bool = False

def __post_init__(self):
check_packages()
Expand Down
1 change: 1 addition & 0 deletions xfuser/core/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def init_distributed_environment(
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import importlib

from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper


def apply_fbcache_on_transformer(transformer, *args, **kwargs):
if isinstance(transformer, (FluxTransformer2DModel, xFuserFluxTransformer2DWrapper)):
adapter_name = "flux"
else:
raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import functools
import unittest

import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel

from xfuser.model_executor.plugins.first_block_cache import utils


def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
rel_l1_thresh=0.6,
use_cache=True,
return_hidden_states_first=False,
):
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.FBCachedTransformerBlocks(
transformer.transformer_blocks,
transformer.single_transformer_blocks,
transformer=transformer,
rel_l1_thresh=rel_l1_thresh,
return_hidden_states_first=return_hidden_states_first,
enable_fbcache=use_cache,
)
]
)
dummy_single_transformer_blocks = torch.nn.ModuleList()

original_forward = transformer.forward

@functools.wraps(original_forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)

transformer.forward = new_forward.__get__(transformer)

return transformer

113 changes: 113 additions & 0 deletions xfuser/model_executor/plugins/first_block_cache/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict
from xfuser.core.distributed import (
get_sp_group,
get_sequence_parallel_world_size,
)

import torch


@dataclasses.dataclass
class CacheContext:
first_hidden_states_residual: torch.Tensor = None
hidden_states_residual: torch.Tensor = None
encoder_hidden_states_residual: torch.Tensor = None

def clear_buffers(self):
self.first_hidden_states_residual = None
self.hidden_states_residual = None
self.encoder_hidden_states_residual = None


class FBCachedTransformerBlocks(torch.nn.Module):
def __init__(
self,
transformer_blocks,
single_transformer_blocks=None,
*,
transformer=None,
rel_l1_thresh=0.6,
return_hidden_states_first=True,
enable_fbcache=True,
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer_blocks
self.single_transformer_blocks = single_transformer_blocks
self.rel_l1_thresh = rel_l1_thresh
self.return_hidden_states_first = return_hidden_states_first
self.enable_fbcache = enable_fbcache
self.cache_context = CacheContext()

def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
if not self.enable_fbcache:
# the branch to disable cache
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)

# run first block of transformer
original_hidden_states = hidden_states
first_transformer_block = self.transformer_blocks[0]
hidden_states, encoder_hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs
)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states

prev_first_hidden_states_residual = self.cache_context.first_hidden_states_residual

if prev_first_hidden_states_residual is None:
use_cache = False
else:
mean_diff = (first_hidden_states_residual-prev_first_hidden_states_residual).abs().mean()
mean_t1 = prev_first_hidden_states_residual.abs().mean()
if get_sequence_parallel_world_size() > 1:
mean_diff = get_sp_group().all_gather(mean_diff.unsqueeze(0)).mean()
mean_t1 = get_sp_group().all_gather(mean_t1.unsqueeze(0)).mean()
diff = mean_diff / mean_t1
use_cache = diff < self.rel_l1_thresh

if use_cache:
del first_hidden_states_residual
hidden_states += self.cache_context.hidden_states_residual
encoder_hidden_states += self.cache_context.encoder_hidden_states_residual
else:
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
self.cache_context.first_hidden_states_residual = first_hidden_states_residual
for block in self.transformer_blocks[1:]:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
encoder_hidden_states, hidden_states = hidden_states.split(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
self.cache_context.hidden_states_residual = hidden_states - original_hidden_states
self.cache_context.encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states

return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import importlib

from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper


def apply_teacache_on_transformer(transformer, *args, **kwargs):
if isinstance(transformer, (FluxTransformer2DModel, xFuserFluxTransformer2DWrapper)):
adapter_name = "flux"
else:
raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
59 changes: 59 additions & 0 deletions xfuser/model_executor/plugins/teacache/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import functools
import unittest

import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel

from xfuser.model_executor.plugins.teacache import utils


def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
rel_l1_thresh=0.6,
use_cache=True,
num_steps=8,
return_hidden_states_first=False,
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01],
):
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.TeaCachedTransformerBlocks(
transformer.transformer_blocks,
transformer.single_transformer_blocks,
transformer=transformer,
enable_teacache=use_cache,
num_steps=num_steps,
rel_l1_thresh=rel_l1_thresh,
return_hidden_states_first=return_hidden_states_first,
coefficients=coefficients,
)
]
)
dummy_single_transformer_blocks = torch.nn.ModuleList()

original_forward = transformer.forward

@functools.wraps(original_forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)

transformer.forward = new_forward.__get__(transformer)

return transformer
Loading

0 comments on commit 54aede0

Please sign in to comment.