Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/assets/ray_baseline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/ray_tp2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/ray_wrapper.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
145 changes: 145 additions & 0 deletions docs/user_guide/RAY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Ray Wrapper

<div id="ray-wrapper"></div>

The Ray Wrapper lets cache-dit create and manage the distributed worker processes for you. After enabling it, user code can still look like normal single-process Diffusers code: load a pipeline, call `cache_dit.enable_cache(...)`, then call the pipeline as usual.

![alt text](../assets/ray_wrapper.png)

This means you do not need to write manual distributed inference code. In the common case, you do not need `torchrun`, `dist.init_process_group`, rank/world-size branching, per-rank device placement, or explicit model sharding code. cache-dit starts Ray actors, places workers on GPUs, initializes the worker process group, transfers the model snapshot, applies cache-dit parallelism, and proxies calls back through the original pipeline object.

|Baseline|Ray Wrapper with TP=2 + Compile|
|:---:|:---:|
|47.41s|24.86s|
|![](../assets/ray_baseline.png)|![](../assets/ray_tp2.png)|

## Pipeline-Level Wrapper

```python
import torch
from diffusers import Flux2KleinPipeline

import cache_dit
from cache_dit import ParallelismConfig

# Just let it load on CPU; cache-dit will handle GPU
# transfer inside the Ray workers.
pipe = Flux2KleinPipeline.from_pretrained(
"/path/to/FLUX.2-klein-base-9B",
torch_dtype=torch.bfloat16,
)

# NOTE: Will auto transfer to cuda inside by ray wrapper for
# pipeline-level parallelism, so we keep the original pipeline
# on CPU to avoid redundant GPU memory usage.
cache_dit.enable_cache(
pipe,
parallelism_config=ParallelismConfig(
tp_size=2,
use_ray=True,
),
)

# Call the pipeline as usual; No code changes are needed for
# Ray parallelism to work.
image = pipe(
prompt="A cat holding a sign that says hello world",
height=1024,
width=1024,
num_inference_steps=28,
).images[0]

image.save("ray_wrapper.png")
cache_dit.disable_cache(pipe)
```

The code above is still a normal single-process script. Run it with `python your_script.py`; cache-dit and Ray handle the distributed execution internally.

## Transformer-Level Wrapper

You can also wrap only the transformer module. This is useful when you want the text encoders, VAE, scheduler, and other pipeline components to stay in the main process while only the transformer is executed by Ray workers.

```python
cache_dit.enable_cache(
pipe.transformer,
parallelism_config=ParallelismConfig(
ulysses_size=2,
use_ray=True,
),
)

# NOTE: Only the transformer is parallelized and transferred to GPU,
# so we need to move the pipeline to GPU as well for the forward pass.
pipe.to("cuda")
image = pipe(prompt="A cinematic mountain lake at sunrise").images[0]
cache_dit.disable_cache(pipe.transformer)
```

When the transformer-level wrapper is enabled, cache-dit patches the Ray-owned transformer so `pipe.to("cuda")` does not move the main-process transformer copy back onto the GPU. The executable transformer copies live inside the Ray workers.

## Tensor Parallelism and Context Parallelism

Set the normal cache-dit parallelism fields and add `use_ray=True`:

```python
ParallelismConfig(tp_size=2, use_ray=True)
ParallelismConfig(ulysses_size=2, use_ray=True)
ParallelismConfig(ring_size=2, use_ray=True)
```

Use the explicit field names `tp_size`, `ulysses_size`, and `ring_size`. Short aliases such as `tp`, `ulysses`, and `ring` are intentionally not supported.

## Optional Compile

Ray workers can compile the transformer after loading and applying cache-dit parallelism:

```python
cache_dit.enable_cache(
pipe,
parallelism_config=ParallelismConfig(
tp_size=2,
use_ray=True,
ray_use_compile=True,
),
)
```

If the transformer provides `compile_repeated_blocks()`, cache-dit calls that method first. Otherwise it falls back to `transformer.compile()` when available.

## Cache and Quantization

