diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py index c70dbdaa5ee1..9792069fe3cb 100644 --- a/colossalai/utils/device.py +++ b/colossalai/utils/device.py @@ -45,12 +45,32 @@ def get_current_device() -> torch.device: def _dispatch_device_func(fn_name: str, *args, **kwargs): - if torch.cuda.is_available(): + if "device" in kwargs: # if device is specified, try to use the provided one + device = kwargs["device"] + del kwargs["device"] + if 'cuda' in device and torch.cuda.is_available(): + device = "cuda" + elif 'npu' in device and IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + else: # if device is not specified, device will be automatically detected + if torch.cuda.is_available(): + device = "cuda" + elif IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + + if device == "cuda": return getattr(torch.cuda, fn_name)(*args, **kwargs) - elif IS_NPU_AVAILABLE: + elif device == "npu": return getattr(torch.npu, fn_name)(*args, **kwargs) - else: - raise RuntimeError("No device available") + else: + try: + return getattr(torch, fn_name)(*args, **kwargs) + except AttributeError: + raise RuntimeError(f"Current device does not support the function: {fn_name}") # device semantics @@ -114,7 +134,12 @@ def utilization(device=None) -> int: def get_rng_state(device="cuda") -> torch.Tensor: - return _dispatch_device_func("get_rng_state", device) + if torch.cuda.is_available() and device=="cuda": + return _dispatch_device_func("get_rng_state", device="cuda") + elif IS_NPU_AVAILABLE and device=="npu": + return _dispatch_device_func("get_rng_state", device="npu") + else: + return _dispatch_device_func("get_rng_state", device="cpu") def get_rng_state_all() -> List[torch.Tensor]: @@ -122,7 +147,12 @@ def get_rng_state_all() -> List[torch.Tensor]: def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: - return _dispatch_device_func("set_rng_state", new_state, device) + if torch.cuda.is_available() and device=="cuda": + return _dispatch_device_func("set_rng_state", new_state, device="cuda") + elif IS_NPU_AVAILABLE and device=="npu": + return _dispatch_device_func("set_rng_state", new_state, device="npu") + else: + return _dispatch_device_func("set_rng_state", new_state, device="cpu") def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: