diff --git a/src/cache_dit/caching/cache_adapters/cache_adapter.py b/src/cache_dit/caching/cache_adapters/cache_adapter.py index d9d9a1ea..543a76f8 100644 --- a/src/cache_dit/caching/cache_adapters/cache_adapter.py +++ b/src/cache_dit/caching/cache_adapters/cache_adapter.py @@ -532,9 +532,9 @@ def _release_blocks_hooks(blocks): return def _release_transformer_hooks(transformer): - from ...ray import disable_ray_parallelism + from ...ray import disable_ray_module_parallelism - disable_ray_parallelism(transformer) + disable_ray_module_parallelism(transformer) if hasattr(transformer, "_original_forward"): original_forward = transformer._original_forward transformer.forward = original_forward.__get__(transformer) diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py index 59e2e3d0..cf7f9dd9 100644 --- a/src/cache_dit/caching/cache_interface.py +++ b/src/cache_dit/caching/cache_interface.py @@ -89,10 +89,8 @@ def enable_cache( >>> output = pipe(...) >>> stats = cache_dit.summary(pipe) """ - ray_enabled = isinstance( - parallelism_config, - ParallelismConfig, - ) and parallelism_config.use_ray + ray_enabled = isinstance(parallelism_config, ParallelismConfig) and parallelism_config.use_ray + if not ray_enabled and isinstance( parallelism_config, ParallelismConfig) and parallelism_config.ray_transfer_fn is not None: logger.warning("ray_transfer_fn is set but use_ray=False; the function will be ignored. " @@ -101,7 +99,12 @@ def enable_cache( if not (ray_enabled and parallelism_config.ray_transfer_fn is not None): raise ValueError( "pipe_or_adapter can only be None when use_ray=True and ray_transfer_fn is set.") + if ray_enabled: + # BlockAdapter is not supported currently in Ray wrapper. + assert not isinstance(pipe_or_adapter, BlockAdapter), \ + "BlockAdapter is not supported in Ray wrapper currently." + return _enable_cache_with_ray_impl( pipe_or_adapter, cache_config=cache_config, @@ -112,6 +115,7 @@ def enable_cache( quantize_config=quantize_config, **kwargs, ) + return _enable_cache_impl( pipe_or_adapter, cache_config=cache_config, @@ -127,7 +131,6 @@ def enable_cache( def _enable_cache_with_ray_impl( pipe_or_adapter: Union[ DiffusionPipeline, - BlockAdapter, torch.nn.Module, ModelMixin, None, @@ -150,7 +153,6 @@ def _enable_cache_with_ray_impl( DiffusionPipeline, torch.nn.Module, ModelMixin, - BlockAdapter, ]: """Internal: Ray-path passthrough — packages configs and delegates to Ray wrappers. @@ -158,7 +160,6 @@ def _enable_cache_with_ray_impl( attention backend resolution, parallelism tuning, quantization) is deferred to :func:`_enable_cache_impl` running inside each Ray worker. """ - if attention_backend is not None: parallelism_config.attention_backend = attention_backend @@ -176,39 +177,7 @@ def _enable_cache_with_ray_impl( ray_qconfig = copy.deepcopy(quantize_config) if quantize_config is not None else None from ..ray import enable_ray_parallelism - from ..ray import enable_ray_pipeline_parallelism - - # pipe_or_adapter=None path: create engine directly, no pipe to patch. - # enable_ray_pipeline_parallelism does two things: (1) create RayPipelineEngine, - # (2) monkey-patch pipe.__class__ so pipe(...) proxies to engine.call(...). - # When there is no pipe, step (2) is meaningless — there is nothing to patch. - # We skip straight to step (1) and return the engine, which is callable via its - # own __call__ method. The callers then use engine(prompt=...) directly. - if pipe_or_adapter is None: - from ..ray.engine import RayPipelineEngine - - engine = RayPipelineEngine( - None, - parallelism_config, - cache_context_kwargs=ray_kwargs, - quantize_config=ray_qconfig, - ) - return engine - - # BlockAdapter is not supported currently in Ray wrapper. - assert not isinstance(pipe_or_adapter, - BlockAdapter), "BlockAdapter is not supported in Ray wrapper currently." - - # Case 1: DiffusionPipeline - if isinstance(pipe_or_adapter, DiffusionPipeline): - return enable_ray_pipeline_parallelism( - pipe_or_adapter, - parallelism_config, - cache_context_kwargs=ray_kwargs, - quantize_config=ray_qconfig, - ) - # Case 2: Transformer return enable_ray_parallelism( pipe_or_adapter, parallelism_config, diff --git a/src/cache_dit/ray/__init__.py b/src/cache_dit/ray/__init__.py index 2c9e157f..2988eec2 100644 --- a/src/cache_dit/ray/__init__.py +++ b/src/cache_dit/ray/__init__.py @@ -1,11 +1,13 @@ -from .wrapper import disable_ray_parallelism +from .wrapper import disable_ray_module_parallelism from .wrapper import disable_ray_pipeline_parallelism +from .wrapper import enable_ray_module_parallelism from .wrapper import enable_ray_parallelism from .wrapper import enable_ray_pipeline_parallelism __all__ = [ - "disable_ray_parallelism", + "disable_ray_module_parallelism", "disable_ray_pipeline_parallelism", + "enable_ray_module_parallelism", "enable_ray_parallelism", "enable_ray_pipeline_parallelism", ] diff --git a/src/cache_dit/ray/wrapper.py b/src/cache_dit/ray/wrapper.py index 8c6efa3d..16f65c1c 100644 --- a/src/cache_dit/ray/wrapper.py +++ b/src/cache_dit/ray/wrapper.py @@ -17,6 +17,67 @@ def enable_ray_parallelism( + pipe_or_adapter: (DiffusionPipeline | torch.nn.Module | ModelMixin | None), + parallelism_config: ParallelismConfig, + cache_context_kwargs: dict[str, Any] | None = None, + quantize_config: QuantizeConfig | None = None, +) -> DiffusionPipeline | torch.nn.Module | ModelMixin: + """Enable Ray parallelism for a pipeline, adapter, or module, dispatching by type. + + This is the single public entry point for Ray-parallel execution in cache-dit. It routes + to the appropriate backend based on the type of ``pipe_or_adapter``: + + * ``None`` — creates a :class:`RayPipelineEngine` directly (for init-fn flows where no + pipeline exists in the main process). + * :class:`DiffusionPipeline` — wraps the whole pipeline via + :func:`enable_ray_pipeline_parallelism`. + * transformer :class:`torch.nn.Module` / :class:`ModelMixin` — wraps the module via + :func:`enable_ray_module_parallelism`. + + :class:`BlockAdapter` is not supported and will raise an :class:`AssertionError`. + + :param pipe_or_adapter: Pipeline, adapter, transformer module, or ``None``. + :param parallelism_config: Ray-enabled parallelism configuration. + :param cache_context_kwargs: Optional cache context keyword arguments to apply inside Ray + workers. + :param quantize_config: Optional quantization configuration to apply inside Ray workers. + :returns: The wrapped object with proxied forward / ``__call__``. + """ + + # pipe_or_adapter=None path: create engine directly, no pipe to patch. + # enable_ray_pipeline_parallelism does two things: (1) create RayPipelineEngine, + # (2) monkey-patch pipe.__class__ so pipe(...) proxies to engine.call(...). + # When there is no pipe, step (2) is meaningless — there is nothing to patch. + # We skip straight to step (1) and return the engine, which is callable via its + # own __call__ method. The callers then use engine(prompt=...) directly. + if pipe_or_adapter is None: + engine = RayPipelineEngine( + None, + parallelism_config, + cache_context_kwargs=cache_context_kwargs, + quantize_config=quantize_config, + ) + return engine + + # Case 1: DiffusionPipeline + if isinstance(pipe_or_adapter, DiffusionPipeline): + return enable_ray_pipeline_parallelism( + pipe_or_adapter, + parallelism_config, + cache_context_kwargs=cache_context_kwargs, + quantize_config=quantize_config, + ) + + # Case 2: Transformer / module + return enable_ray_module_parallelism( + pipe_or_adapter, + parallelism_config, + cache_context_kwargs=cache_context_kwargs, + quantize_config=quantize_config, + ) + + +def enable_ray_module_parallelism( transformer: torch.nn.Module | ModelMixin, parallelism_config: ParallelismConfig, cache_context_kwargs: dict[str, Any] | None = None, @@ -32,7 +93,7 @@ def enable_ray_parallelism( """ if getattr(transformer, "_cache_dit_ray_enabled", False): - logger.warning("Ray parallelism is already enabled for this transformer. Skipping.") + logger.warning("Ray module parallelism is already enabled for this transformer. Skipping.") return transformer engine = RayParallelEngine(transformer, parallelism_config, cache_context_kwargs, quantize_config) @@ -51,13 +112,13 @@ def ray_to(self, *args, **kwargs): transformer.forward = types.MethodType(ray_forward, transformer) transformer.to = types.MethodType(ray_to, transformer) transformer._cache_dit_ray_enabled = True # type: ignore[attr-defined] - logger.info(f"Enabled Ray parallelism for {transformer.__class__.__name__} with " + logger.info(f"Enabled Ray module parallelism for {transformer.__class__.__name__} with " f"world_size={engine.world_size}.") return transformer -def disable_ray_parallelism(transformer: torch.nn.Module | ModelMixin) -> None: - """Restore a transformer patched by :func:`enable_ray_parallelism`. +def disable_ray_module_parallelism(transformer: torch.nn.Module | ModelMixin) -> None: + """Restore a transformer patched by :func:`enable_ray_module_parallelism`. :param transformer: Transformer module that may own a Ray engine. """ diff --git a/tests/ray/test_ray_wrapper.py b/tests/ray/test_ray_wrapper.py index 3683f0d6..933a4d79 100644 --- a/tests/ray/test_ray_wrapper.py +++ b/tests/ray/test_ray_wrapper.py @@ -259,16 +259,16 @@ def fail_main_process_cache_apply(*args, **kwargs): raise AssertionError("Ray mode should not apply cache hooks in the main process.") def fake_enable_ray_parallelism( - transformer_arg, + pipe_or_adapter, parallelism_config_arg, cache_context_kwargs=None, quantize_config=None, ): - captured["transformer"] = transformer_arg + captured["pipe_or_adapter"] = pipe_or_adapter captured["parallelism_config"] = parallelism_config_arg captured["cache_context_kwargs"] = cache_context_kwargs captured["quantize_config"] = quantize_config - return transformer_arg + return pipe_or_adapter monkeypatch.setattr( "cache_dit.caching.cache_interface.CachedAdapter.apply", @@ -286,7 +286,7 @@ def fake_enable_ray_parallelism( ) assert returned is transformer - assert captured["transformer"] is transformer + assert captured["pipe_or_adapter"] is transformer assert captured["parallelism_config"] is parallelism_config cache_context_kwargs = captured["cache_context_kwargs"] assert isinstance(cache_context_kwargs, dict) @@ -304,16 +304,16 @@ def fail_main_process_quantize(*args, **kwargs): raise AssertionError("Ray mode should not quantize in the main process.") def fake_enable_ray_parallelism( - transformer_arg, + pipe_or_adapter, parallelism_config_arg, cache_context_kwargs=None, quantize_config=None, ): - captured["transformer"] = transformer_arg + captured["pipe_or_adapter"] = pipe_or_adapter captured["parallelism_config"] = parallelism_config_arg captured["cache_context_kwargs"] = cache_context_kwargs captured["quantize_config"] = quantize_config - return transformer_arg + return pipe_or_adapter monkeypatch.setattr( "cache_dit.caching.cache_interface.quantize", @@ -331,7 +331,7 @@ def fake_enable_ray_parallelism( ) assert returned is transformer - assert captured["transformer"] is transformer + assert captured["pipe_or_adapter"] is transformer assert captured["parallelism_config"] is parallelism_config assert captured["cache_context_kwargs"] is None assert isinstance(captured["quantize_config"], QuantizeConfig)