Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TeaCache and FBCache #451

Merged
merged 1 commit into from
Feb 24, 2025
Merged

Add TeaCache and FBCache #451

merged 1 commit into from
Feb 24, 2025

Conversation

Binary2355
Copy link
Contributor

@Binary2355 Binary2355 commented Feb 20, 2025

adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git

  • 描述:针对FLUX模型添加TeaCache和FBCache插件,并且只有在SP并行或者没有并行的时候才会调用

  • 单机四卡运行验证示例:

    • 修改run.sh:
    diff --git a/examples/run.sh b/examples/run.sh
    index 1289781..668eead 100644
    --- a/examples/run.sh
    +++ b/examples/run.sh
    @@ -10,7 +10,7 @@ declare -A MODEL_CONFIGS=(
         ["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
         ["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
         ["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
    -    ["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
    +    ["Flux"]="flux_example.py /file_system/models/dit/FLUX.1-schnell/ 8"
         ["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
     )
     
    @@ -27,10 +27,13 @@ mkdir -p ./results
     # task args
     TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"
     
    +# cache args
    +# CACHE_ARGS="--use_teacache"
    +# CACHE_ARGS="--use_fbcache"
     
     # On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
    -N_GPUS=8
    -PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2"
    +N_GPUS=4
    +PARALLEL_ARGS="--pipefusion_parallel_degree 1 --ulysses_degree 4 --ring_degree 1"
     
     # CFG_ARGS="--use_cfg_parallel"
     
    @@ -64,3 +67,4 @@ $CFG_ARGS \
     $PARALLLEL_VAE \
     $COMPILE_FLAG \
     $QUANTIZE_FLAG \
    +$CACHE_ARGS \
    • 运行脚本:
    clear && bash examples/run.sh
  • 性能对比:

    方法 性能
    原始 2.02s
    use_teacache 1.58s
    use_fbcache 0.93s

@Binary2355 Binary2355 changed the title tecache added Add TeaCache and FBCache Feb 20, 2025
@Binary2355 Binary2355 force-pushed the main branch 5 times, most recently from d90cd4c to d574887 Compare February 20, 2025 10:59
Comment on lines 10 to 9
elif transformer_cls_name.startswith("Flux") or transformer_cls_name.startswith("xFuserFlux"):
adapter_name = "flux"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否能通过 isinstance 方式来判断。如果靠名字感觉可以随便自定义?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -60,6 +60,8 @@ class RuntimeConfig:
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False
use_teacache: bool = False
use_fbcache: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default 设置成 False 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@Binary2355 Binary2355 force-pushed the main branch 6 times, most recently from fdf69d1 to 54aede0 Compare February 20, 2025 12:12
@@ -0,0 +1,56 @@
import functools
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果用了其他 repo 代码,在文件开头加一行注释说明一下

use_cache=True,
num_steps=8,
return_hidden_states_first=False,
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数是针对 flux 的么?

Comment on lines 46 to 62
if not self.enable_fbcache:
# the branch to disable cache
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码是不是和FluxTransformer2DModel.forward 的逻辑重复了?如果不用 fbcache 就走原来 forward 的逻辑?

而且既然都调用apply_cache_on_transformer了,说明 enable_cache已经是 True 了,就没必要再在这个函数里判断了?

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L426

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

简而言之,是不是这个函数内不只需要有 enable_fb_cache 的逻辑?

Comment on lines 53 to 69
if not self.enable_teacache:
# the branch to disable cache
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, *args, **kwargs)
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, temb, *args, **kwargs)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
return (
(hidden_states, encoder_hidden_states)
if self.return_hidden_states_first
else (encoder_hidden_states, hidden_states)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same as fbcache

@cb88
Copy link

cb88 commented Feb 21, 2025

Consider wavespeed also as many have found it to work better than teacache.
https://github.com/chengzeyi/Comfy-WaveSpeed

@Binary2355 Binary2355 force-pushed the main branch 2 times, most recently from 78edca6 to 264c25d Compare February 23, 2025 06:30
del first_hidden_states_residual
hidden_states += self.cache_context.hidden_states_residual
encoder_hidden_states += self.cache_context.encoder_hidden_states_residual
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -0,0 +1,70 @@
import functools
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加 copy 代码来源。
目录cache_不要加下划线?

@@ -0,0 +1,341 @@
import contextlib
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加 copy 代码来源

@feifeibear feifeibear merged commit 47f2071 into xdit-project:main Feb 24, 2025
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants