diff --git a/README.md b/README.md
index 19e790a9..05f77c0d 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,20 @@ We also have implemented the following parallel strategies for reference:
1. Tensor Parallelism
2. [DistriFusion](https://arxiv.org/abs/2402.19481)
+
Cache Acceleration
+
+Cache method is inspired by work from TeaCache(https://github.com/ali-vilab/TeaCache.git) and ParaAttn(https://github.com/chengzeyi/ParaAttention.git); We adapted the TeaCache and First-Block-Cache in xDiT.
+
+This method is not orthogonal to parallel in xDiT. Only when SP or no parrallelism can activate the cache function.
+
+To use this functionality, you can activate it by `--use_teacache` or `--use_fbcache`, which activate TeaCache and First-Block-Cache respectively. Right now, this repo only supports FLUX model.
+
+The Performance shown as below, tested on 4 H20 with SP=4:
+| 方法 | 性能 |
+|----------------|--------|
+| 原始 | 2.02s |
+| use_teacache | 1.58s |
+| use_fbcache | 0.93s |
Computing Acceleration
diff --git a/examples/flux_example.py b/examples/flux_example.py
index ff578bbf..921a6262 100644
--- a/examples/flux_example.py
+++ b/examples/flux_example.py
@@ -11,6 +11,10 @@
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
+ get_pipeline_parallel_world_size,
+ get_classifier_free_guidance_world_size,
+ get_tensor_model_parallel_world_size,
+ get_data_parallel_world_size,
)
@@ -45,7 +49,27 @@ def main():
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
pipe.prepare_run(input_config, steps=1)
+ from xfuser.model_executor.plugins.cache_.diffusers_adapters import apply_cache_on_transformer
+ use_cache = engine_args.use_teacache or engine_args.use_fbcache
+ if (use_cache
+ and get_pipeline_parallel_world_size() == 1
+ and get_classifier_free_guidance_world_size() == 1
+ and get_tensor_model_parallel_world_size() == 1
+ ):
+ cache_args = {
+ "rel_l1_thresh": 0.6,
+ "return_hidden_states_first": False,
+ "num_steps": input_config.num_inference_steps,
+ }
+
+ if engine_args.use_fbcache and engine_args.use_teacache:
+ cache_args["use_cache"] = "Fb"
+ elif engine_args.use_teacache:
+ cache_args["use_cache"] = "Tea"
+ elif engine_args.use_fbcache:
+ cache_args["use_cache"] = "Fb"
+ pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
diff --git a/examples/run.sh b/examples/run.sh
index 12897814..12473309 100644
--- a/examples/run.sh
+++ b/examples/run.sh
@@ -10,7 +10,7 @@ declare -A MODEL_CONFIGS=(
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
- ["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
+ ["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev/ 28"
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
)
@@ -27,6 +27,9 @@ mkdir -p ./results
# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"
+# cache args
+# CACHE_ARGS="--use_teacache"
+# CACHE_ARGS="--use_fbcache"
# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
@@ -64,3 +67,4 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
+$CACHE_ARGS \
diff --git a/xfuser/config/args.py b/xfuser/config/args.py
index b4afabe8..7f5095f3 100644
--- a/xfuser/config/args.py
+++ b/xfuser/config/args.py
@@ -111,6 +111,8 @@ class xFuserArgs:
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
+ use_teacache: bool = False
+ use_fbcache: bool = False
use_fp8_t5_encoder: bool = False
@staticmethod
@@ -154,6 +156,16 @@ def add_cli_args(parser: FlexibleArgumentParser):
action="store_true",
help="Enable onediff to accelerate inference in a single card",
)
+ runtime_group.add_argument(
+ "--use_teacache",
+ action="store_true",
+ help="Enable teacache to accelerate inference in a single card",
+ )
+ runtime_group.add_argument(
+ "--use_fbcache",
+ action="store_true",
+ help="Enable teacache to accelerate inference in a single card",
+ )
# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
diff --git a/xfuser/config/config.py b/xfuser/config/config.py
index 2750b020..a87f0cc5 100644
--- a/xfuser/config/config.py
+++ b/xfuser/config/config.py
@@ -60,6 +60,8 @@ class RuntimeConfig:
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
+ use_teacache: bool = False
+ use_fbcache: bool = False
def __post_init__(self):
check_packages()
diff --git a/xfuser/core/distributed/group_coordinator.py b/xfuser/core/distributed/group_coordinator.py
index 0402756f..b3c613a6 100644
--- a/xfuser/core/distributed/group_coordinator.py
+++ b/xfuser/core/distributed/group_coordinator.py
@@ -196,7 +196,7 @@ def group_skip_rank(self):
world_size = self.world_size
return (world_size - rank_in_group - 1) % world_size
- def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
+ def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
@@ -206,7 +206,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.world_size == 1:
return input_
else:
- torch.distributed.all_reduce(input_, group=self.device_group)
+ torch.distributed.all_reduce(input_, op=op, group=self.device_group)
return input_
def all_gather(
diff --git a/xfuser/core/distributed/parallel_state.py b/xfuser/core/distributed/parallel_state.py
index 35590e5c..67b434e1 100644
--- a/xfuser/core/distributed/parallel_state.py
+++ b/xfuser/core/distributed/parallel_state.py
@@ -217,6 +217,7 @@ def init_distributed_environment(
world_size=world_size,
rank=rank,
)
+ torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
diff --git a/xfuser/model_executor/plugins/__init__.py b/xfuser/model_executor/plugins/__init__.py
new file mode 100644
index 00000000..8e81242b
--- /dev/null
+++ b/xfuser/model_executor/plugins/__init__.py
@@ -0,0 +1,4 @@
+"""
+adapted from https://github.com/ali-vilab/TeaCache.git
+adapted from https://github.com/chengzeyi/ParaAttention.git
+"""
\ No newline at end of file
diff --git a/xfuser/model_executor/plugins/cache_/__init__.py b/xfuser/model_executor/plugins/cache_/__init__.py
new file mode 100644
index 00000000..8e81242b
--- /dev/null
+++ b/xfuser/model_executor/plugins/cache_/__init__.py
@@ -0,0 +1,4 @@
+"""
+adapted from https://github.com/ali-vilab/TeaCache.git
+adapted from https://github.com/chengzeyi/ParaAttention.git
+"""
\ No newline at end of file
diff --git a/xfuser/model_executor/plugins/cache_/diffusers_adapters/__init__.py b/xfuser/model_executor/plugins/cache_/diffusers_adapters/__init__.py
new file mode 100644
index 00000000..211afb9a
--- /dev/null
+++ b/xfuser/model_executor/plugins/cache_/diffusers_adapters/__init__.py
@@ -0,0 +1,17 @@
+"""
+adapted from https://github.com/ali-vilab/TeaCache.git
+adapted from https://github.com/chengzeyi/ParaAttention.git
+"""
+import importlib
+from typing import Type, Dict, TypeVar
+from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY
+
+
+def apply_cache_on_transformer(transformer, *args, **kwargs):
+ adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer))
+ if not adapter_name:
+ raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")
+
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
+ apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
+ return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
diff --git a/xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py b/xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py
new file mode 100644
index 00000000..c3b40ba2
--- /dev/null
+++ b/xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py
@@ -0,0 +1,74 @@
+"""
+adapted from https://github.com/ali-vilab/TeaCache.git
+adapted from https://github.com/chengzeyi/ParaAttention.git
+"""
+import functools
+import unittest
+
+import torch
+from torch import nn
+from diffusers import DiffusionPipeline, FluxTransformer2DModel
+from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY
+
+from xfuser.model_executor.plugins.cache_ import utils
+
+def create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps):
+ cached_transformer_class = {
+ "Fb": utils.FBCachedTransformerBlocks,
+ "Tea": utils.TeaCachedTransformerBlocks,
+ }.get(use_cache)
+
+ if not cached_transformer_class:
+ raise ValueError(f"Unsupported use_cache value: {use_cache}")
+
+ return cached_transformer_class(
+ transformer.transformer_blocks,
+ transformer.single_transformer_blocks,
+ transformer=transformer,
+ rel_l1_thresh=rel_l1_thresh,
+ return_hidden_states_first=return_hidden_states_first,
+ num_steps=num_steps,
+ name=TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer)),
+ )
+
+
+def apply_cache_on_transformer(
+ transformer: FluxTransformer2DModel,
+ *,
+ rel_l1_thresh=0.6,
+ return_hidden_states_first=False,
+ num_steps=8,
+ use_cache="Fb",
+):
+ cached_transformer_blocks = nn.ModuleList([
+ create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps)
+ ])
+
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
+
+ original_forward = transformer.forward
+
+ @functools.wraps(original_forward)
+ def new_forward(
+ self,
+ *args,
+ **kwargs,
+ ):
+ with unittest.mock.patch.object(
+ self,
+ "transformer_blocks",
+ cached_transformer_blocks,
+ ), unittest.mock.patch.object(
+ self,
+ "single_transformer_blocks",
+ dummy_single_transformer_blocks,
+ ):
+ return original_forward(
+ *args,
+ **kwargs,
+ )
+
+ transformer.forward = new_forward.__get__(transformer)
+
+ return transformer
+
diff --git a/xfuser/model_executor/plugins/cache_/diffusers_adapters/registry.py b/xfuser/model_executor/plugins/cache_/diffusers_adapters/registry.py
new file mode 100644
index 00000000..2232832a
--- /dev/null
+++ b/xfuser/model_executor/plugins/cache_/diffusers_adapters/registry.py
@@ -0,0 +1,16 @@
+"""
+adapted from https://github.com/ali-vilab/TeaCache.git
+adapted from https://github.com/chengzeyi/ParaAttention.git
+"""
+from typing import Type, Dict
+from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
+from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper
+
+TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {}
+
+def register_transformer_adapter(transformer_class: Type, adapter_name: str):
+ TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name
+
+register_transformer_adapter(FluxTransformer2DModel, "flux")
+register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux")
+
diff --git a/xfuser/model_executor/plugins/cache_/utils.py b/xfuser/model_executor/plugins/cache_/utils.py
new file mode 100644
index 00000000..9737fc0d
--- /dev/null
+++ b/xfuser/model_executor/plugins/cache_/utils.py
@@ -0,0 +1,345 @@
+"""
+adapted from https://github.com/ali-vilab/TeaCache.git
+adapted from https://github.com/chengzeyi/ParaAttention.git
+"""
+import contextlib
+import dataclasses
+from collections import defaultdict
+from typing import DefaultDict, Dict, Optional, List
+from xfuser.core.distributed import (
+ get_sp_group,
+ get_sequence_parallel_world_size,
+)
+
+import torch
+from abc import ABC, abstractmethod
+
+
+#--------- CacheCallback ---------#
+@dataclasses.dataclass
+class CacheState:
+ transformer: None
+ transformer_blocks: None
+ single_transformer_blocks: None
+ cache_context: None
+ rel_l1_thresh: int = 0.6
+ return_hidden_states_first: bool = True
+ use_cache: bool = False
+ num_steps: int = 8
+ name: str = "default"
+
+
+class CacheCallback:
+ def on_init_end(self, state: CacheState, **kwargs):
+ pass
+
+ def on_forward_begin(self, state: CacheState, **kwargs):
+ pass
+
+ def on_forward_remaining_begin(self, state: CacheState, **kwargs):
+ pass
+
+ def on_forward_end(self, state: CacheState, **kwargs):
+ pass
+
+
+class CallbackHandler(CacheCallback):
+ def __init__(self, callbacks):
+ self.callbacks = []
+ if callbacks is not None:
+ for cb in callbacks:
+ self.add_callback(cb)
+
+ def add_callback(self, callback):
+ cb = callback() if isinstance(callback, type) else callback
+ self.callbacks.append(cb)
+
+ def pop_callback(self, callback):
+ if isinstance(callback, type):
+ for cb in self.callbacks:
+ if isinstance(cb, callback):
+ self.callbacks.remove(cb)
+ return cb
+ else:
+ for cb in self.callbacks:
+ if cb == callback:
+ self.callbacks.remove(cb)
+ return cb
+
+ def remove_callback(self, callback):
+ if isinstance(callback, type):
+ for cb in self.callbacks:
+ if isinstance(cb, callback):
+ self.callbacks.remove(cb)
+ return
+ else:
+ self.callbacks.remove(callback)
+
+ def on_init_end(self, state: CacheState):
+ return self.call_event("on_init_end", state)
+
+ def on_forward_begin(self, state: CacheState):
+ return self.call_event("on_forward_begin", state)
+
+ def on_forward_remaining_begin(self, state: CacheState):
+ return self.call_event("on_forward_remaining_begin", state)
+
+ def on_forward_end(self, state: CacheState):
+ return self.call_event("on_forward_end", state)
+
+ def call_event(self, event, state, **kwargs):
+ for callback in self.callbacks:
+ getattr(callback, event)(
+ state,
+ **kwargs,
+ )
+
+
+#--------- CacheContext ---------#
+@dataclasses.dataclass
+class CacheContext:
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
+ coefficients: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self):
+ self.coefficients["default"] = torch.Tensor([1, 0]).cuda()
+ self.coefficients["flux"] = torch.Tensor([4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]).cuda()
+
+ def get_coef(self, name):
+ return self.coefficients.get(name)
+
+ def get_buffer(self, name):
+ return self.buffers.get(name)
+
+ def set_buffer(self, name, buffer):
+ self.buffers[name] = buffer
+
+ def clear_buffer(self):
+ self.buffers.clear()
+
+
+#--------- torch version of poly1d ---------#
+class TorchPoly1D:
+ def __init__(self, coefficients):
+ self.coefficients = torch.tensor(coefficients, dtype=torch.float32)
+ self.degree = len(coefficients) - 1
+
+ def __call__(self, x):
+ result = torch.zeros_like(x)
+ for i, coef in enumerate(self.coefficients):
+ result += coef * (x ** (self.degree - i))
+ return result
+
+
+class CachedTransformerBlocks(torch.nn.Module, ABC):
+ def __init__(
+ self,
+ transformer_blocks,
+ single_transformer_blocks=None,
+ *,
+ transformer=None,
+ rel_l1_thresh=0.6,
+ return_hidden_states_first=True,
+ num_steps=-1,
+ name="default",
+ callbacks: Optional[List[CacheCallback]] = None,
+ ):
+ super().__init__()
+ self.state = CacheState(
+ transformer=transformer,
+ transformer_blocks=transformer_blocks,
+ single_transformer_blocks=single_transformer_blocks,
+ cache_context=CacheContext(),
+ )
+ self.state.rel_l1_thresh = rel_l1_thresh
+ self.state.return_hidden_states_first = return_hidden_states_first
+ self.state.use_cache = False
+ self.state.num_steps=num_steps
+ self.state.name=name
+ self.callback_handler = CallbackHandler(callbacks)
+ self.callback_handler.on_init_end(self.state)
+
+ def is_parallelized(self):
+ if get_sequence_parallel_world_size() > 1:
+ return True
+ return False
+
+ def all_reduce(self, input_, op):
+ if get_sequence_parallel_world_size() > 1:
+ return get_sp_group().all_reduce(input_=input_, op=op)
+ raise NotImplementedError("Cache method not support parrellism other than sp")
+
+ def l1_distance_two_tensor(self, t1, t2):
+ mean_diff = (t1 - t2).abs().mean()
+ mean_t1 = t1.abs().mean()
+ if self.is_parallelized():
+ mean_diff = self.all_reduce(mean_diff.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0]
+ mean_t1 = self.all_reduce(mean_t1.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0]
+ diff = mean_diff / mean_t1
+ return diff
+
+ @abstractmethod
+ def are_two_tensor_similar(self, t1, t2, threshold):
+ pass
+
+ def run_one_block_transformer(self, block, hidden_states, encoder_hidden_states, *args, **kwargs):
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
+ return (
+ (hidden_states, encoder_hidden_states)
+ if self.state.return_hidden_states_first
+ else (encoder_hidden_states, hidden_states)
+ )
+
+ @abstractmethod
+ def get_start_idx(self):
+ pass
+
+ def get_remaining_block_result(self, hidden_states, encoder_hidden_states, *args, **kwargs):
+ original_hidden_states = self.state.cache_context.get_buffer("original_hidden_states")
+ original_encoder_hidden_states = self.state.cache_context.get_buffer("original_encoder_hidden_states")
+ start_idx = self.get_start_idx()
+ if start_idx == -1:
+ return (hidden_states, encoder_hidden_states)
+ for block in self.state.transformer_blocks[start_idx:]:
+ hidden_states, encoder_hidden_states = \
+ self.run_one_block_transformer(block, hidden_states, encoder_hidden_states, *args, **kwargs)
+ if self.state.single_transformer_blocks is not None:
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ for block in self.state.single_transformer_blocks:
+ hidden_states = block(hidden_states, *args, **kwargs)
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ self.state.cache_context.set_buffer("hidden_states_residual", hidden_states - original_hidden_states)
+ self.state.cache_context.set_buffer("encoder_hidden_states_residual",
+ encoder_hidden_states - original_encoder_hidden_states)
+ return (hidden_states, encoder_hidden_states)
+
+ @abstractmethod
+ def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs):
+ pass
+
+ def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
+ self.callback_handler.on_forward_begin(self.state)
+ modulated_inputs, prev_modulated_inputs, original_hidden_states, original_encoder_hidden_states = \
+ self.get_modulated_inputs(hidden_states, encoder_hidden_states, *args, **kwargs)
+
+ self.state.cache_context.set_buffer("original_hidden_states", original_hidden_states)
+ self.state.cache_context.set_buffer("original_encoder_hidden_states", original_encoder_hidden_states)
+
+ self.state.use_cache = prev_modulated_inputs is not None and self.are_two_tensor_similar(
+ t1=prev_modulated_inputs, t2=modulated_inputs, threshold=self.state.rel_l1_thresh)
+
+ self.callback_handler.on_forward_remaining_begin(self.state)
+ if self.state.use_cache:
+ hidden_states += self.state.cache_context.get_buffer("hidden_states_residual")
+ encoder_hidden_states += self.state.cache_context.get_buffer("encoder_hidden_states_residual")
+ else:
+ hidden_states, encoder_hidden_states = self.get_remaining_block_result(
+ original_hidden_states, original_encoder_hidden_states, *args, **kwargs)
+
+ self.callback_handler.on_forward_end(self.state)
+ return (
+ (hidden_states, encoder_hidden_states)
+ if self.state.return_hidden_states_first
+ else (encoder_hidden_states, hidden_states)
+ )
+
+
+class FBCachedTransformerBlocks(CachedTransformerBlocks):
+ def __init__(
+ self,
+ transformer_blocks,
+ single_transformer_blocks=None,
+ *,
+ transformer=None,
+ rel_l1_thresh=0.6,
+ return_hidden_states_first=True,
+ num_steps=-1,
+ name="default",
+ callbacks: Optional[List[CacheCallback]] = None,
+ ):
+ super().__init__(transformer_blocks,
+ single_transformer_blocks=single_transformer_blocks,
+ transformer=transformer,
+ rel_l1_thresh=rel_l1_thresh,
+ num_steps=num_steps,
+ return_hidden_states_first=return_hidden_states_first,
+ name=name,
+ callbacks=callbacks)
+
+ def get_start_idx(self):
+ return 1
+
+ def are_two_tensor_similar(self, t1, t2, threshold):
+ return self.l1_distance_two_tensor(t1, t2) < threshold
+
+ def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs):
+ original_hidden_states = hidden_states
+ first_transformer_block = self.state.transformer_blocks[0]
+ hidden_states, encoder_hidden_states = \
+ self.run_one_block_transformer(first_transformer_block, hidden_states, encoder_hidden_states, *args, **kwargs)
+ first_hidden_states_residual = hidden_states - original_hidden_states
+ prev_first_hidden_states_residual = self.state.cache_context.get_buffer("modulated_inputs")
+ self.state.cache_context.set_buffer("modulated_inputs", first_hidden_states_residual)
+
+ return first_hidden_states_residual, prev_first_hidden_states_residual, hidden_states, encoder_hidden_states
+
+
+class TeaCachedTransformerBlocks(CachedTransformerBlocks):
+ def __init__(
+ self,
+ transformer_blocks,
+ single_transformer_blocks=None,
+ *,
+ transformer=None,
+ rel_l1_thresh=0.6,
+ return_hidden_states_first=True,
+ num_steps=-1,
+ name="default",
+ callbacks: Optional[List[CacheCallback]] = None,
+ ):
+ super().__init__(transformer_blocks,
+ single_transformer_blocks=single_transformer_blocks,
+ transformer=transformer,
+ rel_l1_thresh=rel_l1_thresh,
+ num_steps=num_steps,
+ return_hidden_states_first=return_hidden_states_first,
+ name=name,
+ callbacks=callbacks)
+ object.__setattr__(self.state, 'cnt', 0)
+ object.__setattr__(self.state, 'accumulated_rel_l1_distance', 0)
+ object.__setattr__(self.state, 'rescale_func', TorchPoly1D(self.state.cache_context.get_coef(self.state.name)))
+
+ def get_start_idx(self):
+ return 0
+
+ def are_two_tensor_similar(self, t1, t2, threshold):
+ if self.state.cnt == 0 or self.state.cnt == self.state.num_steps-1:
+ self.state.accumulated_rel_l1_distance = 0
+ self.state.use_cache = False
+ else:
+ diff = self.l1_distance_two_tensor(t1, t2)
+ self.state.accumulated_rel_l1_distance += self.state.rescale_func(diff)
+ if self.state.accumulated_rel_l1_distance < threshold:
+ self.state.use_cache = True
+ else:
+ self.state.use_cache = False
+ self.state.accumulated_rel_l1_distance = 0
+ self.state.cnt += 1
+ if self.state.cnt == self.state.num_steps:
+ self.state.cnt = 0
+ return self.state.use_cache
+
+
+ def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs):
+ inp = hidden_states.clone()
+ temb_ = kwargs.get("temb", None)
+ if temb_ is not None:
+ temb_ = temb_.clone()
+ else:
+ raise ValueError("'temb' not found in kwargs")
+ modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.state.transformer_blocks[0].norm1(inp, emb=temb_)
+ previous_modulated_input = self.state.cache_context.get_buffer("modulated_inputs")
+ self.state.cache_context.set_buffer("modulated_inputs", modulated_inp)
+ return modulated_inp, previous_modulated_input, hidden_states, encoder_hidden_states
\ No newline at end of file