When `use_ray=True`, cache hooks and quantization are applied inside the Ray workers after the model snapshot is loaded. This preserves the same user-facing `enable_cache` API while avoiding main-process hooks or quantized modules being lost during model transfer.

```python
from cache_dit import DBCacheConfig
from cache_dit import ParallelismConfig
from cache_dit import QuantizeConfig

cache_dit.enable_cache(
pipe,
cache_config=DBCacheConfig(...),
parallelism_config=ParallelismConfig(
tp_size=2,
use_ray=True,
),
quantize_config=QuantizeConfig(...),
)
```

## Quick Start

A complete runnable example is available at `examples/ray/ray_wrapper_example.py`. For example:

```bash
# Baseline
python3 examples/ray/ray_wrapper_example.py \
--model-path $FLUX_2_KLEIN_BASE_9B_DIR \
--save-path ./tmp/baseline.png

# Ray wrapper with TP=2 and compile enabled
python3 examples/ray/ray_wrapper_example.py \
--model-path $FLUX_2_KLEIN_BASE_9B_DIR \
--tp 2 \
--compile \
--save-path ./tmp/ray.png
```
181 changes: 181 additions & 0 deletions examples/ray/ray_wrapper_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import argparse
import os
import time
from pathlib import Path

import torch
from diffusers import Flux2KleinPipeline

import cache_dit
from cache_dit import DBCacheConfig
from cache_dit import ParallelismConfig
from cache_dit import QuantizeConfig


def parse_args() -> argparse.Namespace:
"""Parse command line arguments for the Ray wrapper example.

:returns: Parsed command line arguments.
"""

parser = argparse.ArgumentParser(
description="Run FLUX.2-klein-base-9B with optional cache-dit Ray wrapper.")
parser.add_argument("--model-path",
type=str,
default=None,
help="Path to FLUX.2-klein-base-9B model.")
parser.add_argument("--prompt", type=str, default="A cat holding a sign that says hello world")
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--num-inference-steps", "--steps", type=int, default=28)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--warmup",
type=int,
default=1,
help="Number of warmup generations before timing.")
parser.add_argument("--repeat", type=int, default=1, help="Number of timed generations.")
parser.add_argument(
"--cache",
action="store_true",
help="Enable cache-dit with the default DBCacheConfig.",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Enable quantization with the default QuantizeConfig.",
)
parser.add_argument(
"--ulysses",
type=int,
default=1,
help="Ulysses size. Values > 1 enable Ray.",
)
parser.add_argument(
"--tp",
type=int,
default=1,
help="Tensor parallel size. Values > 1 enable Ray tensor parallelism.",
)
parser.add_argument("--save-path", type=str, default=".tmp/ray_wrapper.png")
parser.add_argument(
"--target",
choices=("transformer", "pipeline"),
default="pipeline",
help="Enable Ray wrapper on pipe.transformer or on the pipeline object.",
)
parser.add_argument(
"--use-flashpack-transfer",
action="store_true",
help="Use Diffusers serialization with use_flashpack=True for Ray pipeline snapshots.",
)
parser.add_argument(
"--use-compile",
"--compile",
action="store_true",
help="Compile the Ray-owned transformer after loading and parallelization.",
)
return parser.parse_args()


def main() -> None:
"""Run the Ray wrapper example and save the generated image."""

args = parse_args()
model_path = args.model_path or os.environ.get(
"FLUX_2_KLEIN_BASE_9B_DIR",
"black-forest-labs/FLUX.2-klein-base-9B",
)
use_ray = args.ulysses > 1 or args.tp > 1
pipe = Flux2KleinPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
) # .to("cuda") will be called inside the Ray wrapper if use_ray is True

if not use_ray:
pipe.to("cuda")

cache_config = DBCacheConfig(Fn_compute_blocks=1) if args.cache else None
quantize_config = QuantizeConfig(quant_type="float8_per_tensor") if args.quantize else None
parallelism_config = None
accelerate_enabled = use_ray or cache_config is not None or quantize_config is not None

