diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py index cf7f9dd9..1e610cfd 100644 --- a/src/cache_dit/caching/cache_interface.py +++ b/src/cache_dit/caching/cache_interface.py @@ -160,6 +160,13 @@ def _enable_cache_with_ray_impl( attention backend resolution, parallelism tuning, quantization) is deferred to :func:`_enable_cache_impl` running inside each Ray worker. """ + # Soft check: only trigger the ray import when entering the Ray path. + try: + import ray # noqa: F401 + except ImportError as exc: + raise ImportError("Ray wrapper requires ray. Install it with `pip install ray` or " + "`pip install cache-dit[ray]`.") from exc + if attention_backend is not None: parallelism_config.attention_backend = attention_backend diff --git a/src/cache_dit/ray/engine.py b/src/cache_dit/ray/engine.py index 50b3d2a0..b4a7585e 100644 --- a/src/cache_dit/ray/engine.py +++ b/src/cache_dit/ray/engine.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any +import ray import torch from diffusers import DiffusionPipeline from diffusers import pipelines as diffusers_pipelines @@ -54,7 +55,6 @@ def _maybe_user_module(cls: type) -> bool: def _warn_if_ray_pre_initialized( - ray: Any, cls: type, config: ParallelismConfig, ray_initialized_by_engine: bool = False, @@ -67,7 +67,6 @@ def _warn_if_ray_pre_initialized( to every worker. If that was not configured, worker-side pickle deserialization (or ``from_pretrained`` class imports) will fail with ``ModuleNotFoundError``. - :param ray: Imported Ray module. :param cls: The pipeline or transformer class being transferred. :param config: The user-provided parallelism configuration. :param ray_initialized_by_engine: ``True`` when cache-dit itself just called @@ -94,19 +93,9 @@ def _warn_if_ray_pre_initialized( f"cache_dit initialize Ray for you.") -def _require_ray(): - try: - import ray - except ImportError as exc: - raise ImportError("Ray wrapper requires ray. Install it with `pip install ray` or " - "`pip install cache-dit[ray]`.") from exc - return ray - - -def _init_ray(ray: Any, **init_kwargs: Any) -> None: +def _init_ray(**init_kwargs: Any) -> None: """Initialize Ray while suppressing its accelerator override transition warning. - :param ray: Imported Ray module. :param init_kwargs: Keyword arguments forwarded to ``ray.init``. """ @@ -287,7 +276,6 @@ def __init__( cache_context_kwargs: dict[str, Any] | None = None, quantize_config: QuantizeConfig | None = None, ): - self.ray = _require_ray() self.parallelism_config = parallelism_config self.cache_context_kwargs = cache_context_kwargs self.quantize_config = quantize_config @@ -305,17 +293,16 @@ def __init__( self._ray_initialized_by_engine = False self._transfer_dir: Path | None = None - if not self.ray.is_initialized(): + if not ray.is_initialized(): init_kwargs = dict(parallelism_config.ray_init_kwargs) if parallelism_config.ray_address is not None: init_kwargs["address"] = parallelism_config.ray_address if parallelism_config.ray_runtime_env is not None: init_kwargs["runtime_env"] = parallelism_config.ray_runtime_env - _init_ray(self.ray, **init_kwargs) + _init_ray(**init_kwargs) self._ray_initialized_by_engine = True _warn_if_ray_pre_initialized( - self.ray, transformer.__class__, self.parallelism_config, ray_initialized_by_engine=self._ray_initialized_by_engine, @@ -329,7 +316,7 @@ def __init__( self._actors = self._create_workers(transformer) def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any]: - remote_worker = self.ray.remote(RayTransformerWorker) + remote_worker = ray.remote(RayTransformerWorker) worker_options = {"num_gpus": 1} worker_options.update(self.parallelism_config.ray_worker_options) actors = [ @@ -342,8 +329,8 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any self.master_port, ) for rank in range(self.world_size) ] - self.ray.get([actor.ready.remote() for actor in actors]) - device_infos = self.ray.get([actor.device_info.remote() for actor in actors]) + ray.get([actor.ready.remote() for actor in actors]) + device_infos = ray.get([actor.device_info.remote() for actor in actors]) logger.info(f"Ray transformer worker placement before load: {device_infos}") if self._source_device is not None and self._source_device.type != "cpu": @@ -381,7 +368,7 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any logger.info(f"Saved the transformer snapshot in " f"{time.perf_counter() - save_start:.2f}s.") tfmr_cls_ref = _qualified_class_name(transformer.__class__) - load_infos = self.ray.get([ + load_infos = ray.get([ actor.load_transformer_from_pretrained.remote( tfmr_cls_ref, str(transformer_path), @@ -399,7 +386,7 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any logger.info(f"Saved the transformer safetensors file in " f"{time.perf_counter() - save_start:.2f}s.") tfmr_cls_ref = _qualified_class_name(transformer.__class__) - load_infos = self.ray.get([ + load_infos = ray.get([ actor.load_transformer_from_safetensors.remote( tfmr_cls_ref, dict(transformer.config), @@ -411,12 +398,11 @@ def _create_workers(self, transformer: torch.nn.Module | ModelMixin) -> list[Any elif transfer_backend == "object_store": logger.info("Putting the CPU transformer into the Ray object store.") put_start = time.perf_counter() - transformer_ref = self.ray.put(transformer) + transformer_ref = ray.put(transformer) logger.info(f"Put the CPU transformer into the Ray object store in " f"{time.perf_counter() - put_start:.2f}s.") load_start = time.perf_counter() - load_infos = self.ray.get( - [actor.load_transformer.remote(transformer_ref) for actor in actors]) + load_infos = ray.get([actor.load_transformer.remote(transformer_ref) for actor in actors]) logger.info(f"Loaded the object-store transformer on Ray workers in " f"{time.perf_counter() - load_start:.2f}s.") else: @@ -435,7 +421,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: output_device = first_tensor_device((args, kwargs)) cpu_args = cpu_tensor_tree(args) cpu_kwargs = cpu_tensor_tree(kwargs) - results = self.ray.get([actor.forward.remote(cpu_args, cpu_kwargs) for actor in self._actors]) + results = ray.get([actor.forward.remote(cpu_args, cpu_kwargs) for actor in self._actors]) rank0_output = next((result for result in results if result is not None), None) if rank0_output is None: raise RuntimeError("Ray transformer workers did not return a rank-0 output.") @@ -447,12 +433,12 @@ def shutdown(self) -> None: """Shutdown worker actors and optionally the Ray runtime initialized by this engine.""" if self._actors: - self.ray.get([actor.shutdown.remote() for actor in self._actors]) + ray.get([actor.shutdown.remote() for actor in self._actors]) for actor in self._actors: - self.ray.kill(actor) + ray.kill(actor) self._actors = [] if self._ray_initialized_by_engine and self.parallelism_config.ray_auto_shutdown: - self.ray.shutdown() + ray.shutdown() if self._source_device is not None and self._source_device.type != "cpu": restore_start = time.perf_counter() self._source_transformer.to(self._source_device) @@ -482,7 +468,6 @@ def __init__( if not hasattr(pipe, "transformer"): raise ValueError("Ray pipeline parallelism requires a pipeline with a transformer " "attribute.") - self.ray = _require_ray() self.parallelism_config = parallelism_config self.cache_context_kwargs = cache_context_kwargs self.quantize_config = quantize_config @@ -501,18 +486,17 @@ def __init__( self._ray_initialized_by_engine = False self._transfer_dir: Path | None = None - if not self.ray.is_initialized(): + if not ray.is_initialized(): init_kwargs = dict(parallelism_config.ray_init_kwargs) if parallelism_config.ray_address is not None: init_kwargs["address"] = parallelism_config.ray_address if parallelism_config.ray_runtime_env is not None: init_kwargs["runtime_env"] = parallelism_config.ray_runtime_env - _init_ray(self.ray, **init_kwargs) + _init_ray(**init_kwargs) self._ray_initialized_by_engine = True if pipe is not None: _warn_if_ray_pre_initialized( - self.ray, pipe.__class__, self.parallelism_config, ray_initialized_by_engine=self._ray_initialized_by_engine, @@ -527,7 +511,7 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: if transfer_fn is not None: return self._create_workers_with_init_fn(pipe, transfer_fn) - remote_worker = self.ray.remote(RayPipelineWorker) + remote_worker = ray.remote(RayPipelineWorker) worker_options = {"num_gpus": 1} worker_options.update(self.parallelism_config.ray_worker_options) actors = [ @@ -540,8 +524,8 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: self.master_port, ) for rank in range(self.world_size) ] - self.ray.get([actor.ready.remote() for actor in actors]) - device_infos = self.ray.get([actor.device_info.remote() for actor in actors]) + ray.get([actor.ready.remote() for actor in actors]) + device_infos = ray.get([actor.device_info.remote() for actor in actors]) logger.info(f"Ray pipeline worker placement before load: {device_infos}") if self._source_device is not None and self._source_device.type != "cpu": @@ -590,7 +574,7 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: json.dump(custom_class_map, f, indent=2) pipe_cls_ref = _qualified_class_name(pipe.__class__) - load_infos = self.ray.get([ + load_infos = ray.get([ actor.load_pipeline_from_pretrained.remote( pipe_cls_ref, str(pipeline_path), @@ -608,7 +592,7 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: pipe_cls_ref = _qualified_class_name(pipe.__class__) custom_class_map = _build_custom_class_map( pipe) if self.parallelism_config.ray_transfer_custom_obj else None - load_infos = self.ray.get([ + load_infos = ray.get([ actor.load_pipeline_from_pretrained.remote( pipe_cls_ref, model_path, @@ -622,10 +606,10 @@ def _create_workers(self, pipe: DiffusionPipeline) -> list[Any]: elif transfer_backend == "object_store": logger.info("Putting the CPU pipeline into the Ray object store.") put_start = time.perf_counter() - pipe_ref = self.ray.put(pipe) + pipe_ref = ray.put(pipe) logger.info(f"Put the CPU pipeline into the Ray object store in " f"{time.perf_counter() - put_start:.2f}s.") - load_infos = self.ray.get([actor.load_pipeline.remote(pipe_ref) for actor in actors]) + load_infos = ray.get([actor.load_pipeline.remote(pipe_ref) for actor in actors]) logger.info(f"Loaded the object-store pipeline on Ray workers in " f"{time.perf_counter() - load_start:.2f}s.") else: @@ -655,7 +639,7 @@ def _create_workers_with_init_fn(self, pipe: DiffusionPipeline, transfer_fn) -> # attempts to cloudpickle the function as part of the actor constructor args. worker_config = dataclasses.replace(self.parallelism_config, ray_transfer_fn=None) - remote_worker = self.ray.remote(RayPipelineWorker) + remote_worker = ray.remote(RayPipelineWorker) worker_options = {"num_gpus": 1} worker_options.update(worker_config.ray_worker_options) actors = [ @@ -668,8 +652,8 @@ def _create_workers_with_init_fn(self, pipe: DiffusionPipeline, transfer_fn) -> self.master_port, ) for rank in range(self.world_size) ] - self.ray.get([actor.ready.remote() for actor in actors]) - device_infos = self.ray.get([actor.device_info.remote() for actor in actors]) + ray.get([actor.ready.remote() for actor in actors]) + device_infos = ray.get([actor.device_info.remote() for actor in actors]) logger.info(f"Ray pipeline worker placement before load: {device_infos}") # Offload the main-process pipeline to CPU to free GPU memory for workers. @@ -684,8 +668,8 @@ def _create_workers_with_init_fn(self, pipe: DiffusionPipeline, transfer_fn) -> logger.info("The main-process pipeline is already on CPU before Ray worker loading.") load_start = time.perf_counter() - fn_ref = self.ray.put(transfer_fn) - load_infos = self.ray.get([ + fn_ref = ray.put(transfer_fn) + load_infos = ray.get([ actor.load_pipeline_with_init_fn.remote( fn_ref, self.cache_context_kwargs, @@ -712,7 +696,7 @@ def call(self, *args: Any, **kwargs: Any) -> Any: cpu_args = cpu_tensor_tree(args) cpu_kwargs = cpu_tensor_tree(kwargs) - results = self.ray.get([actor.call.remote(cpu_args, cpu_kwargs) for actor in self._actors]) + results = ray.get([actor.call.remote(cpu_args, cpu_kwargs) for actor in self._actors]) rank0_output = next((result for result in results if result is not None), None) if rank0_output is None: raise RuntimeError("Ray pipeline workers did not return a rank-0 output.") @@ -722,12 +706,12 @@ def shutdown(self) -> None: """Shutdown worker actors and optionally the Ray runtime initialized by this engine.""" if self._actors: - self.ray.get([actor.shutdown.remote() for actor in self._actors]) + ray.get([actor.shutdown.remote() for actor in self._actors]) for actor in self._actors: - self.ray.kill(actor) + ray.kill(actor) self._actors = [] if self._ray_initialized_by_engine and self.parallelism_config.ray_auto_shutdown: - self.ray.shutdown() + ray.shutdown() if self._source_pipe is not None and self._source_device is not None and self._source_device.type != "cpu": restore_start = time.perf_counter() _move_pipeline_modules(self._source_pipe, self._source_device) diff --git a/src/cache_dit/ray/wrapper.py b/src/cache_dit/ray/wrapper.py index 16f65c1c..65768d79 100644 --- a/src/cache_dit/ray/wrapper.py +++ b/src/cache_dit/ray/wrapper.py @@ -10,8 +10,6 @@ from ..distributed import ParallelismConfig from ..logger import init_logger from ..quantization import QuantizeConfig -from .engine import RayParallelEngine -from .engine import RayPipelineEngine logger = init_logger(__name__) @@ -51,6 +49,8 @@ def enable_ray_parallelism( # We skip straight to step (1) and return the engine, which is callable via its # own __call__ method. The callers then use engine(prompt=...) directly. if pipe_or_adapter is None: + from .engine import RayPipelineEngine + engine = RayPipelineEngine( None, parallelism_config, @@ -96,6 +96,8 @@ def enable_ray_module_parallelism( logger.warning("Ray module parallelism is already enabled for this transformer. Skipping.") return transformer + from .engine import RayParallelEngine + engine = RayParallelEngine(transformer, parallelism_config, cache_context_kwargs, quantize_config) transformer._cache_dit_ray_original_forward = transformer.forward # type: ignore[attr-defined] transformer._cache_dit_ray_original_to = transformer.to # type: ignore[attr-defined] @@ -156,6 +158,8 @@ def enable_ray_pipeline_parallelism( logger.warning("Ray parallelism is already enabled for this pipeline. Skipping.") return pipe + from .engine import RayPipelineEngine + engine = RayPipelineEngine(pipe, parallelism_config, cache_context_kwargs, quantize_config) original_class = pipe.__class__