diff --git a/docs/performance/flux.md b/docs/performance/flux.md index d1310819..4fad3d35 100644 --- a/docs/performance/flux.md +++ b/docs/performance/flux.md @@ -117,15 +117,48 @@ The quality of image generation at 2048px, 3072px, and 4096px resolutions is as ## Cache Methods -We tested the performance of TeaCache and First-Block-Cache on 4xH20 with SP=4. +We tested the performance of TeaCache and First-Block-Cache on 4xH20 with SP=4 and 1xH20 respectively. The Performance shown as below:
-| Method | Latency (s) | -|----------------|--------| -| Baseline | 2.02s | -| use_teacache | 1.58s | -| use_fbcache | 0.93s | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodLatency (s)
without torch.compilewith torch.compile
4xH201xH204xH201xH20
Baseline2.02s6.10s1.81s5.02s
use_teacache1.60s4.67s1.50s3.92s
use_fbcache0.93s2.51s0.85s2.09s
diff --git a/examples/flux_example.py b/examples/flux_example.py index c70c8279..bac45287 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -33,9 +33,18 @@ def main(): quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) + cache_args = { + "use_teacache": engine_args.use_teacache, + "use_fbcache": engine_args.use_fbcache, + "rel_l1_thresh": 0.6, + "return_hidden_states_first": False, + "num_steps": input_config.num_inference_steps, + } + pipe = xFuserFluxPipeline.from_pretrained( pretrained_model_name_or_path=engine_config.model_config.model, engine_config=engine_config, + cache_args=cache_args, torch_dtype=torch.bfloat16, text_encoder_2=text_encoder_2, ) @@ -48,28 +57,8 @@ def main(): parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - pipe.prepare_run(input_config, steps=1) - - use_cache = engine_args.use_teacache or engine_args.use_fbcache - if (use_cache - and get_pipeline_parallel_world_size() == 1 - and get_classifier_free_guidance_world_size() == 1 - and get_tensor_model_parallel_world_size() == 1 - ): - cache_args = { - "rel_l1_thresh": 0.6, - "return_hidden_states_first": False, - "num_steps": input_config.num_inference_steps, - } - - if engine_args.use_fbcache and engine_args.use_teacache: - cache_args["use_cache"] = "Fb" - elif engine_args.use_teacache: - cache_args["use_cache"] = "Tea" - elif engine_args.use_fbcache: - cache_args["use_cache"] = "Fb" + pipe.prepare_run(input_config, steps=input_config.num_inference_steps) - pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args) torch.cuda.reset_peak_memory_stats() start_time = time.time() output = pipe( diff --git a/xfuser/model_executor/cache/diffusers_adapters/__init__.py b/xfuser/model_executor/cache/diffusers_adapters/__init__.py index c23cb07d..f76ef5ef 100644 --- a/xfuser/model_executor/cache/diffusers_adapters/__init__.py +++ b/xfuser/model_executor/cache/diffusers_adapters/__init__.py @@ -5,12 +5,16 @@ import importlib from typing import Type, Dict, TypeVar from xfuser.model_executor.cache.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY +from xfuser.logger import init_logger + +logger = init_logger(__name__) def apply_cache_on_transformer(transformer, *args, **kwargs): adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer)) if not adapter_name: - raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}") + logger.error(f"Unknown transformer class: {transformer.__class__.__name__}") + return transformer adapter_module = importlib.import_module(f".{adapter_name}", __package__) apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer") diff --git a/xfuser/model_executor/cache/utils.py b/xfuser/model_executor/cache/utils.py index 9737fc0d..48457e8b 100644 --- a/xfuser/model_executor/cache/utils.py +++ b/xfuser/model_executor/cache/utils.py @@ -2,129 +2,71 @@ adapted from https://github.com/ali-vilab/TeaCache.git adapted from https://github.com/chengzeyi/ParaAttention.git """ -import contextlib import dataclasses -from collections import defaultdict -from typing import DefaultDict, Dict, Optional, List +from typing import Dict, Optional, List from xfuser.core.distributed import ( get_sp_group, get_sequence_parallel_world_size, ) import torch +from torch.nn import Module from abc import ABC, abstractmethod +# --------- CacheContext --------- # +class CacheContext(Module): + def __init__(self): + super().__init__() + self.register_buffer("default_coef", torch.tensor([1.0, 0.0]).cuda()) + self.register_buffer("flux_coef", torch.tensor([498.651651, -283.781631, 55.8554382, -3.82021401, 0.264230861]).cuda()) + + self.register_buffer("original_hidden_states", None, persistent=False) + self.register_buffer("original_encoder_hidden_states", None, persistent=False) + self.register_buffer("hidden_states_residual", None, persistent=False) + self.register_buffer("encoder_hidden_states_residual", None, persistent=False) + self.register_buffer("modulated_inputs", None, persistent=False) + + def get_coef(self, name: str) -> torch.Tensor: + return getattr(self, f"{name}_coef") + #--------- CacheCallback ---------# @dataclasses.dataclass class CacheState: - transformer: None - transformer_blocks: None - single_transformer_blocks: None - cache_context: None - rel_l1_thresh: int = 0.6 + transformer: Optional[torch.nn.Module] = None + transformer_blocks: Optional[List[torch.nn.Module]] = None + single_transformer_blocks: Optional[List[torch.nn.Module]] = None + cache_context: Optional[CacheContext] = None + rel_l1_thresh: float = 0.6 return_hidden_states_first: bool = True - use_cache: bool = False + use_cache: torch.Tensor = torch.tensor(False, dtype=torch.bool) num_steps: int = 8 name: str = "default" class CacheCallback: - def on_init_end(self, state: CacheState, **kwargs): - pass - - def on_forward_begin(self, state: CacheState, **kwargs): - pass - - def on_forward_remaining_begin(self, state: CacheState, **kwargs): - pass - - def on_forward_end(self, state: CacheState, **kwargs): - pass + def on_init_end(self, state: CacheState, **kwargs): pass + def on_forward_begin(self, state: CacheState, **kwargs): pass + def on_forward_remaining_begin(self, state: CacheState, **kwargs): pass + def on_forward_end(self, state: CacheState, **kwargs): pass class CallbackHandler(CacheCallback): - def __init__(self, callbacks): - self.callbacks = [] - if callbacks is not None: - for cb in callbacks: - self.add_callback(cb) - - def add_callback(self, callback): - cb = callback() if isinstance(callback, type) else callback - self.callbacks.append(cb) - - def pop_callback(self, callback): - if isinstance(callback, type): - for cb in self.callbacks: - if isinstance(cb, callback): - self.callbacks.remove(cb) - return cb - else: - for cb in self.callbacks: - if cb == callback: - self.callbacks.remove(cb) - return cb - - def remove_callback(self, callback): - if isinstance(callback, type): - for cb in self.callbacks: - if isinstance(cb, callback): - self.callbacks.remove(cb) - return - else: - self.callbacks.remove(callback) - - def on_init_end(self, state: CacheState): - return self.call_event("on_init_end", state) - - def on_forward_begin(self, state: CacheState): - return self.call_event("on_forward_begin", state) - - def on_forward_remaining_begin(self, state: CacheState): - return self.call_event("on_forward_remaining_begin", state) - - def on_forward_end(self, state: CacheState): - return self.call_event("on_forward_end", state) + def __init__(self, callbacks: Optional[List[CacheCallback]] = None): + self.callbacks = list(callbacks) if callbacks else [] - def call_event(self, event, state, **kwargs): - for callback in self.callbacks: - getattr(callback, event)( - state, - **kwargs, - ) + def trigger_event(self, event: str, state: CacheState): + for cb in self.callbacks: + getattr(cb, event)(state) - -#--------- CacheContext ---------# -@dataclasses.dataclass -class CacheContext: - buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) - coefficients: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) - - def __post_init__(self): - self.coefficients["default"] = torch.Tensor([1, 0]).cuda() - self.coefficients["flux"] = torch.Tensor([4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]).cuda() - - def get_coef(self, name): - return self.coefficients.get(name) - - def get_buffer(self, name): - return self.buffers.get(name) - - def set_buffer(self, name, buffer): - self.buffers[name] = buffer - - def clear_buffer(self): - self.buffers.clear() - - -#--------- torch version of poly1d ---------# -class TorchPoly1D: - def __init__(self, coefficients): - self.coefficients = torch.tensor(coefficients, dtype=torch.float32) +# --------- Vectorized Poly1D --------- # +class VectorizedPoly1D(Module): + def __init__(self, coefficients: torch.Tensor): + super().__init__() + self.register_buffer("coefficients", coefficients) self.degree = len(coefficients) - 1 - def __call__(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: result = torch.zeros_like(x) for i, coef in enumerate(self.coefficients): result += coef * (x ** (self.degree - i)) @@ -134,116 +76,91 @@ def __call__(self, x): class CachedTransformerBlocks(torch.nn.Module, ABC): def __init__( self, - transformer_blocks, - single_transformer_blocks=None, + transformer_blocks: List[Module], + single_transformer_blocks: Optional[List[Module]] = None, *, - transformer=None, - rel_l1_thresh=0.6, - return_hidden_states_first=True, - num_steps=-1, - name="default", + transformer: Optional[Module] = None, + rel_l1_thresh: float = 0.6, + return_hidden_states_first: bool = True, + num_steps: int = -1, + name: str = "default", callbacks: Optional[List[CacheCallback]] = None, ): super().__init__() - self.state = CacheState( - transformer=transformer, - transformer_blocks=transformer_blocks, - single_transformer_blocks=single_transformer_blocks, - cache_context=CacheContext(), - ) - self.state.rel_l1_thresh = rel_l1_thresh - self.state.return_hidden_states_first = return_hidden_states_first - self.state.use_cache = False - self.state.num_steps=num_steps - self.state.name=name + self.transformer_blocks = torch.nn.ModuleList(transformer_blocks) + self.single_transformer_blocks = torch.nn.ModuleList(single_transformer_blocks) if single_transformer_blocks else None + self.transformer = transformer + self.register_buffer("cnt", torch.tensor(0).cuda()) + self.register_buffer("accumulated_rel_l1_distance", torch.tensor([0.0]).cuda()) + self.register_buffer("use_cache", torch.tensor(False, dtype=torch.bool).cuda()) + + self.cache_context = CacheContext() self.callback_handler = CallbackHandler(callbacks) - self.callback_handler.on_init_end(self.state) - - def is_parallelized(self): - if get_sequence_parallel_world_size() > 1: - return True - return False - - def all_reduce(self, input_, op): - if get_sequence_parallel_world_size() > 1: - return get_sp_group().all_reduce(input_=input_, op=op) - raise NotImplementedError("Cache method not support parrellism other than sp") - - def l1_distance_two_tensor(self, t1, t2): - mean_diff = (t1 - t2).abs().mean() - mean_t1 = t1.abs().mean() - if self.is_parallelized(): - mean_diff = self.all_reduce(mean_diff.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0] - mean_t1 = self.all_reduce(mean_t1.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0] - diff = mean_diff / mean_t1 - return diff - @abstractmethod - def are_two_tensor_similar(self, t1, t2, threshold): - pass + self.rel_l1_thresh = torch.tensor(rel_l1_thresh).cuda() + self.return_hidden_states_first = return_hidden_states_first + self.num_steps = num_steps + self.name = name + self.callback_handler.trigger_event("on_init_begin", self) + + @property + def is_parallelized(self) -> bool: + return get_sequence_parallel_world_size() > 1 + + def all_reduce(self, input_: torch.Tensor, op=torch.distributed.ReduceOp.AVG) -> torch.Tensor: + return get_sp_group().all_reduce(input_, op=op) if self.is_parallelized else input_ - def run_one_block_transformer(self, block, hidden_states, encoder_hidden_states, *args, **kwargs): - hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) - return ( - (hidden_states, encoder_hidden_states) - if self.state.return_hidden_states_first - else (encoder_hidden_states, hidden_states) - ) + def l1_distance(self, t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor: + diff = (t1 - t2).abs().mean() + norm = t1.abs().mean() + diff, norm = self.all_reduce(diff.unsqueeze(0)), self.all_reduce(norm.unsqueeze(0)) + return (diff / norm).squeeze() @abstractmethod - def get_start_idx(self): - pass - - def get_remaining_block_result(self, hidden_states, encoder_hidden_states, *args, **kwargs): - original_hidden_states = self.state.cache_context.get_buffer("original_hidden_states") - original_encoder_hidden_states = self.state.cache_context.get_buffer("original_encoder_hidden_states") - start_idx = self.get_start_idx() - if start_idx == -1: - return (hidden_states, encoder_hidden_states) - for block in self.state.transformer_blocks[start_idx:]: - hidden_states, encoder_hidden_states = \ - self.run_one_block_transformer(block, hidden_states, encoder_hidden_states, *args, **kwargs) - if self.state.single_transformer_blocks is not None: - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - for block in self.state.single_transformer_blocks: - hidden_states = block(hidden_states, *args, **kwargs) - encoder_hidden_states, hidden_states = hidden_states.split( - [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 - ) - self.state.cache_context.set_buffer("hidden_states_residual", hidden_states - original_hidden_states) - self.state.cache_context.set_buffer("encoder_hidden_states_residual", - encoder_hidden_states - original_encoder_hidden_states) - return (hidden_states, encoder_hidden_states) + def are_two_tensor_similar(self, t1: torch.Tensor, t2: torch.Tensor, threshold: float) -> torch.Tensor: pass @abstractmethod - def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs): - pass + def get_start_idx(self) -> int: pass + + @abstractmethod + def get_modulated_inputs(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, **kwargs): pass + + def process_blocks(self, start_idx: int, hidden: torch.Tensor, encoder: torch.Tensor, *args, **kwargs): + for block in self.transformer_blocks[start_idx:]: + hidden, encoder = block(hidden, encoder, *args, **kwargs) + hidden, encoder = (hidden, encoder) if self.return_hidden_states_first else (encoder, hidden) + + if self.single_transformer_blocks: + hidden = torch.cat([encoder, hidden], dim=1) + for block in self.single_transformer_blocks: + hidden = block(hidden, *args, **kwargs) + encoder, hidden = hidden.split([encoder.shape[1], hidden.shape[1] - encoder.shape[1]], dim=1) + + self.cache_context.hidden_states_residual = hidden - self.cache_context.original_hidden_states + self.cache_context.encoder_hidden_states_residual = encoder - self.cache_context.original_encoder_hidden_states + return hidden, encoder def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): - self.callback_handler.on_forward_begin(self.state) - modulated_inputs, prev_modulated_inputs, original_hidden_states, original_encoder_hidden_states = \ + self.callback_handler.trigger_event("on_forward_begin", self) + + modulated, prev_modulated, orig_hidden, orig_encoder = \ self.get_modulated_inputs(hidden_states, encoder_hidden_states, *args, **kwargs) - self.state.cache_context.set_buffer("original_hidden_states", original_hidden_states) - self.state.cache_context.set_buffer("original_encoder_hidden_states", original_encoder_hidden_states) + self.cache_context.original_hidden_states = orig_hidden + self.cache_context.original_encoder_hidden_states = orig_encoder - self.state.use_cache = prev_modulated_inputs is not None and self.are_two_tensor_similar( - t1=prev_modulated_inputs, t2=modulated_inputs, threshold=self.state.rel_l1_thresh) + self.use_cache = self.are_two_tensor_similar(prev_modulated, modulated, self.rel_l1_thresh) \ + if prev_modulated is not None else torch.tensor(False, dtype=torch.bool) - self.callback_handler.on_forward_remaining_begin(self.state) - if self.state.use_cache: - hidden_states += self.state.cache_context.get_buffer("hidden_states_residual") - encoder_hidden_states += self.state.cache_context.get_buffer("encoder_hidden_states_residual") + self.callback_handler.trigger_event("on_forward_remaining_begin", self) + if self.use_cache: + hidden = hidden_states + self.cache_context.hidden_states_residual + encoder = encoder_hidden_states + self.cache_context.encoder_hidden_states_residual else: - hidden_states, encoder_hidden_states = self.get_remaining_block_result( - original_hidden_states, original_encoder_hidden_states, *args, **kwargs) + hidden, encoder = self.process_blocks(self.get_start_idx(), orig_hidden, orig_encoder, *args, **kwargs) - self.callback_handler.on_forward_end(self.state) - return ( - (hidden_states, encoder_hidden_states) - if self.state.return_hidden_states_first - else (encoder_hidden_states, hidden_states) - ) + self.callback_handler.trigger_event("on_forward_end", self) + return ((hidden, encoder) if self.return_hidden_states_first else (encoder, hidden)) class FBCachedTransformerBlocks(CachedTransformerBlocks): @@ -268,20 +185,21 @@ def __init__( name=name, callbacks=callbacks) - def get_start_idx(self): + def get_start_idx(self) -> int: return 1 - def are_two_tensor_similar(self, t1, t2, threshold): - return self.l1_distance_two_tensor(t1, t2) < threshold + def are_two_tensor_similar(self, t1: torch.Tensor, t2: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: + return self.l1_distance(t1, t2) < threshold def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs): original_hidden_states = hidden_states - first_transformer_block = self.state.transformer_blocks[0] - hidden_states, encoder_hidden_states = \ - self.run_one_block_transformer(first_transformer_block, hidden_states, encoder_hidden_states, *args, **kwargs) + first_transformer_block = self.transformer_blocks[0] + hidden_states, encoder_hidden_states = first_transformer_block(hidden_states, encoder_hidden_states, *args, **kwargs) + hidden_states, encoder_hidden_states = (hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states) first_hidden_states_residual = hidden_states - original_hidden_states - prev_first_hidden_states_residual = self.state.cache_context.get_buffer("modulated_inputs") - self.state.cache_context.set_buffer("modulated_inputs", first_hidden_states_residual) + prev_first_hidden_states_residual = self.cache_context.modulated_inputs + if not self.use_cache: + self.cache_context.modulated_inputs = first_hidden_states_residual return first_hidden_states_residual, prev_first_hidden_states_residual, hidden_states, encoder_hidden_states @@ -307,39 +225,25 @@ def __init__( return_hidden_states_first=return_hidden_states_first, name=name, callbacks=callbacks) - object.__setattr__(self.state, 'cnt', 0) - object.__setattr__(self.state, 'accumulated_rel_l1_distance', 0) - object.__setattr__(self.state, 'rescale_func', TorchPoly1D(self.state.cache_context.get_coef(self.state.name))) + self.rescale_func = VectorizedPoly1D(self.cache_context.get_coef(self.name)) - def get_start_idx(self): + def get_start_idx(self) -> int: return 0 - def are_two_tensor_similar(self, t1, t2, threshold): - if self.state.cnt == 0 or self.state.cnt == self.state.num_steps-1: - self.state.accumulated_rel_l1_distance = 0 - self.state.use_cache = False - else: - diff = self.l1_distance_two_tensor(t1, t2) - self.state.accumulated_rel_l1_distance += self.state.rescale_func(diff) - if self.state.accumulated_rel_l1_distance < threshold: - self.state.use_cache = True - else: - self.state.use_cache = False - self.state.accumulated_rel_l1_distance = 0 - self.state.cnt += 1 - if self.state.cnt == self.state.num_steps: - self.state.cnt = 0 - return self.state.use_cache + def are_two_tensor_similar(self, t1: torch.Tensor, t2: torch.Tensor, threshold: float) -> torch.Tensor: + diff = self.l1_distance(t1, t2) + new_accum = self.accumulated_rel_l1_distance + self.rescale_func(diff) + reset_mask = (self.cnt == 0) or (self.cnt == self.num_steps - 1) + self.use_cache = torch.logical_and(new_accum < threshold, torch.logical_not(reset_mask)) + self.accumulated_rel_l1_distance[0] = torch.where(self.use_cache, new_accum[0], 0.0) + self.cnt = torch.where(self.cnt + 1 < self.num_steps, self.cnt + 1, 0) + return self.use_cache def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs): inp = hidden_states.clone() - temb_ = kwargs.get("temb", None) - if temb_ is not None: - temb_ = temb_.clone() - else: - raise ValueError("'temb' not found in kwargs") - modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.state.transformer_blocks[0].norm1(inp, emb=temb_) - previous_modulated_input = self.state.cache_context.get_buffer("modulated_inputs") - self.state.cache_context.set_buffer("modulated_inputs", modulated_inp) - return modulated_inp, previous_modulated_input, hidden_states, encoder_hidden_states \ No newline at end of file + temb_ = kwargs.get("temb", None).clone() + modulated, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_) + prev_modulated = self.cache_context.modulated_inputs + self.cache_context.modulated_inputs = modulated + return modulated, prev_modulated, hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index ad94e1dd..d76eee70 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -48,6 +48,7 @@ from xfuser.model_executor.schedulers import * from xfuser.model_executor.models.transformers import * from xfuser.model_executor.layers.attention_processor import * +from xfuser.model_executor.cache.diffusers_adapters import apply_cache_on_transformer from xfuser.config.config import ParallelConfig try: import os @@ -142,6 +143,7 @@ def __init__( self, pipeline: DiffusionPipeline, engine_config: EngineConfig, + cache_args: Optional[Dict] = None, ): self.module: DiffusionPipeline self.engine_config = engine_config @@ -161,6 +163,7 @@ def __init__( transformer, enable_torch_compile=engine_config.runtime_config.use_torch_compile, enable_onediff=engine_config.runtime_config.use_onediff, + cache_args=cache_args, ) elif unet is not None: pipeline.unet = self._convert_unet_backbone(unet) @@ -357,7 +360,7 @@ def _init_fast_attn_state( initialize_fast_attn_state(pipeline=pipeline, single_config=engine_config.fast_attn_config) def _convert_transformer_backbone( - self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool + self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool, cache_args: Optional[Dict] = None, ): if ( get_pipeline_parallel_world_size() == 1 @@ -379,7 +382,29 @@ def _convert_transformer_backbone( logger.warning( f"apply --use_torch_compile and --use_onediff togather. we use torch compile only" ) - + if cache_args: + use_teacache = cache_args["use_teacache"] + use_fbcache = cache_args["use_fbcache"] + cache_args.pop("use_teacache") + cache_args.pop("use_fbcache") + use_cache = use_teacache or use_fbcache + use_cache = ( + use_cache + and get_pipeline_parallel_world_size() == 1 + and get_classifier_free_guidance_world_size() == 1 + and get_tensor_model_parallel_world_size() == 1 + ) + if use_cache: + if use_teacache and use_fbcache: + logger.warning(f"apply --use_teacache and --use_fbcache togather. we use FBCache") + cache_args["use_cache"] = "Fb" + elif use_teacache: + cache_args["use_cache"] = "Tea" + elif use_fbcache: + cache_args["use_cache"] = "Fb" + + transformer = apply_cache_on_transformer(transformer, **cache_args) + self.original_transformer = transformer if enable_torch_compile or enable_onediff: if getattr(transformer, "forward") is not None: if enable_torch_compile: diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index 1550eb05..b87a00b2 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -57,13 +57,14 @@ def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], engine_config: EngineConfig, + cache_args: Dict={}, return_org_pipeline: bool = False, **kwargs, ): pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) if return_org_pipeline: return pipeline - return cls(pipeline, engine_config) + return cls(pipeline, engine_config, cache_args) def prepare_run( self,