diff --git a/examples/flux_example.py b/examples/flux_example.py index ff578bbf..935e65aa 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -11,6 +11,10 @@ get_data_parallel_world_size, get_runtime_state, is_dp_last_group, + get_pipeline_parallel_world_size, + get_classifier_free_guidance_world_size, + get_tensor_model_parallel_world_size, + get_data_parallel_world_size, ) @@ -45,6 +49,25 @@ def main(): parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") pipe.prepare_run(input_config, steps=1) + from xfuser.model_executor.plugins.teacache.diffusers_adapters import apply_teacache_on_transformer + from xfuser.model_executor.plugins.first_block_cache.diffusers_adapters import apply_fbcache_on_transformer + use_cache = engine_args.use_teacache or engine_args.use_fbcache + print(f"636998 use_teacache[{engine_args.use_teacache}]use_fbcache[{engine_args.use_fbcache}]") + if (use_cache + and get_pipeline_parallel_world_size() == 1 + and get_classifier_free_guidance_world_size() == 1 + and get_tensor_model_parallel_world_size() == 1 + ): + use_cache = True + if engine_args.use_fbcache and engine_args.use_teacache: + pipe.transformer = apply_fbcache_on_transformer( + pipe.transformer, rel_l1_thresh=0.6, use_cache=use_cache, return_hidden_states_first=False) + elif engine_args.use_teacache: + pipe.transformer = apply_teacache_on_transformer( + pipe.transformer, rel_l1_thresh=0.6, use_cache=use_cache, num_steps=input_config.num_inference_steps) + elif engine_args.use_fbcache: + pipe.transformer = apply_fbcache_on_transformer( + pipe.transformer, rel_l1_thresh=0.6, use_cache=use_cache, return_hidden_states_first=False) torch.cuda.reset_peak_memory_stats() start_time = time.time() diff --git a/examples/run.sh b/examples/run.sh index 12897814..de7b16f6 100644 --- a/examples/run.sh +++ b/examples/run.sh @@ -27,6 +27,9 @@ mkdir -p ./results # task args TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" +# cache args +# CACHE_ARGS="--use_teacache" +CACHE_ARGS="--use_fbcache" # On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch) N_GPUS=8 @@ -64,3 +67,4 @@ $CFG_ARGS \ $PARALLLEL_VAE \ $COMPILE_FLAG \ $QUANTIZE_FLAG \ +$CACHE_ARGS \ diff --git a/xfuser/config/args.py b/xfuser/config/args.py index b4afabe8..7f5095f3 100644 --- a/xfuser/config/args.py +++ b/xfuser/config/args.py @@ -111,6 +111,8 @@ class xFuserArgs: window_size: int = 64 coco_path: Optional[str] = None use_cache: bool = False + use_teacache: bool = False + use_fbcache: bool = False use_fp8_t5_encoder: bool = False @staticmethod @@ -154,6 +156,16 @@ def add_cli_args(parser: FlexibleArgumentParser): action="store_true", help="Enable onediff to accelerate inference in a single card", ) + runtime_group.add_argument( + "--use_teacache", + action="store_true", + help="Enable teacache to accelerate inference in a single card", + ) + runtime_group.add_argument( + "--use_fbcache", + action="store_true", + help="Enable teacache to accelerate inference in a single card", + ) # Parallel arguments parallel_group = parser.add_argument_group("Parallel Processing Options") diff --git a/xfuser/config/config.py b/xfuser/config/config.py index 2750b020..a87f0cc5 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -60,6 +60,8 @@ class RuntimeConfig: use_torch_compile: bool = False use_onediff: bool = False use_fp8_t5_encoder: bool = False + use_teacache: bool = False + use_fbcache: bool = False def __post_init__(self): check_packages() diff --git a/xfuser/core/distributed/parallel_state.py b/xfuser/core/distributed/parallel_state.py index 35590e5c..67b434e1 100644 --- a/xfuser/core/distributed/parallel_state.py +++ b/xfuser/core/distributed/parallel_state.py @@ -217,6 +217,7 @@ def init_distributed_environment( world_size=world_size, rank=rank, ) + torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 diff --git a/xfuser/model_executor/plugins/__init__.py b/xfuser/model_executor/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xfuser/model_executor/plugins/first_block_cache/__init__.py b/xfuser/model_executor/plugins/first_block_cache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xfuser/model_executor/plugins/first_block_cache/diffusers_adapters/__init__.py b/xfuser/model_executor/plugins/first_block_cache/diffusers_adapters/__init__.py new file mode 100644 index 00000000..eaa7557e --- /dev/null +++ b/xfuser/model_executor/plugins/first_block_cache/diffusers_adapters/__init__.py @@ -0,0 +1,18 @@ +import importlib + +from diffusers import DiffusionPipeline + + +def apply_fbcache_on_transformer(transformer, *args, **kwargs): + transformer_cls_name = transformer.__class__.__name__ + if False: + pass + elif transformer_cls_name.startswith("Flux") or transformer_cls_name.startswith("xFuserFlux"): + adapter_name = "flux" + else: + raise ValueError(f"Unknown transformer class name: {transformer_cls_name}") + + adapter_module = importlib.import_module(f".{adapter_name}", __package__) + apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer") + return apply_cache_on_transformer_fn(transformer, *args, **kwargs) + diff --git a/xfuser/model_executor/plugins/first_block_cache/diffusers_adapters/flux.py b/xfuser/model_executor/plugins/first_block_cache/diffusers_adapters/flux.py new file mode 100644 index 00000000..1699a42e --- /dev/null +++ b/xfuser/model_executor/plugins/first_block_cache/diffusers_adapters/flux.py @@ -0,0 +1,56 @@ +import functools +import unittest + +import torch +from diffusers import DiffusionPipeline, FluxTransformer2DModel + +from xfuser.model_executor.plugins.first_block_cache import utils + + +def apply_cache_on_transformer( + transformer: FluxTransformer2DModel, + *, + rel_l1_thresh=0.6, + use_cache=True, + return_hidden_states_first=False, +): + cached_transformer_blocks = torch.nn.ModuleList( + [ + utils.FBCachedTransformerBlocks( + transformer.transformer_blocks, + transformer.single_transformer_blocks, + transformer=transformer, + rel_l1_thresh=rel_l1_thresh, + return_hidden_states_first=return_hidden_states_first, + enable_fbcache=use_cache, + ) + ] + ) + dummy_single_transformer_blocks = torch.nn.ModuleList() + + original_forward = transformer.forward + + @functools.wraps(original_forward) + def new_forward( + self, + *args, + **kwargs, + ): + with unittest.mock.patch.object( + self, + "transformer_blocks", + cached_transformer_blocks, + ), unittest.mock.patch.object( + self, + "single_transformer_blocks", + dummy_single_transformer_blocks, + ): + return original_forward( + *args, + **kwargs, + ) + + transformer.forward = new_forward.__get__(transformer) + + return transformer + diff --git a/xfuser/model_executor/plugins/first_block_cache/utils.py b/xfuser/model_executor/plugins/first_block_cache/utils.py new file mode 100644 index 00000000..39d88112 --- /dev/null +++ b/xfuser/model_executor/plugins/first_block_cache/utils.py @@ -0,0 +1,113 @@ +import contextlib +import dataclasses +from collections import defaultdict +from typing import DefaultDict, Dict +from xfuser.core.distributed import ( + get_sp_group, + get_sequence_parallel_world_size, +) + +import torch + + +@dataclasses.dataclass +class CacheContext: + first_hidden_states_residual: torch.Tensor = None + hidden_states_residual: torch.Tensor = None + encoder_hidden_states_residual: torch.Tensor = None + + def clear_buffers(self): + self.first_hidden_states_residual = None + self.hidden_states_residual = None + self.encoder_hidden_states_residual = None + + +class FBCachedTransformerBlocks(torch.nn.Module): + def __init__( + self, + transformer_blocks, + single_transformer_blocks=None, + *, + transformer=None, + rel_l1_thresh=0.6, + return_hidden_states_first=True, + enable_fbcache=True, + ): + super().__init__() + self.transformer = transformer + self.transformer_blocks = transformer_blocks + self.single_transformer_blocks = single_transformer_blocks + self.rel_l1_thresh = rel_l1_thresh + self.return_hidden_states_first = return_hidden_states_first + self.enable_fbcache = enable_fbcache + self.cache_context = CacheContext() + + def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): + if not self.enable_fbcache: + # the branch to disable cache + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if self.single_transformer_blocks is not None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for block in self.single_transformer_blocks: + hidden_states = block(hidden_states, *args, **kwargs) + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :] + return ( + (hidden_states, encoder_hidden_states) + if self.return_hidden_states_first + else (encoder_hidden_states, hidden_states) + ) + + # run first block of transformer + original_hidden_states = hidden_states + first_transformer_block = self.transformer_blocks[0] + hidden_states, encoder_hidden_states = first_transformer_block( + hidden_states, encoder_hidden_states, *args, **kwargs + ) + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + first_hidden_states_residual = hidden_states - original_hidden_states + del original_hidden_states + + prev_first_hidden_states_residual = self.cache_context.first_hidden_states_residual + + if prev_first_hidden_states_residual is None: + use_cache = False + else: + mean_diff = (first_hidden_states_residual-prev_first_hidden_states_residual).abs().mean() + mean_t1 = prev_first_hidden_states_residual.abs().mean() + if get_sequence_parallel_world_size() > 1: + mean_diff = get_sp_group().all_gather(mean_diff.unsqueeze(0)).mean() + mean_t1 = get_sp_group().all_gather(mean_t1.unsqueeze(0)).mean() + diff = mean_diff / mean_t1 + use_cache = diff < self.rel_l1_thresh + + if use_cache: + del first_hidden_states_residual + hidden_states += self.cache_context.hidden_states_residual + encoder_hidden_states += self.cache_context.encoder_hidden_states_residual + else: + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + self.cache_context.first_hidden_states_residual = first_hidden_states_residual + for block in self.transformer_blocks[1:]: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if self.single_transformer_blocks is not None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for block in self.single_transformer_blocks: + hidden_states = block(hidden_states, *args, **kwargs) + encoder_hidden_states, hidden_states = hidden_states.split( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + self.cache_context.hidden_states_residual = hidden_states - original_hidden_states + self.cache_context.encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states + + return ( + (hidden_states, encoder_hidden_states) + if self.return_hidden_states_first + else (encoder_hidden_states, hidden_states) + ) diff --git a/xfuser/model_executor/plugins/teacache/__init__.py b/xfuser/model_executor/plugins/teacache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xfuser/model_executor/plugins/teacache/diffusers_adapters/__init__.py b/xfuser/model_executor/plugins/teacache/diffusers_adapters/__init__.py new file mode 100644 index 00000000..a4e8ba65 --- /dev/null +++ b/xfuser/model_executor/plugins/teacache/diffusers_adapters/__init__.py @@ -0,0 +1,17 @@ +import importlib + +from diffusers import DiffusionPipeline + + +def apply_teacache_on_transformer(transformer, *args, **kwargs): + transformer_cls_name = transformer.__class__.__name__ + if False: + pass + elif transformer_cls_name.startswith("Flux") or transformer_cls_name.startswith("xFuserFlux"): + adapter_name = "flux" + else: + raise ValueError(f"Unknown transformer class name: {transformer_cls_name}") + + adapter_module = importlib.import_module(f".{adapter_name}", __package__) + apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer") + return apply_cache_on_transformer_fn(transformer, *args, **kwargs) diff --git a/xfuser/model_executor/plugins/teacache/diffusers_adapters/flux.py b/xfuser/model_executor/plugins/teacache/diffusers_adapters/flux.py new file mode 100644 index 00000000..e3affa43 --- /dev/null +++ b/xfuser/model_executor/plugins/teacache/diffusers_adapters/flux.py @@ -0,0 +1,59 @@ +import functools +import unittest + +import torch +from diffusers import DiffusionPipeline, FluxTransformer2DModel + +from xfuser.model_executor.plugins.teacache import utils + + +def apply_cache_on_transformer( + transformer: FluxTransformer2DModel, + *, + rel_l1_thresh=0.6, + use_cache=True, + num_steps=8, + return_hidden_states_first=False, + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], +): + cached_transformer_blocks = torch.nn.ModuleList( + [ + utils.TeaCachedTransformerBlocks( + transformer.transformer_blocks, + transformer.single_transformer_blocks, + transformer=transformer, + enable_teacache=use_cache, + num_steps=num_steps, + rel_l1_thresh=rel_l1_thresh, + return_hidden_states_first=return_hidden_states_first, + coefficients=coefficients, + ) + ] + ) + dummy_single_transformer_blocks = torch.nn.ModuleList() + + original_forward = transformer.forward + + @functools.wraps(original_forward) + def new_forward( + self, + *args, + **kwargs, + ): + with unittest.mock.patch.object( + self, + "transformer_blocks", + cached_transformer_blocks, + ), unittest.mock.patch.object( + self, + "single_transformer_blocks", + dummy_single_transformer_blocks, + ): + return original_forward( + *args, + **kwargs, + ) + + transformer.forward = new_forward.__get__(transformer) + + return transformer diff --git a/xfuser/model_executor/plugins/teacache/utils.py b/xfuser/model_executor/plugins/teacache/utils.py new file mode 100644 index 00000000..76c9c930 --- /dev/null +++ b/xfuser/model_executor/plugins/teacache/utils.py @@ -0,0 +1,118 @@ +import contextlib +import dataclasses +from collections import defaultdict +from typing import DefaultDict, Dict +from xfuser.core.distributed import ( + get_sp_group, + get_sequence_parallel_world_size, +) + +import torch + +class TorchPoly1D: + def __init__(self, coefficients): + self.coefficients = torch.tensor(coefficients, dtype=torch.float32) + self.degree = len(coefficients) - 1 + + def __call__(self, x): + result = torch.zeros_like(x) + for i, coef in enumerate(self.coefficients): + result += coef * (x ** (self.degree - i)) + return result + +class TeaCachedTransformerBlocks(torch.nn.Module): + def __init__( + self, + transformer_blocks, + single_transformer_blocks=None, + *, + transformer=None, + enable_teacache=True, + num_steps=8, + rel_l1_thresh=0.6, + return_hidden_states_first=True, + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], + ): + super().__init__() + self.transformer = transformer + self.transformer_blocks = transformer_blocks + self.single_transformer_blocks = single_transformer_blocks + self.cnt = 0 + self.enable_teacache = enable_teacache + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.previous_residual_encoder = None + self.coefficients = coefficients + self.return_hidden_states_first = return_hidden_states_first + self.rescale_func = TorchPoly1D(coefficients) + + def forward(self, hidden_states, encoder_hidden_states, temb, *args, **kwargs): + if not self.enable_teacache: + # the branch to disable cache + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, *args, **kwargs) + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if self.single_transformer_blocks is not None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for block in self.single_transformer_blocks: + hidden_states = block(hidden_states, temb, *args, **kwargs) + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :] + return ( + (hidden_states, encoder_hidden_states) + if self.return_hidden_states_first + else (encoder_hidden_states, hidden_states) + ) + + original_hidden_states = hidden_states + original_encoder_hidden_states = encoder_hidden_states + first_transformer_block = self.transformer_blocks[0] + inp = hidden_states.clone() + temb_ = temb.clone() + modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = first_transformer_block.norm1(inp, emb=temb_) + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + mean_diff = (modulated_inp-self.previous_modulated_input).abs().mean() + mean_t1 = self.previous_modulated_input.abs().mean() + if get_sequence_parallel_world_size() > 1: + mean_diff = get_sp_group().all_gather(mean_diff.unsqueeze(0)).mean() + mean_t1 = get_sp_group().all_gather(mean_t1.unsqueeze(0)).mean() + self.accumulated_rel_l1_distance += self.rescale_func(mean_diff / mean_t1) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.cnt += 1 + if self.cnt == self.num_steps: + self.cnt = 0 + + if not should_calc: + hidden_states += self.previous_residual + encoder_hidden_states += self.previous_residual_encoder + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, *args, **kwargs) + if not self.return_hidden_states_first: + hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states + if self.single_transformer_blocks is not None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for block in self.single_transformer_blocks: + hidden_states = block(hidden_states, temb, *args, **kwargs) + encoder_hidden_states, hidden_states = hidden_states.split( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + self.previous_residual = hidden_states - original_hidden_states + self.previous_residual_encoder = encoder_hidden_states - original_encoder_hidden_states + + return ( + (hidden_states, encoder_hidden_states) + if self.return_hidden_states_first + else (encoder_hidden_states, hidden_states) + )