Skip to content

Commit

Permalink
tecache added
Browse files Browse the repository at this point in the history
  • Loading branch information
Binary2355 committed Feb 20, 2025
1 parent 836cf85 commit c940dfe
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
get_pipeline_parallel_world_size,
)


Expand Down Expand Up @@ -45,6 +46,12 @@ def main():
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)
from xfuser.model_executor.plugins.teacache.diffusers_adapters import apply_cache_on_transformer
use_teacache = True
if get_pipeline_parallel_world_size() > 1:
use_teacache = False
pipe.transformer = apply_cache_on_transformer(
pipe.transformer, rel_l1_thresh=0.6, use_teacache=use_teacache, num_steps=input_config.num_inference_steps)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib

from diffusers import DiffusionPipeline


def apply_cache_on_transformer(transformer, *args, **kwargs):
transformer_cls_name = transformer.__class__.__name__
if False:
pass
elif transformer_cls_name.startswith("Flux") or transformer_cls_name.startswith("xFuserFlux"):
adapter_name = "flux"
else:
raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)


def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
assert isinstance(pipe, DiffusionPipeline)

pipe_cls_name = pipe.__class__.__name__
if False:
pass
elif pipe_cls_name.startswith("Flux"):
adapter_name = "flux"
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
86 changes: 86 additions & 0 deletions xfuser/model_executor/plugins/teacache/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import functools
import unittest

import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel

from xfuser.model_executor.plugins.teacache import utils


def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
rel_l1_thresh=0.05,
use_teacache=True,
num_steps=8,
return_hidden_states_first=False,
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01],
):
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.TeaCachedTransformerBlocks(
transformer.transformer_blocks,
transformer.single_transformer_blocks,
transformer=transformer,
enable_teacache=use_teacache,
num_steps=num_steps,
rel_l1_thresh=rel_l1_thresh,
return_hidden_states_first=return_hidden_states_first,
coefficients=coefficients,
)
]
)
dummy_single_transformer_blocks = torch.nn.ModuleList()

original_forward = transformer.forward

@functools.wraps(original_forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)

transformer.forward = new_forward.__get__(transformer)

return transformer


def apply_cache_on_pipe(
pipe: DiffusionPipeline,
*,
shallow_patch: bool = False,
**kwargs,
):
original_call = pipe.__class__.__call__

if not getattr(original_call, "_is_cached", False):

@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)

pipe.__class__.__call__ = new_call

new_call._is_cached = True

if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)

pipe._is_cached = True

return pipe
101 changes: 101 additions & 0 deletions xfuser/model_executor/plugins/teacache/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import contextlib
import dataclasses
from collections import defaultdict
from typing import DefaultDict, Dict

import torch
import numpy as np

class TeaCachedTransformerBlocks(torch.nn.Module):
def __init__(
self,
transformer_blocks,
single_transformer_blocks=None,
*,
transformer=None,
enable_teacache=True,
num_steps=8,
rel_l1_thresh=0.6,
return_hidden_states_first=True,
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01],
):
super().__init__()
self.transformer = transformer
self.transformer_blocks = transformer_blocks
self.single_transformer_blocks = single_transformer_blocks
self.cnt = 0
self.enable_teacache = enable_teacache
self.num_steps = num_steps
self.rel_l1_thresh = rel_l1_thresh
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.previous_residual_encoder = None
self.coefficients = coefficients
self.return_hidden_states_first = return_hidden_states_first

def forward(self, hidden_states, encoder_hidden_states, temb, *args, **kwargs):
if not self.enable_teacache:
# the branch to disable cache
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, temb, *args, **kwargs)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)

original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
first_transformer_block = self.transformer_blocks[0]
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = first_transformer_block.norm1(inp, emb=temb_)
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0

print(f"636998 should_calc[{should_calc}]")
if not should_calc:
hidden_states += self.previous_residual
encoder_hidden_states += self.previous_residual_encoder
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, temb, *args, **kwargs)
encoder_hidden_states, hidden_states = hidden_states.split(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
self.previous_residual = hidden_states - original_hidden_states
self.previous_residual_encoder = encoder_hidden_states - original_encoder_hidden_states
print(f"self.previous_residual[{self.previous_residual.device}]self.previous_residual_encoder[{self.previous_residual_encoder.device}]")

return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)

0 comments on commit c940dfe

Please sign in to comment.