-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TeaCache and FBCache #451
Conversation
d90cd4c
to
d574887
Compare
xfuser/model_executor/plugins/teacache/diffusers_adapters/__init__.py
Outdated
Show resolved
Hide resolved
elif transformer_cls_name.startswith("Flux") or transformer_cls_name.startswith("xFuserFlux"): | ||
adapter_name = "flux" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否能通过 isinstance 方式来判断。如果靠名字感觉可以随便自定义?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
xfuser/config/config.py
Outdated
@@ -60,6 +60,8 @@ class RuntimeConfig: | |||
use_torch_compile: bool = False | |||
use_onediff: bool = False | |||
use_fp8_t5_encoder: bool = False | |||
use_teacache: bool = False | |||
use_fbcache: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default 设置成 False 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
fdf69d1
to
54aede0
Compare
@@ -0,0 +1,56 @@ | |||
import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果用了其他 repo 代码,在文件开头加一行注释说明一下
use_cache=True, | ||
num_steps=8, | ||
return_hidden_states_first=False, | ||
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参数是针对 flux 的么?
if not self.enable_fbcache: | ||
# the branch to disable cache | ||
for block in self.transformer_blocks: | ||
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) | ||
if not self.return_hidden_states_first: | ||
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states | ||
if self.single_transformer_blocks is not None: | ||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | ||
for block in self.single_transformer_blocks: | ||
hidden_states = block(hidden_states, *args, **kwargs) | ||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :] | ||
return ( | ||
(hidden_states, encoder_hidden_states) | ||
if self.return_hidden_states_first | ||
else (encoder_hidden_states, hidden_states) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码是不是和FluxTransformer2DModel.forward 的逻辑重复了?如果不用 fbcache 就走原来 forward 的逻辑?
而且既然都调用apply_cache_on_transformer了,说明 enable_cache已经是 True 了,就没必要再在这个函数里判断了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
简而言之,是不是这个函数内不只需要有 enable_fb_cache 的逻辑?
if not self.enable_teacache: | ||
# the branch to disable cache | ||
for block in self.transformer_blocks: | ||
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, temb, *args, **kwargs) | ||
if not self.return_hidden_states_first: | ||
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states | ||
if self.single_transformer_blocks is not None: | ||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | ||
for block in self.single_transformer_blocks: | ||
hidden_states = block(hidden_states, temb, *args, **kwargs) | ||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :] | ||
return ( | ||
(hidden_states, encoder_hidden_states) | ||
if self.return_hidden_states_first | ||
else (encoder_hidden_states, hidden_states) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the same as fbcache
Consider wavespeed also as many have found it to work better than teacache. |
78edca6
to
264c25d
Compare
del first_hidden_states_residual | ||
hidden_states += self.cache_context.hidden_states_residual | ||
encoder_hidden_states += self.cache_context.encoder_hidden_states_residual | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -0,0 +1,70 @@ | |||
import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加 copy 代码来源。
目录cache_不要加下划线?
@@ -0,0 +1,341 @@ | |||
import contextlib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加 copy 代码来源
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
描述:针对FLUX模型添加TeaCache和FBCache插件,并且只有在SP并行或者没有并行的时候才会调用
单机四卡运行验证示例:
clear && bash examples/run.sh
性能对比: