Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions docs/user_guide/RAY.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,64 @@ cache_dit.enable_cache(
)
```

## Ray Runtime Env

When your application code (custom pipeline classes, helper modules) lives in a directory that is **not** part of the default Python path, Ray workers cannot import those modules. This causes ``ModuleNotFoundError`` during worker-side model loading or pickle deserialization.

Cache-DiT exposes `ParallelismConfig.ray_runtime_env` to pass a Ray `runtime_env` dictionary to `ray.init()`. The most common and **lightweight** pattern is to use `env_vars` to set `PYTHONPATH`:

```python
cache_dit.enable_cache(
pipe,
parallelism_config=ParallelismConfig(
ulysses_size=2,
use_ray=True,
ray_runtime_env={
# Single or multiple paths can be added, separated by ":" or " "
"env_vars": {
"PYTHONPATH": "/path/to/your/app/src",
# "PYTHONPATH": "xxx/path1:xxx/path2:xxx/path3",
}
},
),
)
```

### What is PYTHONPATH?

`PYTHONPATH` is a Python environment variable that specifies additional directories where Python searches for modules. When Python starts, it appends each `:`-separated path in `PYTHONPATH` to `sys.path`, before the standard-library and site-packages entries. Multiple paths are supported:

```python
"PYTHONPATH": "/app/src:/app/libs:/shared/utils"
```

Ray workers receive `env_vars` entries as process environment variables — **nothing is packaged, zipped, or uploaded**. Workers resolve `import your_module` directly from the filesystem path. `env_vars` vs `working_dir` / `py_modules` listed as below:

| field | behavior | large files (>2 GB) |
|-------|----------|---------------------|
| `env_vars.PYTHONPATH` | sets an environment variable; **zero packaging** | works |
| `working_dir` | zips the entire directory, uploads to workers | **fails** — Ray packaging times out or errors |
| `py_modules` | packages the module tree, uploads to workers | **fails** — same packaging issue |

> **Rule of thumb:** when your module tree contains model weights (e.g. a 54 GB
> `FLUX.1-dev` checkpoint), always use `env_vars.PYTHONPATH`. The files must be
> accessible from every worker node via a shared filesystem (NFS, Lustre, or
> — on a single-node cluster — the local disk).

### Pre-initialized Ray

If you already called `ray.init()` before `cache_dit.enable_cache()`, Cache-DiT will **not** re-initialize Ray and therefore cannot inject the `runtime_env` from `ParallelismConfig`. In that case, configure `runtime_env` in your own `ray.init()` call:

```python
ray.init(runtime_env={"env_vars": {"PYTHONPATH": "/path/to/your/app/src"}})
cache_dit.enable_cache(
pipe,
parallelism_config=ParallelismConfig(ulysses_size=2, use_ray=True),
)
```

Cache-DiT emits a warning when it detects this situation and the pipeline class is from a user-application module.

## Quick Start

A complete runnable example is available at `examples/ray/ray_wrapper_example.py`. For example:
Expand Down
92 changes: 88 additions & 4 deletions src/cache_dit/ray/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,72 @@
logger = init_logger(__name__)


def _qualified_class_name(cls: type) -> str:
"""Encode a class as ``module:qualname`` to avoid pickle serialization of class objects.

Passing a raw class object as a Ray task argument forces pickle to serialize
the entire inheritance chain, which fails on workers when the class (or any
base class) is defined in a user-application module that the workers cannot
import. A string reference defers the import to the worker side, where it can
be resolved via ``importlib`` after the application's ``runtime_env`` has been
applied.
"""

return f"{cls.__module__}:{cls.__qualname__}"


def _maybe_user_module(cls: type) -> bool:
"""Return ``True`` when *cls* is likely from user-application code.

Classes from ``diffusers``, ``transformers``, ``torch``, or ``cache_dit`` are
expected to be importable on every Ray worker, so no warning is needed.
"""

_well_known = ("diffusers", "transformers", "torch", "cache_dit", "builtins")
return not cls.__module__.startswith(_well_known)


def _warn_if_ray_pre_initialized(
ray: Any,
cls: type,
config: ParallelismConfig,
ray_initialized_by_engine: bool = False,
) -> None:
"""Warn when Ray was initialized *before* cache-dit and the transferred class is from a user-
application module.

When Ray is already running, cache-dit cannot inject ``runtime_env`` — the
user's ``ray.init()`` is responsible for making the application code available
to every worker. If that was not configured, worker-side pickle deserialization
(or ``from_pretrained`` class imports) will fail with ``ModuleNotFoundError``.

:param ray: Imported Ray module.
:param cls: The pipeline or transformer class being transferred.
:param config: The user-provided parallelism configuration.
:param ray_initialized_by_engine: ``True`` when cache-dit itself just called
``ray.init()``. In that case the ``runtime_env`` from *config* was already
applied and no warning is needed.
"""

if not ray.is_initialized():
return
if ray_initialized_by_engine:
return
if not _maybe_user_module(cls):
return

extra = ""
if config.ray_runtime_env is not None:
extra = (" (ParallelismConfig.ray_runtime_env is set but will NOT be applied "
"because Ray was already initialized)")
logger.warning(f"Ray is already initialized and {cls.__name__} is defined in "
f"'{cls.__module__}', which may not be importable on Ray workers.{extra} "
f"Ensure your ray.init() call includes a runtime_env that exposes this "
f"module, e.g. ray.init(runtime_env={{'working_dir': '...'}}). "
f"Alternatively, configure ParallelismConfig.ray_runtime_env and let "
f"cache_dit initialize Ray for you.")


def _require_ray():
try:
import ray
Expand Down Expand Up @@ -167,6 +233,13 @@ def __init__(
_init_ray(self.ray, **init_kwargs)
self._ray_initialized_by_engine = True

_warn_if_ray_pre_initialized(
self.ray,
transformer.__class__,
self.parallelism_config,
ray_initialized_by_engine=self._ray_initialized_by_engine,
)

self.master_port = parallelism_config.ray_master_port or _get_free_port()
self._actors = self._create_workers(transformer)

Expand Down Expand Up @@ -222,9 +295,10 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any
)
logger.info(f"Saved the transformer snapshot in "
f"{time.perf_counter() - save_start:.2f}s.")
tfmr_cls_ref = _qualified_class_name(transformer.__class__)
load_infos = self.ray.get([
actor.load_transformer_from_pretrained.remote(
transformer.__class__,
tfmr_cls_ref,
str(transformer_path),
_first_parameter_dtype(transformer),
self.parallelism_config.ray_use_flashpack,
Expand All @@ -239,9 +313,10 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any
_save_state_dict_safetensors(transformer, transformer_path)
logger.info(f"Saved the transformer safetensors file in "
f"{time.perf_counter() - save_start:.2f}s.")
tfmr_cls_ref = _qualified_class_name(transformer.__class__)
load_infos = self.ray.get([
actor.load_transformer_from_safetensors.remote(
transformer.__class__,
tfmr_cls_ref,
dict(transformer.config),
str(transformer_path),
) for actor in actors
Expand Down Expand Up @@ -349,6 +424,13 @@ def __init__(
_init_ray(self.ray, **init_kwargs)
self._ray_initialized_by_engine = True

_warn_if_ray_pre_initialized(
self.ray,
pipe.__class__,
self.parallelism_config,
ray_initialized_by_engine=self._ray_initialized_by_engine,
)

self.master_port = parallelism_config.ray_master_port or _get_free_port()
self._actors = self._create_workers(pipe)

Expand Down Expand Up @@ -399,9 +481,10 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
use_flashpack=self.parallelism_config.ray_use_flashpack,
)
logger.info(f"Saved the pipeline snapshot in {time.perf_counter() - save_start:.2f}s.")
pipe_cls_ref = _qualified_class_name(pipe.__class__)
load_infos = self.ray.get([
actor.load_pipeline_from_pretrained.remote(
pipe.__class__,
pipe_cls_ref,
str(pipeline_path),
self._source_dtype,
self.parallelism_config.ray_use_flashpack,
Expand All @@ -413,9 +496,10 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]:
if model_path is None:
raise ValueError("ray_transfer_backend='from_pretrained' requires pipeline.name_or_path.")
logger.info(f"Loading Ray worker pipelines from pretrained source: {model_path}.")
pipe_cls_ref = _qualified_class_name(pipe.__class__)
load_infos = self.ray.get([
actor.load_pipeline_from_pretrained.remote(
pipe.__class__,
pipe_cls_ref,
model_path,
self._source_dtype,
self.parallelism_config.ray_use_flashpack,
Expand Down
50 changes: 42 additions & 8 deletions src/cache_dit/ray/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import importlib
import os
from typing import Any

Expand All @@ -22,6 +23,35 @@
logger = init_logger(__name__)


def _resolve_class(class_ref: str) -> type:
"""Resolve a ``module:qualname`` string returned by :func:`_qualified_class_name`.

The format mirrors what ``copyreg._reduce_ex`` stores for a class, so it is
round-trip-safe for standard and nested classes alike.

:param class_ref: Encoded class reference produced by ``_qualified_class_name``.
:returns: The resolved Python class object.
:raises ModuleNotFoundError: The module portion of *class_ref* cannot be imported.
:raises AttributeError: The qualname portion of *class_ref* is missing from the module.
:raises TypeError: *class_ref* is not a string (e.g. a raw class object from an
older cache-dit that did not use ``_qualified_class_name``).
"""

if not isinstance(class_ref, str):
raise TypeError(f"_resolve_class expected a 'module:qualname' string, got "
f"{type(class_ref).__name__} ({class_ref!r}). This usually means "
f"the calling code passed a raw class object instead of using "
f"_qualified_class_name().")

module_name, _, qualname = class_ref.rpartition(":")
module = importlib.import_module(module_name)
# qualname may contain dots for nested classes, e.g. "Outer.Inner"
obj = module
for part in qualname.split("."):
obj = getattr(obj, part)
return obj


def _maybe_compile_transformer(
transformer: torch.nn.Module | ModelMixin,
parallelism_config: ParallelismConfig,
Expand Down Expand Up @@ -189,13 +219,13 @@ def load_transformer_from_file(self, path: str) -> dict[str, Any]:

def load_transformer_from_safetensors(
self,
transformer_cls: type[ModelMixin],
transformer_cls_ref: str,
transformer_config: dict[str, Any],
path: str,
) -> dict[str, Any]:
"""Load a diffusers transformer from a safetensors state dict.

:param transformer_cls: Diffusers transformer class used to reconstruct the module.
:param transformer_cls_ref: ``module:qualname`` string encoding the transformer class.
:param transformer_config: Serialized transformer config passed to ``from_config``.
:param path: Path to a safetensors state dict written by the Ray engine.
:returns: Device placement and memory information after loading.
Expand All @@ -204,9 +234,11 @@ def load_transformer_from_safetensors(
try:
from safetensors.torch import load_file
except ImportError as exc:
raise ImportError("Ray safetensors transfer requires `safetensors`. Install with "
"`pip install cache-dit[ray]` or `pip install safetensors`.") from exc
raise ImportError(
"Ray safetensors transfer requires `safetensors`. Install with "
"`pip install cache-dit[ray,parallelism]` or `pip install safetensors`.") from exc

transformer_cls = _resolve_class(transformer_cls_ref)
with torch.device("meta"):
transformer = transformer_cls.from_config(transformer_config)
state_dict = load_file(path, device=str(self.device))
Expand All @@ -221,20 +253,21 @@ def load_transformer_from_safetensors(

def load_transformer_from_pretrained(
self,
transformer_cls: type[ModelMixin],
transformer_cls_ref: str,
model_path: str,
torch_dtype: torch.dtype | None,
use_flashpack: bool,
) -> dict[str, Any]:
"""Load a diffusers transformer snapshot inside this actor.

:param transformer_cls: Diffusers transformer class used to reload the module.
:param transformer_cls_ref: ``module:qualname`` string encoding the transformer class.
:param model_path: Local snapshot directory written by the Ray engine.
:param torch_dtype: Optional dtype for model loading.
:param use_flashpack: Whether to prefer FlashPack weights during loading.
:returns: Device placement and memory information after loading.
"""

transformer_cls = _resolve_class(transformer_cls_ref)
load_kwargs = {
"use_safetensors": True,
"use_flashpack": use_flashpack,
Expand Down Expand Up @@ -356,20 +389,21 @@ def load_pipeline(self, pipe: DiffusionPipeline) -> dict[str, Any]:

def load_pipeline_from_pretrained(
self,
pipe_cls: type[DiffusionPipeline],
pipe_cls_ref: str,
model_path: str,
torch_dtype: torch.dtype | None,
use_flashpack: bool,
) -> dict[str, Any]:
"""Load a pipeline from its pretrained directory inside this actor.

:param pipe_cls: Diffusers pipeline class used to reconstruct the pipeline.
:param pipe_cls_ref: ``module:qualname`` string encoding the pipeline class.
:param model_path: Local path or model id passed to ``from_pretrained``.
:param torch_dtype: Optional dtype for model loading.
:param use_flashpack: Whether to prefer FlashPack weights during loading.
:returns: Device placement and memory information after loading.
"""

pipe_cls = _resolve_class(pipe_cls_ref)
load_kwargs = {}
if torch_dtype is not None:
load_kwargs["torch_dtype"] = torch_dtype
Expand Down
Loading