forked from nikopueringer/CorridorKey
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdevice_utils.py
More file actions
76 lines (59 loc) · 2.42 KB
/
device_utils.py
File metadata and controls
76 lines (59 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Centralized cross-platform device selection for CorridorKey."""
import logging
import os
import torch
logger = logging.getLogger(__name__)
DEVICE_ENV_VAR = "CORRIDORKEY_DEVICE"
VALID_DEVICES = ("auto", "cuda", "mps", "cpu")
def detect_best_device() -> str:
"""Auto-detect best available device: CUDA > MPS > CPU."""
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
logger.info("Auto-selected device: %s", device)
return device
def resolve_device(requested: str | None = None) -> str:
"""Resolve device from explicit request > env var > auto-detect.
Args:
requested: Device string from CLI arg. None or "auto" triggers
env var lookup then auto-detection.
Returns:
Validated device string ("cuda", "mps", or "cpu").
Raises:
RuntimeError: If the requested backend is unavailable.
"""
# CLI arg takes priority, then env var, then auto
device = requested
if device is None or device == "auto":
device = os.environ.get(DEVICE_ENV_VAR, "auto")
if device == "auto":
return detect_best_device()
device = device.lower()
if device not in VALID_DEVICES:
raise RuntimeError(f"Unknown device '{device}'. Valid options: {', '.join(VALID_DEVICES)}")
# Validate the explicit request
if device == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA requested but torch.cuda.is_available() is False. Install a CUDA-enabled PyTorch build."
)
elif device == "mps":
if not hasattr(torch.backends, "mps"):
raise RuntimeError(
"MPS requested but this PyTorch build has no MPS support. Install PyTorch >= 1.12 with MPS backend."
)
if not torch.backends.mps.is_available():
raise RuntimeError(
"MPS requested but not available on this machine. Requires Apple Silicon (M1+) with macOS 12.3+."
)
return device
def clear_device_cache(device: torch.device | str) -> None:
"""Clear GPU memory cache if applicable (no-op for CPU)."""
device_type = device.type if isinstance(device, torch.device) else device
if device_type == "cuda":
torch.cuda.empty_cache()
elif device_type == "mps":
torch.mps.empty_cache()