diff --git a/docs/assets/ray_baseline.png b/docs/assets/ray_baseline.png new file mode 100644 index 00000000..6aec07da Binary files /dev/null and b/docs/assets/ray_baseline.png differ diff --git a/docs/assets/ray_tp2.png b/docs/assets/ray_tp2.png new file mode 100644 index 00000000..6aec07da Binary files /dev/null and b/docs/assets/ray_tp2.png differ diff --git a/docs/assets/ray_wrapper.png b/docs/assets/ray_wrapper.png new file mode 100644 index 00000000..99dd9856 Binary files /dev/null and b/docs/assets/ray_wrapper.png differ diff --git a/docs/user_guide/RAY.md b/docs/user_guide/RAY.md new file mode 100644 index 00000000..1df79f5e --- /dev/null +++ b/docs/user_guide/RAY.md @@ -0,0 +1,145 @@ +# Ray Wrapper + +
+ +The Ray Wrapper lets cache-dit create and manage the distributed worker processes for you. After enabling it, user code can still look like normal single-process Diffusers code: load a pipeline, call `cache_dit.enable_cache(...)`, then call the pipeline as usual. + + + +This means you do not need to write manual distributed inference code. In the common case, you do not need `torchrun`, `dist.init_process_group`, rank/world-size branching, per-rank device placement, or explicit model sharding code. cache-dit starts Ray actors, places workers on GPUs, initializes the worker process group, transfers the model snapshot, applies cache-dit parallelism, and proxies calls back through the original pipeline object. + +|Baseline|Ray Wrapper with TP=2 + Compile| +|:---:|:---:| +|47.41s|24.86s| +||| + +## Pipeline-Level Wrapper + +```python +import torch +from diffusers import Flux2KleinPipeline + +import cache_dit +from cache_dit import ParallelismConfig + +# Just let it load on CPU; cache-dit will handle GPU +# transfer inside the Ray workers. +pipe = Flux2KleinPipeline.from_pretrained( + "/path/to/FLUX.2-klein-base-9B", + torch_dtype=torch.bfloat16, +) + +# NOTE: Will auto transfer to cuda inside by ray wrapper for +# pipeline-level parallelism, so we keep the original pipeline +# on CPU to avoid redundant GPU memory usage. +cache_dit.enable_cache( + pipe, + parallelism_config=ParallelismConfig( + tp_size=2, + use_ray=True, + ), +) + +# Call the pipeline as usual; No code changes are needed for +# Ray parallelism to work. +image = pipe( + prompt="A cat holding a sign that says hello world", + height=1024, + width=1024, + num_inference_steps=28, +).images[0] + +image.save("ray_wrapper.png") +cache_dit.disable_cache(pipe) +``` + +The code above is still a normal single-process script. Run it with `python your_script.py`; cache-dit and Ray handle the distributed execution internally. + +## Transformer-Level Wrapper + +You can also wrap only the transformer module. This is useful when you want the text encoders, VAE, scheduler, and other pipeline components to stay in the main process while only the transformer is executed by Ray workers. + +```python +cache_dit.enable_cache( + pipe.transformer, + parallelism_config=ParallelismConfig( + ulysses_size=2, + use_ray=True, + ), +) + +# NOTE: Only the transformer is parallelized and transferred to GPU, +# so we need to move the pipeline to GPU as well for the forward pass. +pipe.to("cuda") +image = pipe(prompt="A cinematic mountain lake at sunrise").images[0] +cache_dit.disable_cache(pipe.transformer) +``` + +When the transformer-level wrapper is enabled, cache-dit patches the Ray-owned transformer so `pipe.to("cuda")` does not move the main-process transformer copy back onto the GPU. The executable transformer copies live inside the Ray workers. + +## Tensor Parallelism and Context Parallelism + +Set the normal cache-dit parallelism fields and add `use_ray=True`: + +```python +ParallelismConfig(tp_size=2, use_ray=True) +ParallelismConfig(ulysses_size=2, use_ray=True) +ParallelismConfig(ring_size=2, use_ray=True) +``` + +Use the explicit field names `tp_size`, `ulysses_size`, and `ring_size`. Short aliases such as `tp`, `ulysses`, and `ring` are intentionally not supported. + +## Optional Compile + +Ray workers can compile the transformer after loading and applying cache-dit parallelism: + +```python +cache_dit.enable_cache( + pipe, + parallelism_config=ParallelismConfig( + tp_size=2, + use_ray=True, + ray_use_compile=True, + ), +) +``` + +If the transformer provides `compile_repeated_blocks()`, cache-dit calls that method first. Otherwise it falls back to `transformer.compile()` when available. + +## Cache and Quantization + +When `use_ray=True`, cache hooks and quantization are applied inside the Ray workers after the model snapshot is loaded. This preserves the same user-facing `enable_cache` API while avoiding main-process hooks or quantized modules being lost during model transfer. + +```python +from cache_dit import DBCacheConfig +from cache_dit import ParallelismConfig +from cache_dit import QuantizeConfig + +cache_dit.enable_cache( + pipe, + cache_config=DBCacheConfig(...), + parallelism_config=ParallelismConfig( + tp_size=2, + use_ray=True, + ), + quantize_config=QuantizeConfig(...), +) +``` + +## Quick Start + +A complete runnable example is available at `examples/ray/ray_wrapper_example.py`. For example: + +```bash +# Baseline +python3 examples/ray/ray_wrapper_example.py \ + --model-path $FLUX_2_KLEIN_BASE_9B_DIR \ + --save-path ./tmp/baseline.png + +# Ray wrapper with TP=2 and compile enabled +python3 examples/ray/ray_wrapper_example.py \ + --model-path $FLUX_2_KLEIN_BASE_9B_DIR \ + --tp 2 \ + --compile \ + --save-path ./tmp/ray.png +``` diff --git a/examples/ray/ray_wrapper_example.py b/examples/ray/ray_wrapper_example.py new file mode 100644 index 00000000..3b3bf18d --- /dev/null +++ b/examples/ray/ray_wrapper_example.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import argparse +import os +import time +from pathlib import Path + +import torch +from diffusers import Flux2KleinPipeline + +import cache_dit +from cache_dit import DBCacheConfig +from cache_dit import ParallelismConfig +from cache_dit import QuantizeConfig + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments for the Ray wrapper example. + + :returns: Parsed command line arguments. + """ + + parser = argparse.ArgumentParser( + description="Run FLUX.2-klein-base-9B with optional cache-dit Ray wrapper.") + parser.add_argument("--model-path", + type=str, + default=None, + help="Path to FLUX.2-klein-base-9B model.") + parser.add_argument("--prompt", type=str, default="A cat holding a sign that says hello world") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--num-inference-steps", "--steps", type=int, default=28) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--warmup", + type=int, + default=1, + help="Number of warmup generations before timing.") + parser.add_argument("--repeat", type=int, default=1, help="Number of timed generations.") + parser.add_argument( + "--cache", + action="store_true", + help="Enable cache-dit with the default DBCacheConfig.", + ) + parser.add_argument( + "--quantize", + action="store_true", + help="Enable quantization with the default QuantizeConfig.", + ) + parser.add_argument( + "--ulysses", + type=int, + default=1, + help="Ulysses size. Values > 1 enable Ray.", + ) + parser.add_argument( + "--tp", + type=int, + default=1, + help="Tensor parallel size. Values > 1 enable Ray tensor parallelism.", + ) + parser.add_argument("--save-path", type=str, default=".tmp/ray_wrapper.png") + parser.add_argument( + "--target", + choices=("transformer", "pipeline"), + default="pipeline", + help="Enable Ray wrapper on pipe.transformer or on the pipeline object.", + ) + parser.add_argument( + "--use-flashpack-transfer", + action="store_true", + help="Use Diffusers serialization with use_flashpack=True for Ray pipeline snapshots.", + ) + parser.add_argument( + "--use-compile", + "--compile", + action="store_true", + help="Compile the Ray-owned transformer after loading and parallelization.", + ) + return parser.parse_args() + + +def main() -> None: + """Run the Ray wrapper example and save the generated image.""" + + args = parse_args() + model_path = args.model_path or os.environ.get( + "FLUX_2_KLEIN_BASE_9B_DIR", + "black-forest-labs/FLUX.2-klein-base-9B", + ) + use_ray = args.ulysses > 1 or args.tp > 1 + pipe = Flux2KleinPipeline.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + ) # .to("cuda") will be called inside the Ray wrapper if use_ray is True + + if not use_ray: + pipe.to("cuda") + + cache_config = DBCacheConfig(Fn_compute_blocks=1) if args.cache else None + quantize_config = QuantizeConfig(quant_type="float8_per_tensor") if args.quantize else None + parallelism_config = None + accelerate_enabled = use_ray or cache_config is not None or quantize_config is not None + + if use_ray: + parallelism_config = ParallelismConfig( + ulysses_size=args.ulysses if args.ulysses > 1 else None, + tp_size=args.tp if args.tp > 1 else None, + use_ray=True, + ray_use_flashpack=args.use_flashpack_transfer, + ray_use_compile=args.use_compile, + ) + + if accelerate_enabled: + if args.target == "pipeline": + # NOTE: Will auto transfer to cuda inside by ray wrapper for + # pipeline-level parallelism, so we keep the original pipeline + # on CPU to avoid redundant GPU memory usage. + cache_dit.enable_cache( + pipe, + cache_config=cache_config, + parallelism_config=parallelism_config, + quantize_config=quantize_config, + ) + else: + cache_dit.enable_cache( + pipe.transformer, + cache_config=cache_config, + parallelism_config=parallelism_config, + quantize_config=quantize_config, + ) + if use_ray: + # NOTE: Only the transformer is parallelized and transferred to GPU, + # so we need to move the pipeline to GPU as well for the forward pass. + pipe.to("cuda") + + if args.warmup < 0: + raise ValueError("--warmup must be greater than or equal to 0.") + if args.repeat < 1: + raise ValueError("--repeat must be greater than or equal to 1.") + + def run_generation(): + generator = torch.Generator("cpu").manual_seed(args.seed) + return pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + generator=generator, + ).images[0] + + # Call the pipeline as usual; No code changes are needed for + # Ray parallelism to work. + for _ in range(args.warmup): + run_generation() + + start_time = time.time() + image = None + for _ in range(args.repeat): + image = run_generation() + elapsed = time.time() - start_time + assert image is not None + + save_path = Path(args.save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + image.save(save_path) + print(f"Warmup: {args.warmup}") + print(f"Repeat: {args.repeat}") + print(f"Total Inference Time: {elapsed:.2f}s") + print(f"Average Inference Time: {elapsed / args.repeat:.2f}s") + print(f"Saved image to {save_path}") + + if accelerate_enabled: + cache_dit.disable_cache(pipe if args.target == "pipeline" else pipe.transformer) + + +if __name__ == "__main__": + main() + # Example usage: + # python3 ray_wrapper_example.py # baseline with no Ray parallelism + # python3 ray_wrapper_example.py --ulysses 2 --save-path ray_ulysses2_output.png + # python3 ray_wrapper_example.py --tp 2 --cache --quantize --save-path ray_tp2_output.png diff --git a/mkdocs.yml b/mkdocs.yml index d800f532..9f8b2f5b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -62,6 +62,7 @@ nav: - DBCache Design: user_guide/DBCACHE_DESIGN.md - Context Parallelism: user_guide/CONTEXT_PARALLEL.md - Tensor Parallelism: user_guide/TENSOR_PARALLEL.md + - Ray Wrapper: user_guide/RAY.md - TE-P, VAE-P and CN-P : user_guide/EXTRA_PARALLEL.md - 2D and 3D Parallelism: user_guide/HYBRID_PARALLEL.md - Low-Bits Quantization: user_guide/QUANTIZATION.md diff --git a/pyproject.toml b/pyproject.toml index 0693e63b..82c8401f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,11 @@ parallelism = [ "einops>=0.8.1", ] +ray = [ + "ray>=2.0", + "safetensors>=0.5.3", +] + quantization = [ "torchao>=0.14.1", "bitsandbytes>=0.48.1", diff --git a/src/cache_dit/caching/cache_adapters/cache_adapter.py b/src/cache_dit/caching/cache_adapters/cache_adapter.py index 3b60fa57..d9d9a1ea 100644 --- a/src/cache_dit/caching/cache_adapters/cache_adapter.py +++ b/src/cache_dit/caching/cache_adapters/cache_adapter.py @@ -532,6 +532,9 @@ def _release_blocks_hooks(blocks): return def _release_transformer_hooks(transformer): + from ...ray import disable_ray_parallelism + + disable_ray_parallelism(transformer) if hasattr(transformer, "_original_forward"): original_forward = transformer._original_forward transformer.forward = original_forward.__get__(transformer) @@ -550,6 +553,9 @@ def _release_transformer_hooks(transformer): del transformer._context_names def _release_pipeline_hooks(pipe): + from ...ray import disable_ray_pipeline_parallelism + + disable_ray_pipeline_parallelism(pipe) if hasattr(pipe, "_original_call"): original_call = pipe.__class__._original_call pipe.__class__.__call__ = original_call diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py index 04a8bf16..d8df050d 100644 --- a/src/cache_dit/caching/cache_interface.py +++ b/src/cache_dit/caching/cache_interface.py @@ -160,7 +160,17 @@ def enable_cache( if params_modifiers is not None: context_kwargs["params_modifiers"] = params_modifiers - if cache_config is not None: + if quantize_config is not None: + assert isinstance(quantize_config, + QuantizeConfig), "quantize_config should be of type QuantizeConfig." + + ray_enabled = isinstance(parallelism_config, ParallelismConfig) and parallelism_config.use_ray + ray_cache_context_kwargs = copy.deepcopy( + context_kwargs) if ray_enabled and cache_config is not None else None + ray_quantize_config = copy.deepcopy( + quantize_config) if ray_enabled and quantize_config is not None else None + + if cache_config is not None and not ray_enabled: if isinstance( pipe_or_adapter, (DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin), @@ -173,6 +183,8 @@ def enable_cache( raise ValueError(f"type: {type(pipe_or_adapter)} is not valid, " "Please pass DiffusionPipeline or BlockAdapter" "for the 1's position param: pipe_or_adapter") + elif cache_config is not None: + logger.info("Ray parallelism is enabled; cache hooks will be applied inside Ray workers.") else: logger.warning("cache_config is None, skip cache acceleration for " f"{pipe_or_adapter.__class__.__name__}.") @@ -209,9 +221,13 @@ def enable_cache( adapter = BlockAdapter.normalize(adapter, unique=False) transformers = BlockAdapter.flatten(adapter.transformer) else: - if not BlockAdapter.is_normalized(pipe_or_adapter): + if isinstance(pipe_or_adapter, (torch.nn.Module, ModelMixin)): + transformers = [pipe_or_adapter] + elif not BlockAdapter.is_normalized(pipe_or_adapter): pipe_or_adapter = BlockAdapter.normalize(pipe_or_adapter, unique=False) - transformers = BlockAdapter.flatten(pipe_or_adapter.transformer) + transformers = BlockAdapter.flatten(pipe_or_adapter.transformer) + else: + transformers = BlockAdapter.flatten(pipe_or_adapter.transformer) if len(transformers) == 0: logger.warning("No transformer is detected in the BlockAdapter, skip enabling " @@ -242,14 +258,35 @@ def enable_cache( if extra_parallel_module is not None and pipe is not None: parallelism_config.extra_parallel_modules = parse_extra_modules(pipe, extra_parallel_module) - for i, transformer in enumerate(transformers): - # Enable parallelism for the transformer inplace - transformers[i] = enable_parallelism(transformer, parallelism_config) + if parallelism_config.use_ray: + from ..ray import enable_ray_parallelism + from ..ray import enable_ray_pipeline_parallelism + + if isinstance(pipe_or_adapter, DiffusionPipeline): + pipe_or_adapter = enable_ray_pipeline_parallelism( + pipe_or_adapter, + parallelism_config, + cache_context_kwargs=ray_cache_context_kwargs, + quantize_config=ray_quantize_config, + ) + else: + for i, transformer in enumerate(transformers): + transformers[i] = enable_ray_parallelism( + transformer, + parallelism_config, + cache_context_kwargs=ray_cache_context_kwargs, + quantize_config=ray_quantize_config, + ) + else: + for i, transformer in enumerate(transformers): + # Enable parallelism for the transformer inplace + transformers[i] = enable_parallelism(transformer, parallelism_config) # Enable quantization if quantize_config is provided. if quantize_config is not None: - assert isinstance(quantize_config, - QuantizeConfig), "quantize_config should be of type QuantizeConfig." + if ray_enabled: + logger.info("Ray parallelism is enabled; quantization will be applied inside Ray workers.") + return pipe_or_adapter # By default, we will try to apply quantization to transformer module(s) # for better performance. User can specify the quantization modules more diff --git a/src/cache_dit/caching/load_configs.py b/src/cache_dit/caching/load_configs.py index 71a5f1af..b3fe73c8 100644 --- a/src/cache_dit/caching/load_configs.py +++ b/src/cache_dit/caching/load_configs.py @@ -215,6 +215,15 @@ def load_parallelism_config(path_or_dict: str | dict, return None parallelism_config_kwargs = parallel_kwargs["parallelism_config"] + for alias, canonical in (("ulysses", "ulysses_size"), ("ring", "ring_size"), ("tp", "tp_size")): + if alias not in parallelism_config_kwargs: + continue + if (canonical in parallelism_config_kwargs + and parallelism_config_kwargs[canonical] != parallelism_config_kwargs[alias]): + raise ValueError(f"Both {alias} and {canonical} are set with different values: " + f"{parallelism_config_kwargs[alias]} vs " + f"{parallelism_config_kwargs[canonical]}.") + parallelism_config_kwargs[canonical] = parallelism_config_kwargs[alias] if "backend" in parallelism_config_kwargs: backend_str = parallelism_config_kwargs["backend"] parallelism_config_kwargs["backend"] = ParallelismBackend.from_str(backend_str) diff --git a/src/cache_dit/distributed/config.py b/src/cache_dit/distributed/config.py index ffeeaae1..7a4c1046 100644 --- a/src/cache_dit/distributed/config.py +++ b/src/cache_dit/distributed/config.py @@ -32,6 +32,49 @@ class ParallelismConfig: # The degree of tensor parallelism. tp_size: int = None + # Ray wrapper config + # use_ray (`bool`, *optional*): + # Whether cache-dit should manage distributed workers through Ray instead of requiring users to + # launch the program with torchrun. + use_ray: bool = False + # ray_num_workers (`int`, *optional*): + # Number of Ray GPU actors. Defaults to the parallel world size derived from the parallelism + # fields. + ray_num_workers: Optional[int] = None + # ray_address (`str`, *optional*): + # Ray cluster address forwarded to ray.init. + ray_address: Optional[str] = None + # ray_runtime_env (`dict`, *optional*): + # Runtime environment forwarded to ray.init. + ray_runtime_env: Optional[Dict[str, Any]] = None + # ray_init_kwargs (`dict`, *optional*): + # Extra keyword arguments forwarded to ray.init. + ray_init_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + # ray_worker_options (`dict`, *optional*): + # Extra options forwarded to Ray actor `.options(...)`. + ray_worker_options: Dict[str, Any] = dataclasses.field(default_factory=dict) + # ray_master_port (`int`, *optional*): + # TCPStore port used by Ray workers to initialize torch.distributed. + ray_master_port: int = 0 + # ray_auto_shutdown (`bool`, *optional*): + # Whether the wrapper should shutdown Ray when disabling cache if it initialized Ray itself. + ray_auto_shutdown: bool = True + # ray_transfer_backend (`str`, *optional*): + # How to move the model copy to Ray workers. "auto" uses a local safetensors file for + # diffusers ModelMixin instances, save_pretrained snapshots for saveable diffusers pipelines, + # and Ray object store for smaller generic modules. + ray_transfer_backend: str = "auto" + # ray_use_flashpack (`bool`, *optional*): + # Whether Ray model snapshots should call save_pretrained/from_pretrained with + # use_flashpack=True. Requires flashpack and a diffusers version that supports it. + ray_use_flashpack: bool = False + # ray_use_compile (`bool`, *optional*): + # Whether Ray workers should compile the executable transformer copy after loading and native + # parallelization. If available, compile_repeated_blocks() is preferred over nn.Module.compile(). + ray_use_compile: bool = False + # Internal test hook: skip cache-dit native parallel planner inside Ray workers. + _ray_skip_native_parallelism: bool = False + # cp_plan: (`cp plan`, *optional*): # The custom context parallelism plan pass by user. cp_plan: Optional[Any] = None @@ -150,7 +193,7 @@ def __post_init__(self): raise ValueError("No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size " "to enable parallelism.") - if self.hybrid_enabled(): + if self.hybrid_enabled() and not self.use_ray: try: self._maybe_init_hybrid_meshes() except Exception as e: diff --git a/src/cache_dit/ray/__init__.py b/src/cache_dit/ray/__init__.py new file mode 100644 index 00000000..2c9e157f --- /dev/null +++ b/src/cache_dit/ray/__init__.py @@ -0,0 +1,11 @@ +from .wrapper import disable_ray_parallelism +from .wrapper import disable_ray_pipeline_parallelism +from .wrapper import enable_ray_parallelism +from .wrapper import enable_ray_pipeline_parallelism + +__all__ = [ + "disable_ray_parallelism", + "disable_ray_pipeline_parallelism", + "enable_ray_parallelism", + "enable_ray_pipeline_parallelism", +] diff --git a/src/cache_dit/ray/_tree.py b/src/cache_dit/ray/_tree.py new file mode 100644 index 00000000..6dabe824 --- /dev/null +++ b/src/cache_dit/ray/_tree.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import dataclasses +from collections.abc import Mapping, Sequence +from typing import Any, Callable + +import torch + + +def tree_map_tensors(value: Any, fn: Callable[[torch.Tensor], torch.Tensor]) -> Any: + """Apply a function to every tensor in a nested Python object. + + :param value: Nested object that may contain tensors. + :param fn: Function applied to each tensor leaf. + :returns: Object with the same container structure and transformed tensor leaves. + """ + + if isinstance(value, torch.Tensor): + return fn(value) + if dataclasses.is_dataclass(value) and not isinstance(value, type): + updates = { + field.name: tree_map_tensors(getattr(value, field.name), fn) + for field in dataclasses.fields(value) + } + return dataclasses.replace(value, **updates) + if isinstance(value, Mapping): + return type(value)((key, tree_map_tensors(item, fn)) for key, item in value.items()) + if isinstance(value, tuple) and hasattr(value, "_fields"): + return type(value)(*(tree_map_tensors(item, fn) for item in value)) + if isinstance(value, tuple): + return tuple(tree_map_tensors(item, fn) for item in value) + if isinstance(value, list): + return [tree_map_tensors(item, fn) for item in value] + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return type(value)(tree_map_tensors(item, fn) for item in value) + return value + + +def cpu_tensor_tree(value: Any) -> Any: + """Detach and move tensor leaves to CPU for Ray object-store transport. + + :param value: Nested object that may contain tensors. + :returns: A nested object whose tensor leaves live on CPU. + """ + + return tree_map_tensors(value, lambda tensor: tensor.detach().cpu()) + + +def device_tensor_tree(value: Any, device: torch.device) -> Any: + """Move tensor leaves in a nested object to a target device. + + :param value: Nested object that may contain tensors. + :param device: Destination torch device. + :returns: A nested object whose tensor leaves live on ``device``. + """ + + return tree_map_tensors(value, lambda tensor: tensor.to(device=device)) + + +def first_tensor_device(value: Any) -> torch.device | None: + """Return the device of the first tensor leaf in a nested object. + + :param value: Nested object that may contain tensors. + :returns: Device of the first tensor leaf, or ``None`` if no tensor exists. + """ + + if isinstance(value, torch.Tensor): + return value.device + if dataclasses.is_dataclass(value) and not isinstance(value, type): + for field in dataclasses.fields(value): + device = first_tensor_device(getattr(value, field.name)) + if device is not None: + return device + if isinstance(value, Mapping): + for item in value.values(): + device = first_tensor_device(item) + if device is not None: + return device + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + device = first_tensor_device(item) + if device is not None: + return device + return None diff --git a/src/cache_dit/ray/dist.py b/src/cache_dit/ray/dist.py new file mode 100644 index 00000000..7f65c86a --- /dev/null +++ b/src/cache_dit/ray/dist.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import torch +import torch.distributed as dist + + +def init_worker_process_group( + rank: int, + world_size: int, + master_port: int, +) -> None: + """Initialize torch.distributed inside one Ray worker actor. + + :param rank: Global rank assigned by the Ray engine. + :param world_size: Number of Ray workers participating in the model-parallel group. + :param master_port: TCPStore port shared by all workers. + """ + + if dist.is_available() and dist.is_initialized(): + return + + backend = "cpu:gloo,cuda:nccl" if torch.cuda.is_available() else "gloo" + store = dist.TCPStore( + host_name="127.0.0.1", + port=master_port, + world_size=world_size, + is_master=(rank == 0), + ) + dist.init_process_group( + backend=backend, + store=store, + rank=rank, + world_size=world_size, + device_id=None, + ) + dist.barrier() + + +def destroy_worker_process_group() -> None: + """Destroy the process group owned by a Ray worker actor.""" + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() diff --git a/src/cache_dit/ray/engine.py b/src/cache_dit/ray/engine.py new file mode 100644 index 00000000..99427b67 --- /dev/null +++ b/src/cache_dit/ray/engine.py @@ -0,0 +1,474 @@ +from __future__ import annotations + +import socket +import shutil +import uuid +import time +import warnings +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline +from diffusers.models.modeling_utils import ModelMixin + +from ..distributed import ParallelismConfig +from ..logger import init_logger +from ..quantization import QuantizeConfig +from ..utils import maybe_empty_cache +from ._tree import cpu_tensor_tree +from ._tree import device_tensor_tree +from ._tree import first_tensor_device +from .worker import RayTransformerWorker +from .worker import RayPipelineWorker + +logger = init_logger(__name__) + + +def _require_ray(): + try: + import ray + except ImportError as exc: + raise ImportError("Ray wrapper requires ray. Install it with `pip install ray` or " + "`pip install cache-dit[ray]`.") from exc + return ray + + +def _init_ray(ray: Any, **init_kwargs: Any) -> None: + """Initialize Ray while suppressing its accelerator override transition warning. + + :param ray: Imported Ray module. + :param init_kwargs: Keyword arguments forwarded to ``ray.init``. + """ + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=("Tip: In future versions of Ray, Ray will no longer override accelerator " + "visible devices env var if num_gpus=0 or num_gpus=None.*"), + category=FutureWarning, + module="ray\\._private\\.worker", + ) + ray.init(**init_kwargs) + + +def _get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _import_safetensors_save_file(): + try: + from safetensors.torch import save_file + except ImportError as exc: + raise ImportError("Ray safetensors transfer requires `safetensors`. Install with " + "`pip install cache-dit[ray]` or `pip install safetensors`.") from exc + return save_file + + +def _save_state_dict_safetensors(module: torch.nn.Module, path: Path) -> None: + save_file = _import_safetensors_save_file() + tensors = { + name: tensor.detach().cpu().contiguous() + for name, tensor in module.state_dict().items() + } + save_file(tensors, str(path)) + + +def _first_parameter_device(module: torch.nn.Module) -> torch.device | None: + for parameter in module.parameters(recurse=True): + return parameter.device + return None + + +def _first_parameter_dtype(module: torch.nn.Module) -> torch.dtype | None: + for parameter in module.parameters(recurse=True): + return parameter.dtype + return None + + +def _first_pipeline_module_device(pipe: DiffusionPipeline) -> torch.device | None: + for component in pipe.components.values(): + if isinstance(component, torch.nn.Module): + device = _first_parameter_device(component) + if device is not None: + return device + return None + + +def _first_pipeline_module_dtype(pipe: DiffusionPipeline) -> torch.dtype | None: + for component in pipe.components.values(): + if isinstance(component, torch.nn.Module): + dtype = _first_parameter_dtype(component) + if dtype is not None: + return dtype + return None + + +def _move_pipeline_modules(pipe: DiffusionPipeline, device: str | torch.device) -> None: + for component in pipe.components.values(): + if isinstance(component, torch.nn.Module): + component.to(device) + + +def _pipeline_supports_save_pretrained(pipe: DiffusionPipeline) -> bool: + for component in pipe.components.values(): + if component is None: + continue + if not callable(getattr(component, "save_pretrained", None)): + return False + return callable(getattr(pipe, "save_pretrained", None)) + + +def _model_supports_save_pretrained(model: ModelMixin) -> bool: + return (callable(getattr(model, "save_pretrained", None)) + and callable(getattr(model.__class__, "from_pretrained", None))) + + +class RayParallelEngine: + """Main-process engine that dispatches transformer forwards to Ray actors. + + :param transformer: User-visible transformer whose forward will be proxied. + :param parallelism_config: Ray-enabled parallelism configuration. + """ + + def __init__( + self, + transformer: torch.nn.Module | ModelMixin, + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None = None, + quantize_config: QuantizeConfig | None = None, + ): + self.ray = _require_ray() + self.parallelism_config = parallelism_config + self.cache_context_kwargs = cache_context_kwargs + self.quantize_config = quantize_config + self._source_transformer = transformer + self._source_device = _first_parameter_device(transformer) + parallel_world_size = parallelism_config._get_world_size() + self.world_size = parallelism_config.ray_num_workers or parallel_world_size + if self.world_size <= 1: + raise ValueError("Ray parallelism requires a world size greater than 1.") + if (parallelism_config.ray_num_workers is not None + and parallelism_config.ray_num_workers != parallel_world_size): + raise ValueError("ray_num_workers must match the parallelism world size for the minimal " + f"Ray wrapper. Got ray_num_workers={parallelism_config.ray_num_workers}, " + f"world_size={parallel_world_size}.") + + self._ray_initialized_by_engine = False + self._transfer_dir: Path | None = None + if not self.ray.is_initialized(): + init_kwargs = dict(parallelism_config.ray_init_kwargs) + if parallelism_config.ray_address is not None: + init_kwargs["address"] = parallelism_config.ray_address + if parallelism_config.ray_runtime_env is not None: + init_kwargs["runtime_env"] = parallelism_config.ray_runtime_env + _init_ray(self.ray, **init_kwargs) + self._ray_initialized_by_engine = True + + self.master_port = parallelism_config.ray_master_port or _get_free_port() + self._actors = self._create_workers(transformer) + + def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any]: + remote_worker = self.ray.remote(RayTransformerWorker) + worker_options = {"num_gpus": 1} + worker_options.update(self.parallelism_config.ray_worker_options) + actors = [ + remote_worker.options(**worker_options).remote( + rank, + self.world_size, + self.parallelism_config, + self.cache_context_kwargs, + self.quantize_config, + self.master_port, + ) for rank in range(self.world_size) + ] + self.ray.get([actor.ready.remote() for actor in actors]) + device_infos = self.ray.get([actor.device_info.remote() for actor in actors]) + logger.info(f"Ray transformer worker placement before load: {device_infos}") + + if self._source_device is not None and self._source_device.type != "cpu": + logger.info("Moving the main-process transformer to CPU before Ray worker loading.") + offload_start = time.perf_counter() + transformer.to("cpu") + maybe_empty_cache() + logger.info(f"Moved the main-process transformer to CPU in " + f"{time.perf_counter() - offload_start:.2f}s.") + else: + logger.info("The main-process transformer is already on CPU before Ray worker loading.") + + transfer_backend = self.parallelism_config.ray_transfer_backend + if transfer_backend == "auto": + transfer_backend = "file" if isinstance(transformer, ModelMixin) else "object_store" + + if transfer_backend == "file": + if not isinstance(transformer, ModelMixin): + raise ValueError("ray_transfer_backend='file' currently requires a diffusers ModelMixin " + "transformer. Use ray_transfer_backend='object_store' for generic " + "torch.nn.Module instances.") + self._transfer_dir = Path.cwd() / ".tmp" / "cache_dit_ray" / uuid.uuid4().hex + self._transfer_dir.mkdir(parents=True, exist_ok=True) + save_start = time.perf_counter() + load_start = time.perf_counter() + if _model_supports_save_pretrained(transformer): + transformer_path = self._transfer_dir / "transformer" + logger.info(f"Saving the current transformer snapshot for Ray workers to " + f"{transformer_path}.") + transformer.save_pretrained( + transformer_path, + safe_serialization=True, + use_flashpack=self.parallelism_config.ray_use_flashpack, + ) + logger.info(f"Saved the transformer snapshot in " + f"{time.perf_counter() - save_start:.2f}s.") + load_infos = self.ray.get([ + actor.load_transformer_from_pretrained.remote( + transformer.__class__, + str(transformer_path), + _first_parameter_dtype(transformer), + self.parallelism_config.ray_use_flashpack, + ) for actor in actors + ]) + logger.info(f"Loaded pretrained transformer snapshots on Ray workers in " + f"{time.perf_counter() - load_start:.2f}s.") + else: + transformer_path = self._transfer_dir / "transformer.safetensors" + logger.info(f"Saving the CPU transformer state_dict for Ray workers to " + f"{transformer_path}.") + _save_state_dict_safetensors(transformer, transformer_path) + logger.info(f"Saved the transformer safetensors file in " + f"{time.perf_counter() - save_start:.2f}s.") + load_infos = self.ray.get([ + actor.load_transformer_from_safetensors.remote( + transformer.__class__, + dict(transformer.config), + str(transformer_path), + ) for actor in actors + ]) + logger.info(f"Loaded the safetensors transformer on Ray workers in " + f"{time.perf_counter() - load_start:.2f}s.") + elif transfer_backend == "object_store": + logger.info("Putting the CPU transformer into the Ray object store.") + put_start = time.perf_counter() + transformer_ref = self.ray.put(transformer) + logger.info(f"Put the CPU transformer into the Ray object store in " + f"{time.perf_counter() - put_start:.2f}s.") + load_start = time.perf_counter() + load_infos = self.ray.get( + [actor.load_transformer.remote(transformer_ref) for actor in actors]) + logger.info(f"Loaded the object-store transformer on Ray workers in " + f"{time.perf_counter() - load_start:.2f}s.") + else: + raise ValueError(f"Unsupported ray_transfer_backend: {transfer_backend!r}.") + logger.info(f"Ray transformer worker placement after load: {load_infos}") + return actors + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Proxy one transformer forward through all Ray ranks. + + :param args: Positional arguments passed to the original transformer forward. + :param kwargs: Keyword arguments passed to the original transformer forward. + :returns: Rank-0 output moved back to the caller's tensor device when possible. + """ + + output_device = first_tensor_device((args, kwargs)) + cpu_args = cpu_tensor_tree(args) + cpu_kwargs = cpu_tensor_tree(kwargs) + results = self.ray.get([actor.forward.remote(cpu_args, cpu_kwargs) for actor in self._actors]) + rank0_output = next((result for result in results if result is not None), None) + if rank0_output is None: + raise RuntimeError("Ray transformer workers did not return a rank-0 output.") + if output_device is not None: + rank0_output = device_tensor_tree(rank0_output, output_device) + return rank0_output + + def shutdown(self) -> None: + """Shutdown worker actors and optionally the Ray runtime initialized by this engine.""" + + if self._actors: + self.ray.get([actor.shutdown.remote() for actor in self._actors]) + for actor in self._actors: + self.ray.kill(actor) + self._actors = [] + if self._ray_initialized_by_engine and self.parallelism_config.ray_auto_shutdown: + self.ray.shutdown() + if self._source_device is not None and self._source_device.type != "cpu": + restore_start = time.perf_counter() + self._source_transformer.to(self._source_device) + maybe_empty_cache() + logger.info(f"Restored the main-process transformer to {self._source_device} in " + f"{time.perf_counter() - restore_start:.2f}s.") + if self._transfer_dir is not None and self._transfer_dir.exists(): + shutil.rmtree(self._transfer_dir) + self._transfer_dir = None + + +class RayPipelineEngine: + """Main-process engine that dispatches full pipeline calls to Ray actors. + + :param pipe: User-visible pipeline whose ``__call__`` will be proxied. + :param parallelism_config: Ray-enabled parallelism configuration. + """ + + def __init__( + self, + pipe: DiffusionPipeline, + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None = None, + quantize_config: QuantizeConfig | None = None, + ): + if not hasattr(pipe, "transformer"): + raise ValueError("Ray pipeline parallelism requires a pipeline with a transformer " + "attribute.") + self.ray = _require_ray() + self.parallelism_config = parallelism_config + self.cache_context_kwargs = cache_context_kwargs + self.quantize_config = quantize_config + self._source_pipe = pipe + self._source_device = _first_pipeline_module_device(pipe) + self._source_dtype = _first_pipeline_module_dtype(pipe) + parallel_world_size = parallelism_config._get_world_size() + self.world_size = parallelism_config.ray_num_workers or parallel_world_size + if self.world_size <= 1: + raise ValueError("Ray parallelism requires a world size greater than 1.") + if (parallelism_config.ray_num_workers is not None + and parallelism_config.ray_num_workers != parallel_world_size): + raise ValueError("ray_num_workers must match the parallelism world size for the minimal " + f"Ray wrapper. Got ray_num_workers={parallelism_config.ray_num_workers}, " + f"world_size={parallel_world_size}.") + + self._ray_initialized_by_engine = False + self._transfer_dir: Path | None = None + if not self.ray.is_initialized(): + init_kwargs = dict(parallelism_config.ray_init_kwargs) + if parallelism_config.ray_address is not None: + init_kwargs["address"] = parallelism_config.ray_address + if parallelism_config.ray_runtime_env is not None: + init_kwargs["runtime_env"] = parallelism_config.ray_runtime_env + _init_ray(self.ray, **init_kwargs) + self._ray_initialized_by_engine = True + + self.master_port = parallelism_config.ray_master_port or _get_free_port() + self._actors = self._create_workers(pipe) + + def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: + remote_worker = self.ray.remote(RayPipelineWorker) + worker_options = {"num_gpus": 1} + worker_options.update(self.parallelism_config.ray_worker_options) + actors = [ + remote_worker.options(**worker_options).remote( + rank, + self.world_size, + self.parallelism_config, + self.cache_context_kwargs, + self.quantize_config, + self.master_port, + ) for rank in range(self.world_size) + ] + self.ray.get([actor.ready.remote() for actor in actors]) + device_infos = self.ray.get([actor.device_info.remote() for actor in actors]) + logger.info(f"Ray pipeline worker placement before load: {device_infos}") + + if self._source_device is not None and self._source_device.type != "cpu": + logger.info("Moving the main-process pipeline to CPU before Ray worker loading.") + offload_start = time.perf_counter() + _move_pipeline_modules(pipe, "cpu") + maybe_empty_cache() + logger.info(f"Moved the main-process pipeline to CPU in " + f"{time.perf_counter() - offload_start:.2f}s.") + else: + logger.info("The main-process pipeline is already on CPU before Ray worker loading.") + + transfer_backend = self.parallelism_config.ray_transfer_backend + model_path = getattr(pipe, "name_or_path", None) + if transfer_backend == "auto": + transfer_backend = "save_pretrained" if _pipeline_supports_save_pretrained( + pipe) else "object_store" + + load_start = time.perf_counter() + if transfer_backend == "save_pretrained": + self._transfer_dir = Path.cwd() / ".tmp" / "cache_dit_ray" / uuid.uuid4().hex + pipeline_path = self._transfer_dir / "pipeline" + pipeline_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving the current pipeline snapshot for Ray workers to {pipeline_path}.") + save_start = time.perf_counter() + pipe.save_pretrained( + pipeline_path, + safe_serialization=True, + use_flashpack=self.parallelism_config.ray_use_flashpack, + ) + logger.info(f"Saved the pipeline snapshot in {time.perf_counter() - save_start:.2f}s.") + load_infos = self.ray.get([ + actor.load_pipeline_from_pretrained.remote( + pipe.__class__, + str(pipeline_path), + self._source_dtype, + self.parallelism_config.ray_use_flashpack, + ) for actor in actors + ]) + logger.info(f"Loaded saved pipeline snapshots on Ray workers in " + f"{time.perf_counter() - load_start:.2f}s.") + elif transfer_backend == "from_pretrained": + if model_path is None: + raise ValueError("ray_transfer_backend='from_pretrained' requires pipeline.name_or_path.") + logger.info(f"Loading Ray worker pipelines from pretrained source: {model_path}.") + load_infos = self.ray.get([ + actor.load_pipeline_from_pretrained.remote( + pipe.__class__, + model_path, + self._source_dtype, + self.parallelism_config.ray_use_flashpack, + ) for actor in actors + ]) + logger.info(f"Loaded pretrained pipelines on Ray workers in " + f"{time.perf_counter() - load_start:.2f}s.") + elif transfer_backend == "object_store": + logger.info("Putting the CPU pipeline into the Ray object store.") + put_start = time.perf_counter() + pipe_ref = self.ray.put(pipe) + logger.info(f"Put the CPU pipeline into the Ray object store in " + f"{time.perf_counter() - put_start:.2f}s.") + load_infos = self.ray.get([actor.load_pipeline.remote(pipe_ref) for actor in actors]) + logger.info(f"Loaded the object-store pipeline on Ray workers in " + f"{time.perf_counter() - load_start:.2f}s.") + else: + raise ValueError(f"Unsupported Ray pipeline transfer backend: {transfer_backend!r}.") + logger.info(f"Ray pipeline worker placement after load: {load_infos}") + return actors + + def call(self, *args: Any, **kwargs: Any) -> Any: + """Proxy one full pipeline call through all Ray ranks. + + :param args: Positional arguments passed to the original pipeline call. + :param kwargs: Keyword arguments passed to the original pipeline call. + :returns: Rank-0 pipeline output. + """ + + cpu_args = cpu_tensor_tree(args) + cpu_kwargs = cpu_tensor_tree(kwargs) + results = self.ray.get([actor.call.remote(cpu_args, cpu_kwargs) for actor in self._actors]) + rank0_output = next((result for result in results if result is not None), None) + if rank0_output is None: + raise RuntimeError("Ray pipeline workers did not return a rank-0 output.") + return rank0_output + + def shutdown(self) -> None: + """Shutdown worker actors and optionally the Ray runtime initialized by this engine.""" + + if self._actors: + self.ray.get([actor.shutdown.remote() for actor in self._actors]) + for actor in self._actors: + self.ray.kill(actor) + self._actors = [] + if self._ray_initialized_by_engine and self.parallelism_config.ray_auto_shutdown: + self.ray.shutdown() + if self._source_device is not None and self._source_device.type != "cpu": + restore_start = time.perf_counter() + _move_pipeline_modules(self._source_pipe, self._source_device) + maybe_empty_cache() + logger.info(f"Restored the main-process pipeline to {self._source_device} in " + f"{time.perf_counter() - restore_start:.2f}s.") + if self._transfer_dir is not None and self._transfer_dir.exists(): + shutil.rmtree(self._transfer_dir) + self._transfer_dir = None diff --git a/src/cache_dit/ray/worker.py b/src/cache_dit/ray/worker.py new file mode 100644 index 00000000..f8913afa --- /dev/null +++ b/src/cache_dit/ray/worker.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import copy +import os +from typing import Any + +import torch +from diffusers import DiffusionPipeline +from diffusers.models.modeling_utils import ModelMixin + +from ..distributed import ParallelismConfig +from ..distributed import enable_parallelism +from ..logger import init_logger +from ..quantization import QuantizeConfig +from ..quantization import quantize +from ..utils import parse_extra_modules +from ._tree import cpu_tensor_tree +from ._tree import device_tensor_tree +from .dist import destroy_worker_process_group +from .dist import init_worker_process_group + +logger = init_logger(__name__) + + +def _maybe_compile_transformer( + transformer: torch.nn.Module | ModelMixin, + parallelism_config: ParallelismConfig, +) -> torch.nn.Module | ModelMixin: + """Compile a Ray-owned transformer when requested by the parallelism config. + + :param transformer: Transformer copy already moved to the actor device and parallelized. + :param parallelism_config: Ray-enabled parallelism configuration. + :returns: The same transformer object after any in-place compile step. + """ + + if not parallelism_config.ray_use_compile: + return transformer + + compile_repeated_blocks = getattr(transformer, "compile_repeated_blocks", None) + if callable(compile_repeated_blocks): + logger.info("Compiling Ray-owned transformer with compile_repeated_blocks().") + transformer.compile_repeated_blocks() + return transformer + + compile_module = getattr(transformer, "compile", None) + if callable(compile_module): + logger.info("Compiling Ray-owned transformer with nn.Module.compile().") + transformer.compile() + return transformer + + logger.warning("ray_use_compile=True, but transformer does not support " + "compile_repeated_blocks() or nn.Module.compile(); skipping compile.") + return transformer + + +def _maybe_apply_cache( + module_or_pipe: torch.nn.Module | ModelMixin | DiffusionPipeline, + cache_context_kwargs: dict[str, Any] | None, +) -> torch.nn.Module | ModelMixin | DiffusionPipeline: + """Apply cache hooks inside a Ray worker when cache config is provided. + + :param module_or_pipe: Worker-local transformer or pipeline copy. + :param cache_context_kwargs: Cache context keyword arguments from ``cache_dit.enable_cache``. + :returns: The same object with cache hooks applied, when requested. + """ + + if cache_context_kwargs is None: + return module_or_pipe + + from ..caching.cache_adapters import CachedAdapter + + logger.info(f"Applying cache hooks inside Ray worker for {module_or_pipe.__class__.__name__}.") + return CachedAdapter.apply(module_or_pipe, **copy.deepcopy(cache_context_kwargs)) + + +def _maybe_quantize_transformer( + transformer: torch.nn.Module | ModelMixin, + quantize_config: QuantizeConfig | None, +) -> torch.nn.Module | ModelMixin: + """Quantize a worker-local transformer when requested. + + :param transformer: Worker-local transformer after cache and parallelism have been applied. + :param quantize_config: Optional quantization configuration. + :returns: Quantized transformer when requested, otherwise the original transformer. + """ + + if quantize_config is None: + return transformer + logger.info(f"Applying quantization inside Ray worker for {transformer.__class__.__name__}.") + return quantize(transformer, quantize_config=copy.deepcopy(quantize_config)) + + +def _maybe_quantize_pipeline( + pipe: DiffusionPipeline, + quantize_config: QuantizeConfig | None, +) -> DiffusionPipeline: + """Quantize worker-local pipeline components when requested. + + :param pipe: Worker-local pipeline after cache and transformer parallelism have been applied. + :param quantize_config: Optional quantization configuration. + :returns: Pipeline with requested components quantized. + """ + + if quantize_config is None: + return pipe + if quantize_config.components_to_quantize is None: + pipe.transformer = _maybe_quantize_transformer(pipe.transformer, quantize_config) + return pipe + + expanded_quantize_configs = QuantizeConfig.expand_configs(quantize_config) + for config in expanded_quantize_configs: + components_to_quantize = config.components_to_quantize + components = parse_extra_modules(pipe, components_to_quantize) + assert len(components) == len(components_to_quantize), ( + f"Some components in quantize_config.components_to_quantize: {components_to_quantize} " + "are not found in the pipeline, please check the component names or directly pass the " + "actual modules in components_to_quantize.") + for component, name in zip(components, components_to_quantize): + name = getattr(component, "_actual_module_name", name) + quantized_component = quantize(component, quantize_config=copy.deepcopy(config)) + setattr(pipe, name, quantized_component) + return pipe + + +class RayTransformerWorker: + """Ray actor body that owns one rank of a cache-dit parallel transformer. + + :param rank: Global rank assigned by the Ray engine. + :param world_size: Number of Ray actors in the model-parallel group. + :param parallelism_config: Parallelism configuration to apply inside the actor. + :param master_port: TCPStore port shared by all actors. + """ + + def __init__( + self, + rank: int, + world_size: int, + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None, + quantize_config: QuantizeConfig | None, + master_port: int, + ): + self.rank = rank + self.world_size = world_size + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.device.type == "cuda": + torch.cuda.set_device(torch.cuda.current_device()) + + init_worker_process_group( + rank=rank, + world_size=world_size, + master_port=master_port, + ) + + self.parallelism_config = copy.deepcopy(parallelism_config) + self.cache_context_kwargs = copy.deepcopy(cache_context_kwargs) + self.quantize_config = copy.deepcopy(quantize_config) + self.parallelism_config.ray_num_workers = world_size + if self.parallelism_config.hybrid_enabled(): + self.parallelism_config._maybe_init_hybrid_meshes() + + self.transformer: torch.nn.Module | ModelMixin | None = None + + def load_transformer(self, transformer: torch.nn.Module | ModelMixin) -> dict[str, Any]: + """Load and parallelize the transformer copy owned by this actor. + + :param transformer: CPU transformer copy from the Ray object store. + :returns: Device placement and memory information after loading. + """ + + self.transformer = transformer.to(self.device) + self.transformer.eval() + self.transformer = _maybe_apply_cache(self.transformer, self.cache_context_kwargs) + if not self.parallelism_config._ray_skip_native_parallelism: + self.transformer = enable_parallelism(self.transformer, self.parallelism_config) + self.transformer = _maybe_quantize_transformer(self.transformer, self.quantize_config) + self.transformer = _maybe_compile_transformer(self.transformer, self.parallelism_config) + return self.device_info() + + def load_transformer_from_file(self, path: str) -> dict[str, Any]: + """Load and parallelize a transformer serialized on the local filesystem. + + :param path: Path to a CPU transformer checkpoint written by the Ray engine. + :returns: Device placement and memory information after loading. + """ + + transformer = torch.load(path, map_location="cpu", weights_only=False) + return self.load_transformer(transformer) + + def load_transformer_from_safetensors( + self, + transformer_cls: type[ModelMixin], + transformer_config: dict[str, Any], + path: str, + ) -> dict[str, Any]: + """Load a diffusers transformer from a safetensors state dict. + + :param transformer_cls: Diffusers transformer class used to reconstruct the module. + :param transformer_config: Serialized transformer config passed to ``from_config``. + :param path: Path to a safetensors state dict written by the Ray engine. + :returns: Device placement and memory information after loading. + """ + + try: + from safetensors.torch import load_file + except ImportError as exc: + raise ImportError("Ray safetensors transfer requires `safetensors`. Install with " + "`pip install cache-dit[ray]` or `pip install safetensors`.") from exc + + with torch.device("meta"): + transformer = transformer_cls.from_config(transformer_config) + state_dict = load_file(path, device=str(self.device)) + transformer.load_state_dict(state_dict, assign=True) + self.transformer = transformer.eval() + self.transformer = _maybe_apply_cache(self.transformer, self.cache_context_kwargs) + if not self.parallelism_config._ray_skip_native_parallelism: + self.transformer = enable_parallelism(self.transformer, self.parallelism_config) + self.transformer = _maybe_quantize_transformer(self.transformer, self.quantize_config) + self.transformer = _maybe_compile_transformer(self.transformer, self.parallelism_config) + return self.device_info() + + def load_transformer_from_pretrained( + self, + transformer_cls: type[ModelMixin], + model_path: str, + torch_dtype: torch.dtype | None, + use_flashpack: bool, + ) -> dict[str, Any]: + """Load a diffusers transformer snapshot inside this actor. + + :param transformer_cls: Diffusers transformer class used to reload the module. + :param model_path: Local snapshot directory written by the Ray engine. + :param torch_dtype: Optional dtype for model loading. + :param use_flashpack: Whether to prefer FlashPack weights during loading. + :returns: Device placement and memory information after loading. + """ + + load_kwargs = { + "use_safetensors": True, + "use_flashpack": use_flashpack, + } + if torch_dtype is not None: + load_kwargs["torch_dtype"] = torch_dtype + transformer = transformer_cls.from_pretrained(model_path, **load_kwargs) + return self.load_transformer(transformer) + + def ready(self) -> int: + """Return the rank after actor initialization has completed. + + :returns: The actor rank. + """ + + return self.rank + + def device_info(self) -> dict[str, Any]: + """Return actor device placement details for diagnostics. + + :returns: Rank, torch device, and visible accelerator ids for this actor. + """ + + info = { + "rank": self.rank, + "device": str(self.device), + "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), + } + if self.device.type == "cuda": + info["memory_allocated_mib"] = torch.cuda.memory_allocated() // 1024 // 1024 + info["memory_reserved_mib"] = torch.cuda.memory_reserved() // 1024 // 1024 + return info + + def forward(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any | None: + """Run a transformer forward on this Ray rank. + + :param args: CPU-staged positional arguments from the main process. + :param kwargs: CPU-staged keyword arguments from the main process. + :returns: CPU-staged output on rank 0 and ``None`` on other ranks. + """ + + device_args = device_tensor_tree(args, self.device) + device_kwargs = device_tensor_tree(kwargs, self.device) + if self.transformer is None: + raise RuntimeError("RayTransformerWorker.forward called before load_transformer.") + with torch.no_grad(): + output = self.transformer(*device_args, **device_kwargs) + if self.rank != 0: + return None + return cpu_tensor_tree(output) + + def shutdown(self) -> None: + """Release actor-local distributed state.""" + + destroy_worker_process_group() + + +class RayPipelineWorker: + """Ray actor body that owns one full pipeline and one distributed rank. + + :param rank: Global rank assigned by the Ray engine. + :param world_size: Number of Ray actors in the model-parallel group. + :param parallelism_config: Parallelism configuration to apply inside the actor. + :param master_port: TCPStore port shared by all actors. + """ + + def __init__( + self, + rank: int, + world_size: int, + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None, + quantize_config: QuantizeConfig | None, + master_port: int, + ): + self.rank = rank + self.world_size = world_size + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.device.type == "cuda": + torch.cuda.set_device(torch.cuda.current_device()) + + init_worker_process_group( + rank=rank, + world_size=world_size, + master_port=master_port, + ) + + self.parallelism_config = copy.deepcopy(parallelism_config) + self.cache_context_kwargs = copy.deepcopy(cache_context_kwargs) + self.quantize_config = copy.deepcopy(quantize_config) + self.parallelism_config.ray_num_workers = world_size + if self.parallelism_config.hybrid_enabled(): + self.parallelism_config._maybe_init_hybrid_meshes() + + self.pipe: DiffusionPipeline | None = None + + def load_pipeline(self, pipe: DiffusionPipeline) -> dict[str, Any]: + """Load a pipeline copy and parallelize its transformer inside this actor. + + :param pipe: CPU pipeline copy from the Ray object store. + :returns: Device placement and memory information after loading. + """ + + for component in pipe.components.values(): + if isinstance(component, torch.nn.Module): + component.to(self.device) + self.pipe = pipe + self.pipe.set_progress_bar_config(disable=True) + self.pipe = _maybe_apply_cache(self.pipe, self.cache_context_kwargs) + self.pipe.transformer.eval() + if not self.parallelism_config._ray_skip_native_parallelism: + self.pipe.transformer = enable_parallelism(self.pipe.transformer, self.parallelism_config) + self.pipe = _maybe_quantize_pipeline(self.pipe, self.quantize_config) + self.pipe.transformer = _maybe_compile_transformer( + self.pipe.transformer, + self.parallelism_config, + ) + return self.device_info() + + def load_pipeline_from_pretrained( + self, + pipe_cls: type[DiffusionPipeline], + model_path: str, + torch_dtype: torch.dtype | None, + use_flashpack: bool, + ) -> dict[str, Any]: + """Load a pipeline from its pretrained directory inside this actor. + + :param pipe_cls: Diffusers pipeline class used to reconstruct the pipeline. + :param model_path: Local path or model id passed to ``from_pretrained``. + :param torch_dtype: Optional dtype for model loading. + :param use_flashpack: Whether to prefer FlashPack weights during loading. + :returns: Device placement and memory information after loading. + """ + + load_kwargs = {} + if torch_dtype is not None: + load_kwargs["torch_dtype"] = torch_dtype + load_kwargs["use_safetensors"] = True + load_kwargs["use_flashpack"] = use_flashpack + pipe = pipe_cls.from_pretrained(model_path, **load_kwargs) + return self.load_pipeline(pipe) + + def ready(self) -> int: + """Return the rank after actor initialization has completed. + + :returns: The actor rank. + """ + + return self.rank + + def device_info(self) -> dict[str, Any]: + """Return actor device placement details for diagnostics. + + :returns: Rank, torch device, and visible accelerator ids for this actor. + """ + + info = { + "rank": self.rank, + "device": str(self.device), + "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), + } + if self.device.type == "cuda": + info["memory_allocated_mib"] = torch.cuda.memory_allocated() // 1024 // 1024 + info["memory_reserved_mib"] = torch.cuda.memory_reserved() // 1024 // 1024 + return info + + def call(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any | None: + """Run a full pipeline call on this Ray rank. + + :param args: Positional arguments from the main process. + :param kwargs: Keyword arguments from the main process. + :returns: Pipeline output on rank 0 and ``None`` on other ranks. + """ + + if self.pipe is None: + raise RuntimeError("RayPipelineWorker.call called before load_pipeline.") + device_args = device_tensor_tree(args, self.device) + device_kwargs = device_tensor_tree(kwargs, self.device) + with torch.no_grad(): + output = self.pipe(*device_args, **device_kwargs) + if self.rank != 0: + return None + return cpu_tensor_tree(output) + + def shutdown(self) -> None: + """Release actor-local distributed state.""" + + destroy_worker_process_group() diff --git a/src/cache_dit/ray/wrapper.py b/src/cache_dit/ray/wrapper.py new file mode 100644 index 00000000..8c6efa3d --- /dev/null +++ b/src/cache_dit/ray/wrapper.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import types +from typing import Any + +import torch +from diffusers import DiffusionPipeline +from diffusers.models.modeling_utils import ModelMixin + +from ..distributed import ParallelismConfig +from ..logger import init_logger +from ..quantization import QuantizeConfig +from .engine import RayParallelEngine +from .engine import RayPipelineEngine + +logger = init_logger(__name__) + + +def enable_ray_parallelism( + transformer: torch.nn.Module | ModelMixin, + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None = None, + quantize_config: QuantizeConfig | None = None, +) -> torch.nn.Module | ModelMixin: + """Patch a transformer so its forward is executed by Ray worker actors. + + :param transformer: User-visible transformer module to patch in-place. + :param parallelism_config: Ray-enabled parallelism configuration. + :param cache_context_kwargs: Optional cache context keyword arguments to apply inside Ray workers. + :param quantize_config: Optional quantization configuration to apply inside Ray workers. + :returns: The same transformer object with a proxied forward method. + """ + + if getattr(transformer, "_cache_dit_ray_enabled", False): + logger.warning("Ray parallelism is already enabled for this transformer. Skipping.") + return transformer + + engine = RayParallelEngine(transformer, parallelism_config, cache_context_kwargs, quantize_config) + transformer._cache_dit_ray_original_forward = transformer.forward # type: ignore[attr-defined] + transformer._cache_dit_ray_original_to = transformer.to # type: ignore[attr-defined] + transformer._cache_dit_ray_engine = engine # type: ignore[attr-defined] + + def ray_forward(self, *args, **kwargs): + return self._cache_dit_ray_engine.forward(*args, **kwargs) + + def ray_to(self, *args, **kwargs): + logger.info(f"Skipping .to(...) for Ray-owned {self.__class__.__name__}; " + "worker actors own the executable transformer copies.") + return self + + transformer.forward = types.MethodType(ray_forward, transformer) + transformer.to = types.MethodType(ray_to, transformer) + transformer._cache_dit_ray_enabled = True # type: ignore[attr-defined] + logger.info(f"Enabled Ray parallelism for {transformer.__class__.__name__} with " + f"world_size={engine.world_size}.") + return transformer + + +def disable_ray_parallelism(transformer: torch.nn.Module | ModelMixin) -> None: + """Restore a transformer patched by :func:`enable_ray_parallelism`. + + :param transformer: Transformer module that may own a Ray engine. + """ + + engine = getattr(transformer, "_cache_dit_ray_engine", None) + if hasattr(transformer, "_cache_dit_ray_original_to"): + transformer.to = transformer._cache_dit_ray_original_to + del transformer._cache_dit_ray_original_to + if engine is not None: + engine.shutdown() + del transformer._cache_dit_ray_engine + if hasattr(transformer, "_cache_dit_ray_original_forward"): + transformer.forward = transformer._cache_dit_ray_original_forward + del transformer._cache_dit_ray_original_forward + if hasattr(transformer, "_cache_dit_ray_enabled"): + del transformer._cache_dit_ray_enabled + + +def enable_ray_pipeline_parallelism( + pipe: DiffusionPipeline, + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None = None, + quantize_config: QuantizeConfig | None = None, +) -> DiffusionPipeline: + """Patch a pipeline so each full inference call is executed by Ray workers. + + :param pipe: User-visible diffusion pipeline to patch in-place. + :param parallelism_config: Ray-enabled parallelism configuration. + :param cache_context_kwargs: Optional cache context keyword arguments to apply inside Ray workers. + :param quantize_config: Optional quantization configuration to apply inside Ray workers. + :returns: The same pipeline object with a proxied ``__call__`` method. + """ + + if getattr(pipe, "_cache_dit_ray_pipeline_enabled", False): + logger.warning("Ray parallelism is already enabled for this pipeline. Skipping.") + return pipe + + engine = RayPipelineEngine(pipe, parallelism_config, cache_context_kwargs, quantize_config) + original_class = pipe.__class__ + + def ray_pipeline_call(self, *args, **kwargs): + return self._cache_dit_ray_pipeline_engine.call(*args, **kwargs) + + ray_class = type( + f"CacheDitRay{original_class.__name__}", + (original_class, ), + {"__call__": ray_pipeline_call}, + ) + pipe._cache_dit_ray_pipeline_original_class = original_class # type: ignore[attr-defined] + pipe._cache_dit_ray_pipeline_engine = engine # type: ignore[attr-defined] + pipe.__class__ = ray_class + pipe._cache_dit_ray_pipeline_enabled = True # type: ignore[attr-defined] + logger.info(f"Enabled Ray parallelism for {original_class.__name__} with " + f"world_size={engine.world_size}.") + return pipe + + +def disable_ray_pipeline_parallelism(pipe: DiffusionPipeline) -> None: + """Restore a pipeline patched by :func:`enable_ray_pipeline_parallelism`. + + :param pipe: Pipeline that may own a Ray pipeline engine. + """ + + engine = getattr(pipe, "_cache_dit_ray_pipeline_engine", None) + if engine is not None: + engine.shutdown() + del pipe._cache_dit_ray_pipeline_engine + original_class = getattr(pipe, "_cache_dit_ray_pipeline_original_class", None) + if original_class is not None: + pipe.__class__ = original_class + del pipe._cache_dit_ray_pipeline_original_class + if hasattr(pipe, "_cache_dit_ray_pipeline_enabled"): + del pipe._cache_dit_ray_pipeline_enabled diff --git a/tests/ray/test_ray_wrapper.py b/tests/ray/test_ray_wrapper.py new file mode 100644 index 00000000..4608464f --- /dev/null +++ b/tests/ray/test_ray_wrapper.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +from pathlib import Path + +import pytest +import torch +from diffusers import DiffusionPipeline + +import cache_dit +from cache_dit import ParallelismConfig +from cache_dit.caching.cache_contexts import DBCacheConfig +from cache_dit.metrics import compute_psnr +from cache_dit.quantization import QuantizeConfig + +ray = pytest.importorskip("ray") + +_REPO_ROOT = Path(__file__).resolve().parents[2] +_PYTHON_BIN = Path("/workspace/dev/miniconda3/envs/cdit/bin/python") +_DEFAULT_VISIBLE_DEVICES = os.getenv("CACHE_DIT_TEST_RAY_CUDA_VISIBLE_DEVICES", "6,7") +_ENABLE_FLUX_TEST = os.getenv("CACHE_DIT_TEST_RAY_FLUX", "0").lower() == "1" +_DEFAULT_MODEL_SOURCE = os.getenv( + "FLUX_2_KLEIN_BASE_9B_DIR", + "/workspace/dev/vipdev/hf_models/FLUX.2-klein-9B", +) +_TEST_OUTPUT_DIR = _REPO_ROOT / ".tmp" / "tests" / "ray_wrapper" + + +class ToyPipeline(DiffusionPipeline): + """Minimal DiffusionPipeline-shaped object with a transformer attribute.""" + + def __init__(self, transformer: torch.nn.Module) -> None: + super().__init__() + self.register_modules(transformer=transformer) + + def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.transformer(hidden_states) + + +class ToyCompilableTransformer(torch.nn.Module): + """Tiny transformer that exposes compile_repeated_blocks for Ray compile tests.""" + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.repeated_blocks_compiled = False + with torch.no_grad(): + self.linear.weight.copy_(torch.eye(4)) + self.linear.bias.zero_() + + def compile_repeated_blocks(self) -> None: + """Mark repeated blocks as compiled without invoking torch.compile in unit tests.""" + + self.repeated_blocks_compiled = True + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + marker = torch.tensor(float(self.repeated_blocks_compiled), device=hidden_states.device) + return self.linear(hidden_states) + marker + + +class ToyExecutionContextTransformer(torch.nn.Module): + """Tiny transformer that validates Ray worker execution context.""" + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_inference_mode_enabled(): + raise RuntimeError("Ray worker forward should not run under torch.inference_mode().") + if torch.is_grad_enabled(): + raise RuntimeError("Ray worker forward should run with gradients disabled.") + return hidden_states + 1.0 + + +@pytest.fixture(autouse=True) +def shutdown_ray_runtime(): + yield + if ray.is_initialized(): + ray.shutdown() + + +def _toy_transformer() -> torch.nn.Sequential: + linear = torch.nn.Linear(4, 4) + with torch.no_grad(): + linear.weight.copy_(torch.eye(4)) + linear.bias.copy_(torch.arange(4, dtype=torch.float32)) + return torch.nn.Sequential(linear, torch.nn.Sigmoid()) + + +def _toy_parallel_config() -> ParallelismConfig: + return ParallelismConfig( + ulysses_size=2, + use_ray=True, + ray_runtime_env={ + "env_vars": { + "PYTHONPATH": f"{_REPO_ROOT / 'tests' / 'ray'}:{_REPO_ROOT / 'src'}", + }, + }, + ray_worker_options={"num_gpus": 0}, + _ray_skip_native_parallelism=True, + ) + + +def test_parallelism_config_rejects_short_aliases() -> None: + """ParallelismConfig should require canonical size field names.""" + + with pytest.raises(TypeError): + ParallelismConfig(ulysses=2) + with pytest.raises(TypeError): + ParallelismConfig(ring=2) + with pytest.raises(TypeError): + ParallelismConfig(tp=2) + + +def test_ray_wrapper_transformer_only_toy_model() -> None: + hidden_states = torch.arange(8, dtype=torch.float32).reshape(2, 4) + transformer = _toy_transformer() + baseline = transformer(hidden_states) + + returned = cache_dit.enable_cache(transformer, parallelism_config=_toy_parallel_config()) + + assert returned is transformer + assert getattr(transformer, "_cache_dit_ray_enabled", False) + original_dtype = next(transformer.parameters()).dtype + assert transformer.to(torch.float64) is transformer + assert next(transformer.parameters()).dtype == original_dtype + result = transformer(hidden_states) + torch.testing.assert_close(result, baseline) + + cache_dit.disable_cache(transformer) + assert not hasattr(transformer, "_cache_dit_ray_enabled") + transformer.to(torch.float64) + assert next(transformer.parameters()).dtype == torch.float64 + transformer.to(original_dtype) + torch.testing.assert_close(transformer(hidden_states), baseline) + + +def test_ray_wrapper_pipeline_level_toy_model() -> None: + hidden_states = torch.arange(8, dtype=torch.float32).reshape(2, 4) + pipe = ToyPipeline(_toy_transformer()) + baseline = pipe(hidden_states) + + returned = cache_dit.enable_cache(pipe, parallelism_config=_toy_parallel_config()) + + assert returned is pipe + assert getattr(pipe, "_cache_dit_ray_pipeline_enabled", False) + result = pipe(hidden_states) + torch.testing.assert_close(result, baseline) + + cache_dit.disable_cache(pipe) + assert not hasattr(pipe, "_cache_dit_ray_pipeline_enabled") + torch.testing.assert_close(pipe(hidden_states), baseline) + + +def test_ray_wrapper_compile_repeated_blocks_toy_model() -> None: + hidden_states = torch.arange(8, dtype=torch.float32).reshape(2, 4) + transformer = ToyCompilableTransformer() + baseline = transformer(hidden_states) + parallelism_config = _toy_parallel_config() + parallelism_config.ray_use_compile = True + + cache_dit.enable_cache(transformer, parallelism_config=parallelism_config) + + result = transformer(hidden_states) + torch.testing.assert_close(result, baseline + 1.0) + + cache_dit.disable_cache(transformer) + + +def test_ray_wrapper_worker_uses_no_grad_not_inference_mode() -> None: + hidden_states = torch.arange(8, dtype=torch.float32).reshape(2, 4) + transformer = ToyExecutionContextTransformer() + + cache_dit.enable_cache(transformer, parallelism_config=_toy_parallel_config()) + + result = transformer(hidden_states) + torch.testing.assert_close(result, hidden_states + 1.0) + + cache_dit.disable_cache(transformer) + + +def test_ray_wrapper_defers_cache_hooks_to_workers(monkeypatch: pytest.MonkeyPatch) -> None: + transformer = _toy_transformer() + parallelism_config = _toy_parallel_config() + captured: dict[str, object] = {} + + def fail_main_process_cache_apply(*args, **kwargs): + raise AssertionError("Ray mode should not apply cache hooks in the main process.") + + def fake_enable_ray_parallelism( + transformer_arg, + parallelism_config_arg, + cache_context_kwargs=None, + quantize_config=None, + ): + captured["transformer"] = transformer_arg + captured["parallelism_config"] = parallelism_config_arg + captured["cache_context_kwargs"] = cache_context_kwargs + captured["quantize_config"] = quantize_config + return transformer_arg + + monkeypatch.setattr( + "cache_dit.caching.cache_interface.CachedAdapter.apply", + fail_main_process_cache_apply, + ) + monkeypatch.setattr( + "cache_dit.ray.enable_ray_parallelism", + fake_enable_ray_parallelism, + ) + + returned = cache_dit.enable_cache( + transformer, + cache_config=DBCacheConfig(), + parallelism_config=parallelism_config, + ) + + assert returned is transformer + assert captured["transformer"] is transformer + assert captured["parallelism_config"] is parallelism_config + cache_context_kwargs = captured["cache_context_kwargs"] + assert isinstance(cache_context_kwargs, dict) + assert isinstance(cache_context_kwargs["cache_config"], DBCacheConfig) + assert captured["quantize_config"] is None + + +def test_ray_wrapper_defers_quantize_to_workers(monkeypatch: pytest.MonkeyPatch) -> None: + transformer = _toy_transformer() + parallelism_config = _toy_parallel_config() + quantize_config = QuantizeConfig(quant_type="float8_per_row") + captured: dict[str, object] = {} + + def fail_main_process_quantize(*args, **kwargs): + raise AssertionError("Ray mode should not quantize in the main process.") + + def fake_enable_ray_parallelism( + transformer_arg, + parallelism_config_arg, + cache_context_kwargs=None, + quantize_config=None, + ): + captured["transformer"] = transformer_arg + captured["parallelism_config"] = parallelism_config_arg + captured["cache_context_kwargs"] = cache_context_kwargs + captured["quantize_config"] = quantize_config + return transformer_arg + + monkeypatch.setattr( + "cache_dit.caching.cache_interface.quantize", + fail_main_process_quantize, + ) + monkeypatch.setattr( + "cache_dit.ray.enable_ray_parallelism", + fake_enable_ray_parallelism, + ) + + returned = cache_dit.enable_cache( + transformer, + parallelism_config=parallelism_config, + quantize_config=quantize_config, + ) + + assert returned is transformer + assert captured["transformer"] is transformer + assert captured["parallelism_config"] is parallelism_config + assert captured["cache_context_kwargs"] is None + assert isinstance(captured["quantize_config"], QuantizeConfig) + assert captured["quantize_config"] is not quantize_config + + +@pytest.mark.skipif( + not _ENABLE_FLUX_TEST, + reason="FLUX Ray wrapper test requires CACHE_DIT_TEST_RAY_FLUX=1.", +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FLUX Ray wrapper test requires CUDA.") +def test_ray_wrapper_flux_example_psnr() -> None: + visible_devices = [ + device.strip() for device in _DEFAULT_VISIBLE_DEVICES.split(",") if device.strip() + ] + if len(visible_devices) < 2: + pytest.skip("FLUX Ray wrapper test requires at least two visible CUDA devices.") + if not _PYTHON_BIN.is_file(): + pytest.skip("The configured cdit python binary is unavailable.") + + if _TEST_OUTPUT_DIR.exists(): + shutil.rmtree(_TEST_OUTPUT_DIR) + _TEST_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + baseline_path = _TEST_OUTPUT_DIR / "baseline.png" + ray_path = _TEST_OUTPUT_DIR / "ray2.png" + env = os.environ.copy() + env["PYTHONPATH"] = str(_REPO_ROOT / "src") + env["CUDA_VISIBLE_DEVICES"] = _DEFAULT_VISIBLE_DEVICES + + subprocess.run( + [ + str(_PYTHON_BIN), + "examples/ray/ray_wrapper_example.py", + "--model-path", + _DEFAULT_MODEL_SOURCE, + "--ulysses", + "1", + "--num-inference-steps", + "4", + "--warmup", + "0", + "--save-path", + str(baseline_path), + ], + cwd=_REPO_ROOT, + env=env, + check=True, + ) + subprocess.run( + [ + str(_PYTHON_BIN), + "examples/ray/ray_wrapper_example.py", + "--model-path", + _DEFAULT_MODEL_SOURCE, + "--ulysses", + "2", + "--num-inference-steps", + "4", + "--warmup", + "0", + "--save-path", + str(ray_path), + ], + cwd=_REPO_ROOT, + env=env, + check=True, + ) + + psnr, count = compute_psnr(str(baseline_path), str(ray_path)) + assert count == 1 + assert psnr is not None and psnr > 20.0