From 248abecf8d8bdbb2401460f78a0f17094e7c4d31 Mon Sep 17 00:00:00 2001
From: Tsekai Lee <44702332+Binary2355@users.noreply.github.com>
Date: Tue, 25 Feb 2025 16:50:06 +0800
Subject: [PATCH] add 4xH20 performance and 1xH20 performance with
torch.compile (#453)
---
docs/performance/flux.md | 45 ++-
examples/flux_example.py | 31 +-
.../cache/diffusers_adapters/__init__.py | 6 +-
xfuser/model_executor/cache/utils.py | 354 +++++++-----------
.../model_executor/pipelines/base_pipeline.py | 29 +-
.../model_executor/pipelines/pipeline_flux.py | 3 +-
6 files changed, 212 insertions(+), 256 deletions(-)
diff --git a/docs/performance/flux.md b/docs/performance/flux.md
index d1310819..4fad3d35 100644
--- a/docs/performance/flux.md
+++ b/docs/performance/flux.md
@@ -117,15 +117,48 @@ The quality of image generation at 2048px, 3072px, and 4096px resolutions is as
## Cache Methods
-We tested the performance of TeaCache and First-Block-Cache on 4xH20 with SP=4.
+We tested the performance of TeaCache and First-Block-Cache on 4xH20 with SP=4 and 1xH20 respectively.
The Performance shown as below:
-| Method | Latency (s) |
-|----------------|--------|
-| Baseline | 2.02s |
-| use_teacache | 1.58s |
-| use_fbcache | 0.93s |
+
+
+ Method |
+ Latency (s) |
+
+
+ without torch.compile |
+ with torch.compile |
+
+
+ |
+ 4xH20 |
+ 1xH20 |
+ 4xH20 |
+ 1xH20 |
+
+
+ Baseline |
+ 2.02s |
+ 6.10s |
+ 1.81s |
+ 5.02s |
+
+
+ use_teacache |
+ 1.60s |
+ 4.67s |
+ 1.50s |
+ 3.92s |
+
+
+ use_fbcache |
+ 0.93s |
+ 2.51s |
+ 0.85s |
+ 2.09s |
+
+
diff --git a/examples/flux_example.py b/examples/flux_example.py
index c70c8279..bac45287 100644
--- a/examples/flux_example.py
+++ b/examples/flux_example.py
@@ -33,9 +33,18 @@ def main():
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
+ cache_args = {
+ "use_teacache": engine_args.use_teacache,
+ "use_fbcache": engine_args.use_fbcache,
+ "rel_l1_thresh": 0.6,
+ "return_hidden_states_first": False,
+ "num_steps": input_config.num_inference_steps,
+ }
+
pipe = xFuserFluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
+ cache_args=cache_args,
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder_2,
)
@@ -48,28 +57,8 @@ def main():
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
- pipe.prepare_run(input_config, steps=1)
-
- use_cache = engine_args.use_teacache or engine_args.use_fbcache
- if (use_cache
- and get_pipeline_parallel_world_size() == 1
- and get_classifier_free_guidance_world_size() == 1
- and get_tensor_model_parallel_world_size() == 1
- ):
- cache_args = {
- "rel_l1_thresh": 0.6,
- "return_hidden_states_first": False,
- "num_steps": input_config.num_inference_steps,
- }
-
- if engine_args.use_fbcache and engine_args.use_teacache:
- cache_args["use_cache"] = "Fb"
- elif engine_args.use_teacache:
- cache_args["use_cache"] = "Tea"
- elif engine_args.use_fbcache:
- cache_args["use_cache"] = "Fb"
+ pipe.prepare_run(input_config, steps=input_config.num_inference_steps)
- pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
diff --git a/xfuser/model_executor/cache/diffusers_adapters/__init__.py b/xfuser/model_executor/cache/diffusers_adapters/__init__.py
index c23cb07d..f76ef5ef 100644
--- a/xfuser/model_executor/cache/diffusers_adapters/__init__.py
+++ b/xfuser/model_executor/cache/diffusers_adapters/__init__.py
@@ -5,12 +5,16 @@
import importlib
from typing import Type, Dict, TypeVar
from xfuser.model_executor.cache.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY
+from xfuser.logger import init_logger
+
+logger = init_logger(__name__)
def apply_cache_on_transformer(transformer, *args, **kwargs):
adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer))
if not adapter_name:
- raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")
+ logger.error(f"Unknown transformer class: {transformer.__class__.__name__}")
+ return transformer
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
diff --git a/xfuser/model_executor/cache/utils.py b/xfuser/model_executor/cache/utils.py
index 9737fc0d..48457e8b 100644
--- a/xfuser/model_executor/cache/utils.py
+++ b/xfuser/model_executor/cache/utils.py
@@ -2,129 +2,71 @@
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
-import contextlib
import dataclasses
-from collections import defaultdict
-from typing import DefaultDict, Dict, Optional, List
+from typing import Dict, Optional, List
from xfuser.core.distributed import (
get_sp_group,
get_sequence_parallel_world_size,
)
import torch
+from torch.nn import Module
from abc import ABC, abstractmethod
+# --------- CacheContext --------- #
+class CacheContext(Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("default_coef", torch.tensor([1.0, 0.0]).cuda())
+ self.register_buffer("flux_coef", torch.tensor([498.651651, -283.781631, 55.8554382, -3.82021401, 0.264230861]).cuda())
+
+ self.register_buffer("original_hidden_states", None, persistent=False)
+ self.register_buffer("original_encoder_hidden_states", None, persistent=False)
+ self.register_buffer("hidden_states_residual", None, persistent=False)
+ self.register_buffer("encoder_hidden_states_residual", None, persistent=False)
+ self.register_buffer("modulated_inputs", None, persistent=False)
+
+ def get_coef(self, name: str) -> torch.Tensor:
+ return getattr(self, f"{name}_coef")
+
#--------- CacheCallback ---------#
@dataclasses.dataclass
class CacheState:
- transformer: None
- transformer_blocks: None
- single_transformer_blocks: None
- cache_context: None
- rel_l1_thresh: int = 0.6
+ transformer: Optional[torch.nn.Module] = None
+ transformer_blocks: Optional[List[torch.nn.Module]] = None
+ single_transformer_blocks: Optional[List[torch.nn.Module]] = None
+ cache_context: Optional[CacheContext] = None
+ rel_l1_thresh: float = 0.6
return_hidden_states_first: bool = True
- use_cache: bool = False
+ use_cache: torch.Tensor = torch.tensor(False, dtype=torch.bool)
num_steps: int = 8
name: str = "default"
class CacheCallback:
- def on_init_end(self, state: CacheState, **kwargs):
- pass
-
- def on_forward_begin(self, state: CacheState, **kwargs):
- pass
-
- def on_forward_remaining_begin(self, state: CacheState, **kwargs):
- pass
-
- def on_forward_end(self, state: CacheState, **kwargs):
- pass
+ def on_init_end(self, state: CacheState, **kwargs): pass
+ def on_forward_begin(self, state: CacheState, **kwargs): pass
+ def on_forward_remaining_begin(self, state: CacheState, **kwargs): pass
+ def on_forward_end(self, state: CacheState, **kwargs): pass
class CallbackHandler(CacheCallback):
- def __init__(self, callbacks):
- self.callbacks = []
- if callbacks is not None:
- for cb in callbacks:
- self.add_callback(cb)
-
- def add_callback(self, callback):
- cb = callback() if isinstance(callback, type) else callback
- self.callbacks.append(cb)
-
- def pop_callback(self, callback):
- if isinstance(callback, type):
- for cb in self.callbacks:
- if isinstance(cb, callback):
- self.callbacks.remove(cb)
- return cb
- else:
- for cb in self.callbacks:
- if cb == callback:
- self.callbacks.remove(cb)
- return cb
-
- def remove_callback(self, callback):
- if isinstance(callback, type):
- for cb in self.callbacks:
- if isinstance(cb, callback):
- self.callbacks.remove(cb)
- return
- else:
- self.callbacks.remove(callback)
-
- def on_init_end(self, state: CacheState):
- return self.call_event("on_init_end", state)
-
- def on_forward_begin(self, state: CacheState):
- return self.call_event("on_forward_begin", state)
-
- def on_forward_remaining_begin(self, state: CacheState):
- return self.call_event("on_forward_remaining_begin", state)
-
- def on_forward_end(self, state: CacheState):
- return self.call_event("on_forward_end", state)
+ def __init__(self, callbacks: Optional[List[CacheCallback]] = None):
+ self.callbacks = list(callbacks) if callbacks else []
- def call_event(self, event, state, **kwargs):
- for callback in self.callbacks:
- getattr(callback, event)(
- state,
- **kwargs,
- )
+ def trigger_event(self, event: str, state: CacheState):
+ for cb in self.callbacks:
+ getattr(cb, event)(state)
-
-#--------- CacheContext ---------#
-@dataclasses.dataclass
-class CacheContext:
- buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
- coefficients: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
-
- def __post_init__(self):
- self.coefficients["default"] = torch.Tensor([1, 0]).cuda()
- self.coefficients["flux"] = torch.Tensor([4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]).cuda()
-
- def get_coef(self, name):
- return self.coefficients.get(name)
-
- def get_buffer(self, name):
- return self.buffers.get(name)
-
- def set_buffer(self, name, buffer):
- self.buffers[name] = buffer
-
- def clear_buffer(self):
- self.buffers.clear()
-
-
-#--------- torch version of poly1d ---------#
-class TorchPoly1D:
- def __init__(self, coefficients):
- self.coefficients = torch.tensor(coefficients, dtype=torch.float32)
+# --------- Vectorized Poly1D --------- #
+class VectorizedPoly1D(Module):
+ def __init__(self, coefficients: torch.Tensor):
+ super().__init__()
+ self.register_buffer("coefficients", coefficients)
self.degree = len(coefficients) - 1
- def __call__(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)
for i, coef in enumerate(self.coefficients):
result += coef * (x ** (self.degree - i))
@@ -134,116 +76,91 @@ def __call__(self, x):
class CachedTransformerBlocks(torch.nn.Module, ABC):
def __init__(
self,
- transformer_blocks,
- single_transformer_blocks=None,
+ transformer_blocks: List[Module],
+ single_transformer_blocks: Optional[List[Module]] = None,
*,
- transformer=None,
- rel_l1_thresh=0.6,
- return_hidden_states_first=True,
- num_steps=-1,
- name="default",
+ transformer: Optional[Module] = None,
+ rel_l1_thresh: float = 0.6,
+ return_hidden_states_first: bool = True,
+ num_steps: int = -1,
+ name: str = "default",
callbacks: Optional[List[CacheCallback]] = None,
):
super().__init__()
- self.state = CacheState(
- transformer=transformer,
- transformer_blocks=transformer_blocks,
- single_transformer_blocks=single_transformer_blocks,
- cache_context=CacheContext(),
- )
- self.state.rel_l1_thresh = rel_l1_thresh
- self.state.return_hidden_states_first = return_hidden_states_first
- self.state.use_cache = False
- self.state.num_steps=num_steps
- self.state.name=name
+ self.transformer_blocks = torch.nn.ModuleList(transformer_blocks)
+ self.single_transformer_blocks = torch.nn.ModuleList(single_transformer_blocks) if single_transformer_blocks else None
+ self.transformer = transformer
+ self.register_buffer("cnt", torch.tensor(0).cuda())
+ self.register_buffer("accumulated_rel_l1_distance", torch.tensor([0.0]).cuda())
+ self.register_buffer("use_cache", torch.tensor(False, dtype=torch.bool).cuda())
+
+ self.cache_context = CacheContext()
self.callback_handler = CallbackHandler(callbacks)
- self.callback_handler.on_init_end(self.state)
-
- def is_parallelized(self):
- if get_sequence_parallel_world_size() > 1:
- return True
- return False
-
- def all_reduce(self, input_, op):
- if get_sequence_parallel_world_size() > 1:
- return get_sp_group().all_reduce(input_=input_, op=op)
- raise NotImplementedError("Cache method not support parrellism other than sp")
-
- def l1_distance_two_tensor(self, t1, t2):
- mean_diff = (t1 - t2).abs().mean()
- mean_t1 = t1.abs().mean()
- if self.is_parallelized():
- mean_diff = self.all_reduce(mean_diff.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0]
- mean_t1 = self.all_reduce(mean_t1.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0]
- diff = mean_diff / mean_t1
- return diff
- @abstractmethod
- def are_two_tensor_similar(self, t1, t2, threshold):
- pass
+ self.rel_l1_thresh = torch.tensor(rel_l1_thresh).cuda()
+ self.return_hidden_states_first = return_hidden_states_first
+ self.num_steps = num_steps
+ self.name = name
+ self.callback_handler.trigger_event("on_init_begin", self)
+
+ @property
+ def is_parallelized(self) -> bool:
+ return get_sequence_parallel_world_size() > 1
+
+ def all_reduce(self, input_: torch.Tensor, op=torch.distributed.ReduceOp.AVG) -> torch.Tensor:
+ return get_sp_group().all_reduce(input_, op=op) if self.is_parallelized else input_
- def run_one_block_transformer(self, block, hidden_states, encoder_hidden_states, *args, **kwargs):
- hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
- return (
- (hidden_states, encoder_hidden_states)
- if self.state.return_hidden_states_first
- else (encoder_hidden_states, hidden_states)
- )
+ def l1_distance(self, t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
+ diff = (t1 - t2).abs().mean()
+ norm = t1.abs().mean()
+ diff, norm = self.all_reduce(diff.unsqueeze(0)), self.all_reduce(norm.unsqueeze(0))
+ return (diff / norm).squeeze()
@abstractmethod
- def get_start_idx(self):
- pass
-
- def get_remaining_block_result(self, hidden_states, encoder_hidden_states, *args, **kwargs):
- original_hidden_states = self.state.cache_context.get_buffer("original_hidden_states")
- original_encoder_hidden_states = self.state.cache_context.get_buffer("original_encoder_hidden_states")
- start_idx = self.get_start_idx()
- if start_idx == -1:
- return (hidden_states, encoder_hidden_states)
- for block in self.state.transformer_blocks[start_idx:]:
- hidden_states, encoder_hidden_states = \
- self.run_one_block_transformer(block, hidden_states, encoder_hidden_states, *args, **kwargs)
- if self.state.single_transformer_blocks is not None:
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- for block in self.state.single_transformer_blocks:
- hidden_states = block(hidden_states, *args, **kwargs)
- encoder_hidden_states, hidden_states = hidden_states.split(
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
- )
- self.state.cache_context.set_buffer("hidden_states_residual", hidden_states - original_hidden_states)
- self.state.cache_context.set_buffer("encoder_hidden_states_residual",
- encoder_hidden_states - original_encoder_hidden_states)
- return (hidden_states, encoder_hidden_states)
+ def are_two_tensor_similar(self, t1: torch.Tensor, t2: torch.Tensor, threshold: float) -> torch.Tensor: pass
@abstractmethod
- def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs):
- pass
+ def get_start_idx(self) -> int: pass
+
+ @abstractmethod
+ def get_modulated_inputs(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, **kwargs): pass
+
+ def process_blocks(self, start_idx: int, hidden: torch.Tensor, encoder: torch.Tensor, *args, **kwargs):
+ for block in self.transformer_blocks[start_idx:]:
+ hidden, encoder = block(hidden, encoder, *args, **kwargs)
+ hidden, encoder = (hidden, encoder) if self.return_hidden_states_first else (encoder, hidden)
+
+ if self.single_transformer_blocks:
+ hidden = torch.cat([encoder, hidden], dim=1)
+ for block in self.single_transformer_blocks:
+ hidden = block(hidden, *args, **kwargs)
+ encoder, hidden = hidden.split([encoder.shape[1], hidden.shape[1] - encoder.shape[1]], dim=1)
+
+ self.cache_context.hidden_states_residual = hidden - self.cache_context.original_hidden_states
+ self.cache_context.encoder_hidden_states_residual = encoder - self.cache_context.original_encoder_hidden_states
+ return hidden, encoder
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
- self.callback_handler.on_forward_begin(self.state)
- modulated_inputs, prev_modulated_inputs, original_hidden_states, original_encoder_hidden_states = \
+ self.callback_handler.trigger_event("on_forward_begin", self)
+
+ modulated, prev_modulated, orig_hidden, orig_encoder = \
self.get_modulated_inputs(hidden_states, encoder_hidden_states, *args, **kwargs)
- self.state.cache_context.set_buffer("original_hidden_states", original_hidden_states)
- self.state.cache_context.set_buffer("original_encoder_hidden_states", original_encoder_hidden_states)
+ self.cache_context.original_hidden_states = orig_hidden
+ self.cache_context.original_encoder_hidden_states = orig_encoder
- self.state.use_cache = prev_modulated_inputs is not None and self.are_two_tensor_similar(
- t1=prev_modulated_inputs, t2=modulated_inputs, threshold=self.state.rel_l1_thresh)
+ self.use_cache = self.are_two_tensor_similar(prev_modulated, modulated, self.rel_l1_thresh) \
+ if prev_modulated is not None else torch.tensor(False, dtype=torch.bool)
- self.callback_handler.on_forward_remaining_begin(self.state)
- if self.state.use_cache:
- hidden_states += self.state.cache_context.get_buffer("hidden_states_residual")
- encoder_hidden_states += self.state.cache_context.get_buffer("encoder_hidden_states_residual")
+ self.callback_handler.trigger_event("on_forward_remaining_begin", self)
+ if self.use_cache:
+ hidden = hidden_states + self.cache_context.hidden_states_residual
+ encoder = encoder_hidden_states + self.cache_context.encoder_hidden_states_residual
else:
- hidden_states, encoder_hidden_states = self.get_remaining_block_result(
- original_hidden_states, original_encoder_hidden_states, *args, **kwargs)
+ hidden, encoder = self.process_blocks(self.get_start_idx(), orig_hidden, orig_encoder, *args, **kwargs)
- self.callback_handler.on_forward_end(self.state)
- return (
- (hidden_states, encoder_hidden_states)
- if self.state.return_hidden_states_first
- else (encoder_hidden_states, hidden_states)
- )
+ self.callback_handler.trigger_event("on_forward_end", self)
+ return ((hidden, encoder) if self.return_hidden_states_first else (encoder, hidden))
class FBCachedTransformerBlocks(CachedTransformerBlocks):
@@ -268,20 +185,21 @@ def __init__(
name=name,
callbacks=callbacks)
- def get_start_idx(self):
+ def get_start_idx(self) -> int:
return 1
- def are_two_tensor_similar(self, t1, t2, threshold):
- return self.l1_distance_two_tensor(t1, t2) < threshold
+ def are_two_tensor_similar(self, t1: torch.Tensor, t2: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
+ return self.l1_distance(t1, t2) < threshold
def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs):
original_hidden_states = hidden_states
- first_transformer_block = self.state.transformer_blocks[0]
- hidden_states, encoder_hidden_states = \
- self.run_one_block_transformer(first_transformer_block, hidden_states, encoder_hidden_states, *args, **kwargs)
+ first_transformer_block = self.transformer_blocks[0]
+ hidden_states, encoder_hidden_states = first_transformer_block(hidden_states, encoder_hidden_states, *args, **kwargs)
+ hidden_states, encoder_hidden_states = (hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)
first_hidden_states_residual = hidden_states - original_hidden_states
- prev_first_hidden_states_residual = self.state.cache_context.get_buffer("modulated_inputs")
- self.state.cache_context.set_buffer("modulated_inputs", first_hidden_states_residual)
+ prev_first_hidden_states_residual = self.cache_context.modulated_inputs
+ if not self.use_cache:
+ self.cache_context.modulated_inputs = first_hidden_states_residual
return first_hidden_states_residual, prev_first_hidden_states_residual, hidden_states, encoder_hidden_states
@@ -307,39 +225,25 @@ def __init__(
return_hidden_states_first=return_hidden_states_first,
name=name,
callbacks=callbacks)
- object.__setattr__(self.state, 'cnt', 0)
- object.__setattr__(self.state, 'accumulated_rel_l1_distance', 0)
- object.__setattr__(self.state, 'rescale_func', TorchPoly1D(self.state.cache_context.get_coef(self.state.name)))
+ self.rescale_func = VectorizedPoly1D(self.cache_context.get_coef(self.name))
- def get_start_idx(self):
+ def get_start_idx(self) -> int:
return 0
- def are_two_tensor_similar(self, t1, t2, threshold):
- if self.state.cnt == 0 or self.state.cnt == self.state.num_steps-1:
- self.state.accumulated_rel_l1_distance = 0
- self.state.use_cache = False
- else:
- diff = self.l1_distance_two_tensor(t1, t2)
- self.state.accumulated_rel_l1_distance += self.state.rescale_func(diff)
- if self.state.accumulated_rel_l1_distance < threshold:
- self.state.use_cache = True
- else:
- self.state.use_cache = False
- self.state.accumulated_rel_l1_distance = 0
- self.state.cnt += 1
- if self.state.cnt == self.state.num_steps:
- self.state.cnt = 0
- return self.state.use_cache
+ def are_two_tensor_similar(self, t1: torch.Tensor, t2: torch.Tensor, threshold: float) -> torch.Tensor:
+ diff = self.l1_distance(t1, t2)
+ new_accum = self.accumulated_rel_l1_distance + self.rescale_func(diff)
+ reset_mask = (self.cnt == 0) or (self.cnt == self.num_steps - 1)
+ self.use_cache = torch.logical_and(new_accum < threshold, torch.logical_not(reset_mask))
+ self.accumulated_rel_l1_distance[0] = torch.where(self.use_cache, new_accum[0], 0.0)
+ self.cnt = torch.where(self.cnt + 1 < self.num_steps, self.cnt + 1, 0)
+ return self.use_cache
def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs):
inp = hidden_states.clone()
- temb_ = kwargs.get("temb", None)
- if temb_ is not None:
- temb_ = temb_.clone()
- else:
- raise ValueError("'temb' not found in kwargs")
- modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.state.transformer_blocks[0].norm1(inp, emb=temb_)
- previous_modulated_input = self.state.cache_context.get_buffer("modulated_inputs")
- self.state.cache_context.set_buffer("modulated_inputs", modulated_inp)
- return modulated_inp, previous_modulated_input, hidden_states, encoder_hidden_states
\ No newline at end of file
+ temb_ = kwargs.get("temb", None).clone()
+ modulated, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_)
+ prev_modulated = self.cache_context.modulated_inputs
+ self.cache_context.modulated_inputs = modulated
+ return modulated, prev_modulated, hidden_states, encoder_hidden_states
\ No newline at end of file
diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py
index ad94e1dd..d76eee70 100644
--- a/xfuser/model_executor/pipelines/base_pipeline.py
+++ b/xfuser/model_executor/pipelines/base_pipeline.py
@@ -48,6 +48,7 @@
from xfuser.model_executor.schedulers import *
from xfuser.model_executor.models.transformers import *
from xfuser.model_executor.layers.attention_processor import *
+from xfuser.model_executor.cache.diffusers_adapters import apply_cache_on_transformer
from xfuser.config.config import ParallelConfig
try:
import os
@@ -142,6 +143,7 @@ def __init__(
self,
pipeline: DiffusionPipeline,
engine_config: EngineConfig,
+ cache_args: Optional[Dict] = None,
):
self.module: DiffusionPipeline
self.engine_config = engine_config
@@ -161,6 +163,7 @@ def __init__(
transformer,
enable_torch_compile=engine_config.runtime_config.use_torch_compile,
enable_onediff=engine_config.runtime_config.use_onediff,
+ cache_args=cache_args,
)
elif unet is not None:
pipeline.unet = self._convert_unet_backbone(unet)
@@ -357,7 +360,7 @@ def _init_fast_attn_state(
initialize_fast_attn_state(pipeline=pipeline, single_config=engine_config.fast_attn_config)
def _convert_transformer_backbone(
- self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool
+ self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool, cache_args: Optional[Dict] = None,
):
if (
get_pipeline_parallel_world_size() == 1
@@ -379,7 +382,29 @@ def _convert_transformer_backbone(
logger.warning(
f"apply --use_torch_compile and --use_onediff togather. we use torch compile only"
)
-
+ if cache_args:
+ use_teacache = cache_args["use_teacache"]
+ use_fbcache = cache_args["use_fbcache"]
+ cache_args.pop("use_teacache")
+ cache_args.pop("use_fbcache")
+ use_cache = use_teacache or use_fbcache
+ use_cache = (
+ use_cache
+ and get_pipeline_parallel_world_size() == 1
+ and get_classifier_free_guidance_world_size() == 1
+ and get_tensor_model_parallel_world_size() == 1
+ )
+ if use_cache:
+ if use_teacache and use_fbcache:
+ logger.warning(f"apply --use_teacache and --use_fbcache togather. we use FBCache")
+ cache_args["use_cache"] = "Fb"
+ elif use_teacache:
+ cache_args["use_cache"] = "Tea"
+ elif use_fbcache:
+ cache_args["use_cache"] = "Fb"
+
+ transformer = apply_cache_on_transformer(transformer, **cache_args)
+ self.original_transformer = transformer
if enable_torch_compile or enable_onediff:
if getattr(transformer, "forward") is not None:
if enable_torch_compile:
diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py
index 1550eb05..b87a00b2 100644
--- a/xfuser/model_executor/pipelines/pipeline_flux.py
+++ b/xfuser/model_executor/pipelines/pipeline_flux.py
@@ -57,13 +57,14 @@ def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
engine_config: EngineConfig,
+ cache_args: Dict={},
return_org_pipeline: bool = False,
**kwargs,
):
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
if return_org_pipeline:
return pipeline
- return cls(pipeline, engine_config)
+ return cls(pipeline, engine_config, cache_args)
def prepare_run(
self,