Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/cache_dit/caching/cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ def _enable_cache_with_ray_impl(
attention backend resolution, parallelism tuning, quantization) is deferred to
:func:`_enable_cache_impl` running inside each Ray worker.
"""
# Soft check: only trigger the ray import when entering the Ray path.
try:
import ray # noqa: F401
except ImportError as exc:
raise ImportError("Ray wrapper requires ray. Install it with `pip install ray` or "
"`pip install cache-dit[ray]`.") from exc

if attention_backend is not None:
parallelism_config.attention_backend = attention_backend

Expand Down
82 changes: 33 additions & 49 deletions src/cache_dit/ray/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any

import ray
import torch
from diffusers import DiffusionPipeline
from diffusers import pipelines as diffusers_pipelines
Expand Down Expand Up @@ -54,7 +55,6 @@ def _maybe_user_module(cls: type) -> bool:


def _warn_if_ray_pre_initialized(
ray: Any,
cls: type,
config: ParallelismConfig,
ray_initialized_by_engine: bool = False,
Expand All @@ -67,7 +67,6 @@ def _warn_if_ray_pre_initialized(
to every worker. If that was not configured, worker-side pickle deserialization
(or ``from_pretrained`` class imports) will fail with ``ModuleNotFoundError``.

:param ray: Imported Ray module.
:param cls: The pipeline or transformer class being transferred.
:param config: The user-provided parallelism configuration.
:param ray_initialized_by_engine: ``True`` when cache-dit itself just called
Expand All @@ -94,19 +93,9 @@ def _warn_if_ray_pre_initialized(
f"cache_dit initialize Ray for you.")


def _require_ray():
try:
import ray
except ImportError as exc:
raise ImportError("Ray wrapper requires ray. Install it with `pip install ray` or "
"`pip install cache-dit[ray]`.") from exc
return ray


def _init_ray(ray: Any, **init_kwargs: Any) -> None:
def _init_ray(**init_kwargs: Any) -> None:
"""Initialize Ray while suppressing its accelerator override transition warning.

:param ray: Imported Ray module.
:param init_kwargs: Keyword arguments forwarded to ``ray.init``.
"""

Expand Down Expand Up @@ -287,7 +276,6 @@ def __init__(
cache_context_kwargs: dict[str, Any] | None = None,
quantize_config: QuantizeConfig | None = None,
):
self.ray = _require_ray()
self.parallelism_config = parallelism_config
self.cache_context_kwargs = cache_context_kwargs
self.quantize_config = quantize_config
Expand All @@ -305,17 +293,16 @@ def __init__(

self._ray_initialized_by_engine = False
self._transfer_dir: Path | None = None
if not self.ray.is_initialized():
if not ray.is_initialized():
init_kwargs = dict(parallelism_config.ray_init_kwargs)
if parallelism_config.ray_address is not None:
init_kwargs["address"] = parallelism_config.ray_address
if parallelism_config.ray_runtime_env is not None:
init_kwargs["runtime_env"] = parallelism_config.ray_runtime_env
_init_ray(self.ray, **init_kwargs)
_init_ray(**init_kwargs)
self._ray_initialized_by_engine = True

_warn_if_ray_pre_initialized(
self.ray,
transformer.__class__,
self.parallelism_config,
ray_initialized_by_engine=self._ray_initialized_by_engine,
Expand All @@ -329,7 +316,7 @@ def __init__(
self._actors = self._create_workers(transformer)

def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any]:
remote_worker = self.ray.remote(RayTransformerWorker)
remote_worker = ray.remote(RayTransformerWorker)
worker_options = {"num_gpus": 1}
worker_options.update(self.parallelism_config.ray_worker_options)
actors = [
Expand All @@ -342,8 +329,8 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any
self.master_port,
) for rank in range(self.world_size)
]
self.ray.get([actor.ready.remote() for actor in actors])
device_infos = self.ray.get([actor.device_info.remote() for actor in actors])
ray.get([actor.ready.remote() for actor in actors])
device_infos = ray.get([actor.device_info.remote() for actor in actors])
logger.info(f"Ray transformer worker placement before load: {device_infos}")

if self._source_device is not None and self._source_device.type != "cpu":
Expand Down Expand Up @@ -381,7 +368,7 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any
logger.info(f"Saved the transformer snapshot in "
f"{time.perf_counter() - save_start:.2f}s.")
tfmr_cls_ref = _qualified_class_name(transformer.__class__)
load_infos = self.ray.get([
load_infos = ray.get([
actor.load_transformer_from_pretrained.remote(
tfmr_cls_ref,
str(transformer_path),
Expand All @@ -399,7 +386,7 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any
logger.info(f"Saved the transformer safetensors file in "
f"{time.perf_counter() - save_start:.2f}s.")
tfmr_cls_ref = _qualified_class_name(transformer.__class__)
load_infos = self.ray.get([
load_infos = ray.get([
actor.load_transformer_from_safetensors.remote(
tfmr_cls_ref,
dict(transformer.config),
Expand All @@ -411,12 +398,11 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any
elif transfer_backend == "object_store":
logger.info("Putting the CPU transformer into the Ray object store.")
put_start = time.perf_counter()
transformer_ref = self.ray.put(transformer)
transformer_ref = ray.put(transformer)
logger.info(f"Put the CPU transformer into the Ray object store in "
f"{time.perf_counter() - put_start:.2f}s.")
load_start = time.perf_counter()
load_infos = self.ray.get(
[actor.load_transformer.remote(transformer_ref) for actor in actors])
load_infos = ray.get([actor.load_transformer.remote(transformer_ref) for actor in actors])
logger.info(f"Loaded the object-store transformer on Ray workers in "
f"{time.perf_counter() - load_start:.2f}s.")
else:
Expand All @@ -435,7 +421,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
output_device = first_tensor_device((args, kwargs))
cpu_args = cpu_tensor_tree(args)
cpu_kwargs = cpu_tensor_tree(kwargs)
results = self.ray.get([actor.forward.remote(cpu_args, cpu_kwargs) for actor in self._actors])
results = ray.get([actor.forward.remote(cpu_args, cpu_kwargs) for actor in self._actors])
rank0_output = next((result for result in results if result is not None), None)
if rank0_output is None:
raise RuntimeError("Ray transformer workers did not return a rank-0 output.")
Expand All @@ -447,12 +433,12 @@ def shutdown(self) -> None:
"""Shutdown worker actors and optionally the Ray runtime initialized by this engine."""

if self._actors:
self.ray.get([actor.shutdown.remote() for actor in self._actors])
ray.get([actor.shutdown.remote() for actor in self._actors])
for actor in self._actors:
self.ray.kill(actor)
ray.kill(actor)
self._actors = []
if self._ray_initialized_by_engine and self.parallelism_config.ray_auto_shutdown:
self.ray.shutdown()
ray.shutdown()
if self._source_device is not None and self._source_device.type != "cpu":
restore_start = time.perf_counter()
self._source_transformer.to(self._source_device)
Expand Down Expand Up @@ -482,7 +468,6 @@ def __init__(
if not hasattr(pipe, "transformer"):
raise ValueError("Ray pipeline parallelism requires a pipeline with a transformer "
"attribute.")
self.ray = _require_ray()
self.parallelism_config = parallelism_config
self.cache_context_kwargs = cache_context_kwargs
self.quantize_config = quantize_config
Expand All @@ -501,18 +486,17 @@ def __init__(

self._ray_initialized_by_engine = False
self._transfer_dir: Path | None = None
if not self.ray.is_initialized():
if not ray.is_initialized():
init_kwargs = dict(parallelism_config.ray_init_kwargs)
if parallelism_config.ray_address is not None:
init_kwargs["address"] = parallelism_config.ray_address
if parallelism_config.ray_runtime_env is not None:
init_kwargs["runtime_env"] = parallelism_config.ray_runtime_env
_init_ray(self.ray, **init_kwargs)
_init_ray(**init_kwargs)
self._ray_initialized_by_engine = True

if pipe is not None:
_warn_if_ray_pre_initialized(
self.ray,
pipe.__class__,
self.parallelism_config,
ray_initialized_by_engine=self._ray_initialized_by_engine,
Expand All @@ -527,7 +511,7 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
if transfer_fn is not None:
return self._create_workers_with_init_fn(pipe, transfer_fn)

remote_worker = self.ray.remote(RayPipelineWorker)
remote_worker = ray.remote(RayPipelineWorker)
worker_options = {"num_gpus": 1}
worker_options.update(self.parallelism_config.ray_worker_options)
actors = [
Expand All @@ -540,8 +524,8 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
self.master_port,
) for rank in range(self.world_size)
]
self.ray.get([actor.ready.remote() for actor in actors])
device_infos = self.ray.get([actor.device_info.remote() for actor in actors])
ray.get([actor.ready.remote() for actor in actors])
device_infos = ray.get([actor.device_info.remote() for actor in actors])
logger.info(f"Ray pipeline worker placement before load: {device_infos}")

if self._source_device is not None and self._source_device.type != "cpu":
Expand Down Expand Up @@ -590,7 +574,7 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
json.dump(custom_class_map, f, indent=2)

pipe_cls_ref = _qualified_class_name(pipe.__class__)
load_infos = self.ray.get([
load_infos = ray.get([
actor.load_pipeline_from_pretrained.remote(
pipe_cls_ref,
str(pipeline_path),
Expand All @@ -608,7 +592,7 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
pipe_cls_ref = _qualified_class_name(pipe.__class__)
custom_class_map = _build_custom_class_map(
pipe) if self.parallelism_config.ray_transfer_custom_obj else None
load_infos = self.ray.get([
load_infos = ray.get([
actor.load_pipeline_from_pretrained.remote(
pipe_cls_ref,
model_path,
Expand All @@ -622,10 +606,10 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
elif transfer_backend == "object_store":
logger.info("Putting the CPU pipeline into the Ray object store.")
put_start = time.perf_counter()
pipe_ref = self.ray.put(pipe)
pipe_ref = ray.put(pipe)
logger.info(f"Put the CPU pipeline into the Ray object store in "
f"{time.perf_counter() - put_start:.2f}s.")
load_infos = self.ray.get([actor.load_pipeline.remote(pipe_ref) for actor in actors])
load_infos = ray.get([actor.load_pipeline.remote(pipe_ref) for actor in actors])
logger.info(f"Loaded the object-store pipeline on Ray workers in "
f"{time.perf_counter() - load_start:.2f}s.")
else:
Expand Down Expand Up @@ -655,7 +639,7 @@ def _create_workers_with_init_fn(self, pipe: DiffusionPipeline, transfer_fn) ->
# attempts to cloudpickle the function as part of the actor constructor args.
worker_config = dataclasses.replace(self.parallelism_config, ray_transfer_fn=None)

remote_worker = self.ray.remote(RayPipelineWorker)
remote_worker = ray.remote(RayPipelineWorker)
worker_options = {"num_gpus": 1}
worker_options.update(worker_config.ray_worker_options)
actors = [
Expand All @@ -668,8 +652,8 @@ def _create_workers_with_init_fn(self, pipe: DiffusionPipeline, transfer_fn) ->
self.master_port,
) for rank in range(self.world_size)
]
self.ray.get([actor.ready.remote() for actor in actors])
device_infos = self.ray.get([actor.device_info.remote() for actor in actors])
ray.get([actor.ready.remote() for actor in actors])
device_infos = ray.get([actor.device_info.remote() for actor in actors])
logger.info(f"Ray pipeline worker placement before load: {device_infos}")

# Offload the main-process pipeline to CPU to free GPU memory for workers.
Expand All @@ -684,8 +668,8 @@ def _create_workers_with_init_fn(self, pipe: DiffusionPipeline, transfer_fn) ->
logger.info("The main-process pipeline is already on CPU before Ray worker loading.")

load_start = time.perf_counter()
fn_ref = self.ray.put(transfer_fn)
load_infos = self.ray.get([
fn_ref = ray.put(transfer_fn)
load_infos = ray.get([
actor.load_pipeline_with_init_fn.remote(
fn_ref,
self.cache_context_kwargs,
Expand All @@ -712,7 +696,7 @@ def call(self, *args: Any, **kwargs: Any) -> Any:

cpu_args = cpu_tensor_tree(args)
cpu_kwargs = cpu_tensor_tree(kwargs)
results = self.ray.get([actor.call.remote(cpu_args, cpu_kwargs) for actor in self._actors])
results = ray.get([actor.call.remote(cpu_args, cpu_kwargs) for actor in self._actors])
rank0_output = next((result for result in results if result is not None), None)
if rank0_output is None:
raise RuntimeError("Ray pipeline workers did not return a rank-0 output.")
Expand All @@ -722,12 +706,12 @@ def shutdown(self) -> None:
"""Shutdown worker actors and optionally the Ray runtime initialized by this engine."""

if self._actors:
self.ray.get([actor.shutdown.remote() for actor in self._actors])
ray.get([actor.shutdown.remote() for actor in self._actors])
for actor in self._actors:
self.ray.kill(actor)
ray.kill(actor)
self._actors = []
if self._ray_initialized_by_engine and self.parallelism_config.ray_auto_shutdown:
self.ray.shutdown()
ray.shutdown()
if self._source_pipe is not None and self._source_device is not None and self._source_device.type != "cpu":
restore_start = time.perf_counter()
_move_pipeline_modules(self._source_pipe, self._source_device)
Expand Down
8 changes: 6 additions & 2 deletions src/cache_dit/ray/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from ..distributed import ParallelismConfig
from ..logger import init_logger
from ..quantization import QuantizeConfig
from .engine import RayParallelEngine
from .engine import RayPipelineEngine

logger = init_logger(__name__)

Expand Down Expand Up @@ -51,6 +49,8 @@ def enable_ray_parallelism(
# We skip straight to step (1) and return the engine, which is callable via its
# own __call__ method. The callers then use engine(prompt=...) directly.
if pipe_or_adapter is None:
from .engine import RayPipelineEngine

engine = RayPipelineEngine(
None,
parallelism_config,
Expand Down Expand Up @@ -96,6 +96,8 @@ def enable_ray_module_parallelism(
logger.warning("Ray module parallelism is already enabled for this transformer. Skipping.")
return transformer

from .engine import RayParallelEngine

engine = RayParallelEngine(transformer, parallelism_config, cache_context_kwargs, quantize_config)
transformer._cache_dit_ray_original_forward = transformer.forward # type: ignore[attr-defined]
transformer._cache_dit_ray_original_to = transformer.to # type: ignore[attr-defined]
Expand Down Expand Up @@ -156,6 +158,8 @@ def enable_ray_pipeline_parallelism(
logger.warning("Ray parallelism is already enabled for this pipeline. Skipping.")
return pipe

from .engine import RayPipelineEngine

engine = RayPipelineEngine(pipe, parallelism_config, cache_context_kwargs, quantize_config)
original_class = pipe.__class__

Expand Down
Loading