Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/cache_dit/caching/cache_adapters/cache_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,9 @@ def _release_blocks_hooks(blocks):
return

def _release_transformer_hooks(transformer):
from ...ray import disable_ray_parallelism
from ...ray import disable_ray_module_parallelism

disable_ray_parallelism(transformer)
disable_ray_module_parallelism(transformer)
if hasattr(transformer, "_original_forward"):
original_forward = transformer._original_forward
transformer.forward = original_forward.__get__(transformer)
Expand Down
47 changes: 8 additions & 39 deletions src/cache_dit/caching/cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ def enable_cache(
>>> output = pipe(...)
>>> stats = cache_dit.summary(pipe)
"""
ray_enabled = isinstance(
parallelism_config,
ParallelismConfig,
) and parallelism_config.use_ray
ray_enabled = isinstance(parallelism_config, ParallelismConfig) and parallelism_config.use_ray

if not ray_enabled and isinstance(
parallelism_config, ParallelismConfig) and parallelism_config.ray_transfer_fn is not None:
logger.warning("ray_transfer_fn is set but use_ray=False; the function will be ignored. "
Expand All @@ -101,7 +99,12 @@ def enable_cache(
if not (ray_enabled and parallelism_config.ray_transfer_fn is not None):
raise ValueError(
"pipe_or_adapter can only be None when use_ray=True and ray_transfer_fn is set.")

if ray_enabled:
# BlockAdapter is not supported currently in Ray wrapper.
assert not isinstance(pipe_or_adapter, BlockAdapter), \
"BlockAdapter is not supported in Ray wrapper currently."

return _enable_cache_with_ray_impl(
pipe_or_adapter,
cache_config=cache_config,
Expand All @@ -112,6 +115,7 @@ def enable_cache(
quantize_config=quantize_config,
**kwargs,
)

return _enable_cache_impl(
pipe_or_adapter,
cache_config=cache_config,
Expand All @@ -127,7 +131,6 @@ def enable_cache(
def _enable_cache_with_ray_impl(
pipe_or_adapter: Union[
DiffusionPipeline,
BlockAdapter,
torch.nn.Module,
ModelMixin,
None,
Expand All @@ -150,15 +153,13 @@ def _enable_cache_with_ray_impl(
DiffusionPipeline,
torch.nn.Module,
ModelMixin,
BlockAdapter,
]:
"""Internal: Ray-path passthrough — packages configs and delegates to Ray wrappers.

All preprocessing (deprecated params, default configs, calibrator setup,
attention backend resolution, parallelism tuning, quantization) is deferred to
:func:`_enable_cache_impl` running inside each Ray worker.
"""

if attention_backend is not None:
parallelism_config.attention_backend = attention_backend

Expand All @@ -176,39 +177,7 @@ def _enable_cache_with_ray_impl(
ray_qconfig = copy.deepcopy(quantize_config) if quantize_config is not None else None

from ..ray import enable_ray_parallelism
from ..ray import enable_ray_pipeline_parallelism

# pipe_or_adapter=None path: create engine directly, no pipe to patch.
# enable_ray_pipeline_parallelism does two things: (1) create RayPipelineEngine,
# (2) monkey-patch pipe.__class__ so pipe(...) proxies to engine.call(...).
# When there is no pipe, step (2) is meaningless — there is nothing to patch.
# We skip straight to step (1) and return the engine, which is callable via its
# own __call__ method. The callers then use engine(prompt=...) directly.
if pipe_or_adapter is None:
from ..ray.engine import RayPipelineEngine

engine = RayPipelineEngine(
None,
parallelism_config,
cache_context_kwargs=ray_kwargs,
quantize_config=ray_qconfig,
)
return engine

# BlockAdapter is not supported currently in Ray wrapper.
assert not isinstance(pipe_or_adapter,
BlockAdapter), "BlockAdapter is not supported in Ray wrapper currently."

# Case 1: DiffusionPipeline
if isinstance(pipe_or_adapter, DiffusionPipeline):
return enable_ray_pipeline_parallelism(
pipe_or_adapter,
parallelism_config,
cache_context_kwargs=ray_kwargs,
quantize_config=ray_qconfig,
)

# Case 2: Transformer
return enable_ray_parallelism(
pipe_or_adapter,
parallelism_config,
Expand Down
6 changes: 4 additions & 2 deletions src/cache_dit/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .wrapper import disable_ray_parallelism
from .wrapper import disable_ray_module_parallelism
from .wrapper import disable_ray_pipeline_parallelism
from .wrapper import enable_ray_module_parallelism
from .wrapper import enable_ray_parallelism
from .wrapper import enable_ray_pipeline_parallelism

__all__ = [
"disable_ray_parallelism",
"disable_ray_module_parallelism",
"disable_ray_pipeline_parallelism",
"enable_ray_module_parallelism",
"enable_ray_parallelism",
"enable_ray_pipeline_parallelism",
]
69 changes: 65 additions & 4 deletions src/cache_dit/ray/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,67 @@


def enable_ray_parallelism(
pipe_or_adapter: (DiffusionPipeline | torch.nn.Module | ModelMixin | None),
parallelism_config: ParallelismConfig,
cache_context_kwargs: dict[str, Any] | None = None,
quantize_config: QuantizeConfig | None = None,
) -> DiffusionPipeline | torch.nn.Module | ModelMixin:
"""Enable Ray parallelism for a pipeline, adapter, or module, dispatching by type.

This is the single public entry point for Ray-parallel execution in cache-dit. It routes
to the appropriate backend based on the type of ``pipe_or_adapter``:

* ``None`` — creates a :class:`RayPipelineEngine` directly (for init-fn flows where no
pipeline exists in the main process).
* :class:`DiffusionPipeline` — wraps the whole pipeline via
:func:`enable_ray_pipeline_parallelism`.
* transformer :class:`torch.nn.Module` / :class:`ModelMixin` — wraps the module via
:func:`enable_ray_module_parallelism`.

:class:`BlockAdapter` is not supported and will raise an :class:`AssertionError`.

:param pipe_or_adapter: Pipeline, adapter, transformer module, or ``None``.
:param parallelism_config: Ray-enabled parallelism configuration.
:param cache_context_kwargs: Optional cache context keyword arguments to apply inside Ray
workers.
:param quantize_config: Optional quantization configuration to apply inside Ray workers.
:returns: The wrapped object with proxied forward / ``__call__``.
"""

# pipe_or_adapter=None path: create engine directly, no pipe to patch.
# enable_ray_pipeline_parallelism does two things: (1) create RayPipelineEngine,
# (2) monkey-patch pipe.__class__ so pipe(...) proxies to engine.call(...).
# When there is no pipe, step (2) is meaningless — there is nothing to patch.
# We skip straight to step (1) and return the engine, which is callable via its
# own __call__ method. The callers then use engine(prompt=...) directly.
if pipe_or_adapter is None:
engine = RayPipelineEngine(
None,
parallelism_config,
cache_context_kwargs=cache_context_kwargs,
quantize_config=quantize_config,
)
return engine

# Case 1: DiffusionPipeline
if isinstance(pipe_or_adapter, DiffusionPipeline):
return enable_ray_pipeline_parallelism(
pipe_or_adapter,
parallelism_config,
cache_context_kwargs=cache_context_kwargs,
quantize_config=quantize_config,
)

# Case 2: Transformer / module
return enable_ray_module_parallelism(
pipe_or_adapter,
parallelism_config,
cache_context_kwargs=cache_context_kwargs,
quantize_config=quantize_config,
)


def enable_ray_module_parallelism(
transformer: torch.nn.Module | ModelMixin,
parallelism_config: ParallelismConfig,
cache_context_kwargs: dict[str, Any] | None = None,
Expand All @@ -32,7 +93,7 @@ def enable_ray_parallelism(
"""

if getattr(transformer, "_cache_dit_ray_enabled", False):
logger.warning("Ray parallelism is already enabled for this transformer. Skipping.")
logger.warning("Ray module parallelism is already enabled for this transformer. Skipping.")
return transformer

engine = RayParallelEngine(transformer, parallelism_config, cache_context_kwargs, quantize_config)
Expand All @@ -51,13 +112,13 @@ def ray_to(self, *args, **kwargs):
transformer.forward = types.MethodType(ray_forward, transformer)
transformer.to = types.MethodType(ray_to, transformer)
transformer._cache_dit_ray_enabled = True # type: ignore[attr-defined]
logger.info(f"Enabled Ray parallelism for {transformer.__class__.__name__} with "
logger.info(f"Enabled Ray module parallelism for {transformer.__class__.__name__} with "
f"world_size={engine.world_size}.")
return transformer


def disable_ray_parallelism(transformer: torch.nn.Module | ModelMixin) -> None:
"""Restore a transformer patched by :func:`enable_ray_parallelism`.
def disable_ray_module_parallelism(transformer: torch.nn.Module | ModelMixin) -> None:
"""Restore a transformer patched by :func:`enable_ray_module_parallelism`.

:param transformer: Transformer module that may own a Ray engine.
"""
Expand Down
16 changes: 8 additions & 8 deletions tests/ray/test_ray_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,16 @@ def fail_main_process_cache_apply(*args, **kwargs):
raise AssertionError("Ray mode should not apply cache hooks in the main process.")

def fake_enable_ray_parallelism(
transformer_arg,
pipe_or_adapter,
parallelism_config_arg,
cache_context_kwargs=None,
quantize_config=None,
):
captured["transformer"] = transformer_arg
captured["pipe_or_adapter"] = pipe_or_adapter
captured["parallelism_config"] = parallelism_config_arg
captured["cache_context_kwargs"] = cache_context_kwargs
captured["quantize_config"] = quantize_config
return transformer_arg
return pipe_or_adapter

monkeypatch.setattr(
"cache_dit.caching.cache_interface.CachedAdapter.apply",
Expand All @@ -286,7 +286,7 @@ def fake_enable_ray_parallelism(
)

assert returned is transformer
assert captured["transformer"] is transformer
assert captured["pipe_or_adapter"] is transformer
assert captured["parallelism_config"] is parallelism_config
cache_context_kwargs = captured["cache_context_kwargs"]
assert isinstance(cache_context_kwargs, dict)
Expand All @@ -304,16 +304,16 @@ def fail_main_process_quantize(*args, **kwargs):
raise AssertionError("Ray mode should not quantize in the main process.")

def fake_enable_ray_parallelism(
transformer_arg,
pipe_or_adapter,
parallelism_config_arg,
cache_context_kwargs=None,
quantize_config=None,
):
captured["transformer"] = transformer_arg
captured["pipe_or_adapter"] = pipe_or_adapter
captured["parallelism_config"] = parallelism_config_arg
captured["cache_context_kwargs"] = cache_context_kwargs
captured["quantize_config"] = quantize_config
return transformer_arg
return pipe_or_adapter

monkeypatch.setattr(
"cache_dit.caching.cache_interface.quantize",
Expand All @@ -331,7 +331,7 @@ def fake_enable_ray_parallelism(
)

assert returned is transformer
assert captured["transformer"] is transformer
assert captured["pipe_or_adapter"] is transformer
assert captured["parallelism_config"] is parallelism_config
assert captured["cache_context_kwargs"] is None
assert isinstance(captured["quantize_config"], QuantizeConfig)
Expand Down
Loading