From 52e408743c2eaea898bae819631a546dd1a60995 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 11 May 2026 11:04:44 +0000 Subject: [PATCH] ray: pass runtime env to workers --- docs/user_guide/RAY.md | 58 +++++++++++++++++++++++ src/cache_dit/ray/engine.py | 92 +++++++++++++++++++++++++++++++++++-- src/cache_dit/ray/worker.py | 50 ++++++++++++++++---- 3 files changed, 188 insertions(+), 12 deletions(-) diff --git a/docs/user_guide/RAY.md b/docs/user_guide/RAY.md index 43fa7601..e63de732 100644 --- a/docs/user_guide/RAY.md +++ b/docs/user_guide/RAY.md @@ -153,6 +153,64 @@ cache_dit.enable_cache( ) ``` +## Ray Runtime Env + +When your application code (custom pipeline classes, helper modules) lives in a directory that is **not** part of the default Python path, Ray workers cannot import those modules. This causes ``ModuleNotFoundError`` during worker-side model loading or pickle deserialization. + +Cache-DiT exposes `ParallelismConfig.ray_runtime_env` to pass a Ray `runtime_env` dictionary to `ray.init()`. The most common and **lightweight** pattern is to use `env_vars` to set `PYTHONPATH`: + +```python +cache_dit.enable_cache( + pipe, + parallelism_config=ParallelismConfig( + ulysses_size=2, + use_ray=True, + ray_runtime_env={ + # Single or multiple paths can be added, separated by ":" or " " + "env_vars": { + "PYTHONPATH": "/path/to/your/app/src", + # "PYTHONPATH": "xxx/path1:xxx/path2:xxx/path3", + } + }, + ), +) +``` + +### What is PYTHONPATH? + +`PYTHONPATH` is a Python environment variable that specifies additional directories where Python searches for modules. When Python starts, it appends each `:`-separated path in `PYTHONPATH` to `sys.path`, before the standard-library and site-packages entries. Multiple paths are supported: + +```python +"PYTHONPATH": "/app/src:/app/libs:/shared/utils" +``` + +Ray workers receive `env_vars` entries as process environment variables — **nothing is packaged, zipped, or uploaded**. Workers resolve `import your_module` directly from the filesystem path. `env_vars` vs `working_dir` / `py_modules` listed as below: + +| field | behavior | large files (>2 GB) | +|-------|----------|---------------------| +| `env_vars.PYTHONPATH` | sets an environment variable; **zero packaging** | works | +| `working_dir` | zips the entire directory, uploads to workers | **fails** — Ray packaging times out or errors | +| `py_modules` | packages the module tree, uploads to workers | **fails** — same packaging issue | + +> **Rule of thumb:** when your module tree contains model weights (e.g. a 54 GB +> `FLUX.1-dev` checkpoint), always use `env_vars.PYTHONPATH`. The files must be +> accessible from every worker node via a shared filesystem (NFS, Lustre, or +> — on a single-node cluster — the local disk). + +### Pre-initialized Ray + +If you already called `ray.init()` before `cache_dit.enable_cache()`, Cache-DiT will **not** re-initialize Ray and therefore cannot inject the `runtime_env` from `ParallelismConfig`. In that case, configure `runtime_env` in your own `ray.init()` call: + +```python +ray.init(runtime_env={"env_vars": {"PYTHONPATH": "/path/to/your/app/src"}}) +cache_dit.enable_cache( + pipe, + parallelism_config=ParallelismConfig(ulysses_size=2, use_ray=True), +) +``` + +Cache-DiT emits a warning when it detects this situation and the pipeline class is from a user-application module. + ## Quick Start A complete runnable example is available at `examples/ray/ray_wrapper_example.py`. For example: diff --git a/src/cache_dit/ray/engine.py b/src/cache_dit/ray/engine.py index 99427b67..a83d9bda 100644 --- a/src/cache_dit/ray/engine.py +++ b/src/cache_dit/ray/engine.py @@ -25,6 +25,72 @@ logger = init_logger(__name__) +def _qualified_class_name(cls: type) -> str: + """Encode a class as ``module:qualname`` to avoid pickle serialization of class objects. + + Passing a raw class object as a Ray task argument forces pickle to serialize + the entire inheritance chain, which fails on workers when the class (or any + base class) is defined in a user-application module that the workers cannot + import. A string reference defers the import to the worker side, where it can + be resolved via ``importlib`` after the application's ``runtime_env`` has been + applied. + """ + + return f"{cls.__module__}:{cls.__qualname__}" + + +def _maybe_user_module(cls: type) -> bool: + """Return ``True`` when *cls* is likely from user-application code. + + Classes from ``diffusers``, ``transformers``, ``torch``, or ``cache_dit`` are + expected to be importable on every Ray worker, so no warning is needed. + """ + + _well_known = ("diffusers", "transformers", "torch", "cache_dit", "builtins") + return not cls.__module__.startswith(_well_known) + + +def _warn_if_ray_pre_initialized( + ray: Any, + cls: type, + config: ParallelismConfig, + ray_initialized_by_engine: bool = False, +) -> None: + """Warn when Ray was initialized *before* cache-dit and the transferred class is from a user- + application module. + + When Ray is already running, cache-dit cannot inject ``runtime_env`` — the + user's ``ray.init()`` is responsible for making the application code available + to every worker. If that was not configured, worker-side pickle deserialization + (or ``from_pretrained`` class imports) will fail with ``ModuleNotFoundError``. + + :param ray: Imported Ray module. + :param cls: The pipeline or transformer class being transferred. + :param config: The user-provided parallelism configuration. + :param ray_initialized_by_engine: ``True`` when cache-dit itself just called + ``ray.init()``. In that case the ``runtime_env`` from *config* was already + applied and no warning is needed. + """ + + if not ray.is_initialized(): + return + if ray_initialized_by_engine: + return + if not _maybe_user_module(cls): + return + + extra = "" + if config.ray_runtime_env is not None: + extra = (" (ParallelismConfig.ray_runtime_env is set but will NOT be applied " + "because Ray was already initialized)") + logger.warning(f"Ray is already initialized and {cls.__name__} is defined in " + f"'{cls.__module__}', which may not be importable on Ray workers.{extra} " + f"Ensure your ray.init() call includes a runtime_env that exposes this " + f"module, e.g. ray.init(runtime_env={{'working_dir': '...'}}). " + f"Alternatively, configure ParallelismConfig.ray_runtime_env and let " + f"cache_dit initialize Ray for you.") + + def _require_ray(): try: import ray @@ -167,6 +233,13 @@ def __init__( _init_ray(self.ray, **init_kwargs) self._ray_initialized_by_engine = True + _warn_if_ray_pre_initialized( + self.ray, + transformer.__class__, + self.parallelism_config, + ray_initialized_by_engine=self._ray_initialized_by_engine, + ) + self.master_port = parallelism_config.ray_master_port or _get_free_port() self._actors = self._create_workers(transformer) @@ -222,9 +295,10 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any ) logger.info(f"Saved the transformer snapshot in " f"{time.perf_counter() - save_start:.2f}s.") + tfmr_cls_ref = _qualified_class_name(transformer.__class__) load_infos = self.ray.get([ actor.load_transformer_from_pretrained.remote( - transformer.__class__, + tfmr_cls_ref, str(transformer_path), _first_parameter_dtype(transformer), self.parallelism_config.ray_use_flashpack, @@ -239,9 +313,10 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any _save_state_dict_safetensors(transformer, transformer_path) logger.info(f"Saved the transformer safetensors file in " f"{time.perf_counter() - save_start:.2f}s.") + tfmr_cls_ref = _qualified_class_name(transformer.__class__) load_infos = self.ray.get([ actor.load_transformer_from_safetensors.remote( - transformer.__class__, + tfmr_cls_ref, dict(transformer.config), str(transformer_path), ) for actor in actors @@ -349,6 +424,13 @@ def __init__( _init_ray(self.ray, **init_kwargs) self._ray_initialized_by_engine = True + _warn_if_ray_pre_initialized( + self.ray, + pipe.__class__, + self.parallelism_config, + ray_initialized_by_engine=self._ray_initialized_by_engine, + ) + self.master_port = parallelism_config.ray_master_port or _get_free_port() self._actors = self._create_workers(pipe) @@ -399,9 +481,10 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: use_flashpack=self.parallelism_config.ray_use_flashpack, ) logger.info(f"Saved the pipeline snapshot in {time.perf_counter() - save_start:.2f}s.") + pipe_cls_ref = _qualified_class_name(pipe.__class__) load_infos = self.ray.get([ actor.load_pipeline_from_pretrained.remote( - pipe.__class__, + pipe_cls_ref, str(pipeline_path), self._source_dtype, self.parallelism_config.ray_use_flashpack, @@ -413,9 +496,10 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: if model_path is None: raise ValueError("ray_transfer_backend='from_pretrained' requires pipeline.name_or_path.") logger.info(f"Loading Ray worker pipelines from pretrained source: {model_path}.") + pipe_cls_ref = _qualified_class_name(pipe.__class__) load_infos = self.ray.get([ actor.load_pipeline_from_pretrained.remote( - pipe.__class__, + pipe_cls_ref, model_path, self._source_dtype, self.parallelism_config.ray_use_flashpack, diff --git a/src/cache_dit/ray/worker.py b/src/cache_dit/ray/worker.py index f8913afa..abe58278 100644 --- a/src/cache_dit/ray/worker.py +++ b/src/cache_dit/ray/worker.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import importlib import os from typing import Any @@ -22,6 +23,35 @@ logger = init_logger(__name__) +def _resolve_class(class_ref: str) -> type: + """Resolve a ``module:qualname`` string returned by :func:`_qualified_class_name`. + + The format mirrors what ``copyreg._reduce_ex`` stores for a class, so it is + round-trip-safe for standard and nested classes alike. + + :param class_ref: Encoded class reference produced by ``_qualified_class_name``. + :returns: The resolved Python class object. + :raises ModuleNotFoundError: The module portion of *class_ref* cannot be imported. + :raises AttributeError: The qualname portion of *class_ref* is missing from the module. + :raises TypeError: *class_ref* is not a string (e.g. a raw class object from an + older cache-dit that did not use ``_qualified_class_name``). + """ + + if not isinstance(class_ref, str): + raise TypeError(f"_resolve_class expected a 'module:qualname' string, got " + f"{type(class_ref).__name__} ({class_ref!r}). This usually means " + f"the calling code passed a raw class object instead of using " + f"_qualified_class_name().") + + module_name, _, qualname = class_ref.rpartition(":") + module = importlib.import_module(module_name) + # qualname may contain dots for nested classes, e.g. "Outer.Inner" + obj = module + for part in qualname.split("."): + obj = getattr(obj, part) + return obj + + def _maybe_compile_transformer( transformer: torch.nn.Module | ModelMixin, parallelism_config: ParallelismConfig, @@ -189,13 +219,13 @@ def load_transformer_from_file(self, path: str) -> dict[str, Any]: def load_transformer_from_safetensors( self, - transformer_cls: type[ModelMixin], + transformer_cls_ref: str, transformer_config: dict[str, Any], path: str, ) -> dict[str, Any]: """Load a diffusers transformer from a safetensors state dict. - :param transformer_cls: Diffusers transformer class used to reconstruct the module. + :param transformer_cls_ref: ``module:qualname`` string encoding the transformer class. :param transformer_config: Serialized transformer config passed to ``from_config``. :param path: Path to a safetensors state dict written by the Ray engine. :returns: Device placement and memory information after loading. @@ -204,9 +234,11 @@ def load_transformer_from_safetensors( try: from safetensors.torch import load_file except ImportError as exc: - raise ImportError("Ray safetensors transfer requires `safetensors`. Install with " - "`pip install cache-dit[ray]` or `pip install safetensors`.") from exc + raise ImportError( + "Ray safetensors transfer requires `safetensors`. Install with " + "`pip install cache-dit[ray,parallelism]` or `pip install safetensors`.") from exc + transformer_cls = _resolve_class(transformer_cls_ref) with torch.device("meta"): transformer = transformer_cls.from_config(transformer_config) state_dict = load_file(path, device=str(self.device)) @@ -221,20 +253,21 @@ def load_transformer_from_safetensors( def load_transformer_from_pretrained( self, - transformer_cls: type[ModelMixin], + transformer_cls_ref: str, model_path: str, torch_dtype: torch.dtype | None, use_flashpack: bool, ) -> dict[str, Any]: """Load a diffusers transformer snapshot inside this actor. - :param transformer_cls: Diffusers transformer class used to reload the module. + :param transformer_cls_ref: ``module:qualname`` string encoding the transformer class. :param model_path: Local snapshot directory written by the Ray engine. :param torch_dtype: Optional dtype for model loading. :param use_flashpack: Whether to prefer FlashPack weights during loading. :returns: Device placement and memory information after loading. """ + transformer_cls = _resolve_class(transformer_cls_ref) load_kwargs = { "use_safetensors": True, "use_flashpack": use_flashpack, @@ -356,20 +389,21 @@ def load_pipeline(self, pipe: DiffusionPipeline) -> dict[str, Any]: def load_pipeline_from_pretrained( self, - pipe_cls: type[DiffusionPipeline], + pipe_cls_ref: str, model_path: str, torch_dtype: torch.dtype | None, use_flashpack: bool, ) -> dict[str, Any]: """Load a pipeline from its pretrained directory inside this actor. - :param pipe_cls: Diffusers pipeline class used to reconstruct the pipeline. + :param pipe_cls_ref: ``module:qualname`` string encoding the pipeline class. :param model_path: Local path or model id passed to ``from_pretrained``. :param torch_dtype: Optional dtype for model loading. :param use_flashpack: Whether to prefer FlashPack weights during loading. :returns: Device placement and memory information after loading. """ + pipe_cls = _resolve_class(pipe_cls_ref) load_kwargs = {} if torch_dtype is not None: load_kwargs["torch_dtype"] = torch_dtype