if use_ray:
parallelism_config = ParallelismConfig(
ulysses_size=args.ulysses if args.ulysses > 1 else None,
tp_size=args.tp if args.tp > 1 else None,
use_ray=True,
ray_use_flashpack=args.use_flashpack_transfer,
ray_use_compile=args.use_compile,
)

if accelerate_enabled:
if args.target == "pipeline":
# NOTE: Will auto transfer to cuda inside by ray wrapper for
# pipeline-level parallelism, so we keep the original pipeline
# on CPU to avoid redundant GPU memory usage.
cache_dit.enable_cache(
pipe,
cache_config=cache_config,
parallelism_config=parallelism_config,
quantize_config=quantize_config,
)
else:
cache_dit.enable_cache(
pipe.transformer,
cache_config=cache_config,
parallelism_config=parallelism_config,
quantize_config=quantize_config,
)
if use_ray:
# NOTE: Only the transformer is parallelized and transferred to GPU,
# so we need to move the pipeline to GPU as well for the forward pass.
pipe.to("cuda")

if args.warmup < 0:
raise ValueError("--warmup must be greater than or equal to 0.")
if args.repeat < 1:
raise ValueError("--repeat must be greater than or equal to 1.")

def run_generation():
generator = torch.Generator("cpu").manual_seed(args.seed)
return pipe(
prompt=args.prompt,
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
generator=generator,
).images[0]

# Call the pipeline as usual; No code changes are needed for
# Ray parallelism to work.
for _ in range(args.warmup):
run_generation()

start_time = time.time()
image = None
for _ in range(args.repeat):
image = run_generation()
elapsed = time.time() - start_time
assert image is not None

save_path = Path(args.save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
image.save(save_path)
print(f"Warmup: {args.warmup}")
print(f"Repeat: {args.repeat}")
print(f"Total Inference Time: {elapsed:.2f}s")
print(f"Average Inference Time: {elapsed / args.repeat:.2f}s")
print(f"Saved image to {save_path}")

if accelerate_enabled:
cache_dit.disable_cache(pipe if args.target == "pipeline" else pipe.transformer)


if __name__ == "__main__":
main()
# Example usage:
# python3 ray_wrapper_example.py # baseline with no Ray parallelism
# python3 ray_wrapper_example.py --ulysses 2 --save-path ray_ulysses2_output.png
# python3 ray_wrapper_example.py --tp 2 --cache --quantize --save-path ray_tp2_output.png
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ nav:
- DBCache Design: user_guide/DBCACHE_DESIGN.md
- Context Parallelism: user_guide/CONTEXT_PARALLEL.md
- Tensor Parallelism: user_guide/TENSOR_PARALLEL.md
- Ray Wrapper: user_guide/RAY.md
- TE-P, VAE-P and CN-P : user_guide/EXTRA_PARALLEL.md
- 2D and 3D Parallelism: user_guide/HYBRID_PARALLEL.md
- Low-Bits Quantization: user_guide/QUANTIZATION.md
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ parallelism = [
"einops>=0.8.1",
]

ray = [
"ray>=2.0",
"safetensors>=0.5.3",
]

quantization = [
"torchao>=0.14.1",
"bitsandbytes>=0.48.1",
Expand Down
6 changes: 6 additions & 0 deletions src/cache_dit/caching/cache_adapters/cache_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ def _release_blocks_hooks(blocks):
return

def _release_transformer_hooks(transformer):
from ...ray import disable_ray_parallelism

disable_ray_parallelism(transformer)
if hasattr(transformer, "_original_forward"):
original_forward = transformer._original_forward
transformer.forward = original_forward.__get__(transformer)
Expand All @@ -550,6 +553,9 @@ def _release_transformer_hooks(transformer):
del transformer._context_names

def _release_pipeline_hooks(pipe):
from ...ray import disable_ray_pipeline_parallelism

disable_ray_pipeline_parallelism(pipe)
if hasattr(pipe, "_original_call"):
original_call = pipe.__class__._original_call
pipe.__class__.__call__ = original_call
Expand Down
Loading
Loading