Skip to content

Commit

Permalink
- add 1xH20 performance
Browse files Browse the repository at this point in the history
- add 4xH20 performance and 1xH20 performance with torch.compile
  • Loading branch information
Binary2355 committed Feb 25, 2025
1 parent 90c2d72 commit 84e137a
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 251 deletions.
45 changes: 39 additions & 6 deletions docs/performance/flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,48 @@ The quality of image generation at 2048px, 3072px, and 4096px resolutions is as

## Cache Methods

We tested the performance of TeaCache and First-Block-Cache on 4xH20 with SP=4.
We tested the performance of TeaCache and First-Block-Cache on 4xH20 with SP=4 and 1xH20 respectively.
The Performance shown as below:

<div align="center">

| Method | Latency (s) |
|----------------|--------|
| Baseline | 2.02s |
| use_teacache | 1.58s |
| use_fbcache | 0.93s |
<table>
<tr>
<th rowspan="2">Method</th>
<th colspan="4">Latency (s)</th>
</tr>
<tr>
<th colspan="2">without torch.compile</th>
<th colspan="2">with torch.compile</th>
</tr>
<tr>
<th></th>
<th>4xH20</th>
<th>1xH20</th>
<th>4xH20</th>
<th>1xH20</th>
</tr>
<tr>
<td>Baseline</td>
<td>2.02s</td>
<td>6.10s</td>
<td>1.81s</td>
<td>5.02s</td>
</tr>
<tr>
<td>use_teacache</td>
<td>1.60s</td>
<td>4.67s</td>
<td>1.50s</td>
<td>3.92s</td>
</tr>
<tr>
<td>use_fbcache</td>
<td>0.93s</td>
<td>2.51s</td>
<td>0.85s</td>
<td>2.09s</td>
</tr>
</table>

</div>
30 changes: 14 additions & 16 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,19 @@ def main():
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

cache_args = {
"use_teacache": engine_args.use_teacache,
"use_fbcache": engine_args.use_fbcache,
"rel_l1_thresh": 0.6,
"return_hidden_states_first": False,
"num_steps": input_config.num_inference_steps,
}

pipe = xFuserFluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
engine_args=engine_args,
cache_args=cache_args,
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder_2,
)
Expand All @@ -48,28 +58,16 @@ def main():

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)
pipe.prepare_run(input_config, steps=input_config.num_inference_steps)

use_cache = engine_args.use_teacache or engine_args.use_fbcache
if (use_cache
use_cache = (
use_cache
and get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
):
cache_args = {
"rel_l1_thresh": 0.6,
"return_hidden_states_first": False,
"num_steps": input_config.num_inference_steps,
}

if engine_args.use_fbcache and engine_args.use_teacache:
cache_args["use_cache"] = "Fb"
elif engine_args.use_teacache:
cache_args["use_cache"] = "Tea"
elif engine_args.use_fbcache:
cache_args["use_cache"] = "Fb"
)

pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
Expand Down
6 changes: 5 additions & 1 deletion xfuser/model_executor/cache/diffusers_adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import importlib
from typing import Type, Dict, TypeVar
from xfuser.model_executor.cache.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY
from xfuser.logger import init_logger

logger = init_logger(__name__)


def apply_cache_on_transformer(transformer, *args, **kwargs):
adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer))
if not adapter_name:
raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}")
logger.error(f"Unknown transformer class: {transformer.__class__.__name__}")
return transformer

adapter_module = importlib.import_module(f".{adapter_name}", __package__)
apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
Expand Down
Loading

0 comments on commit 84e137a

Please sign in to comment.