Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_supported_float8_types():
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.

try:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # noqa: F401
_ = torch.xpu.device_count()
xpu_available = xpu_available or torch.xpu.is_available()
except:
Expand Down Expand Up @@ -186,8 +186,12 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
if torch_version_numeric < (2, 6):
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
else:
_, mem_total_xpu = torch.xpu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total = mem_total_xpu
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
Expand Down Expand Up @@ -929,7 +933,7 @@ def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
if is_intel_xpu():
return False
return True
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
Expand Down Expand Up @@ -968,6 +972,8 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
Expand All @@ -979,13 +985,24 @@ def get_offload_stream(device):
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None

def sync_stream(device, stream):
if stream is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)

def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
Expand Down Expand Up @@ -1092,8 +1109,11 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
if torch_version_numeric < (2, 6):
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
else:
mem_free_xpu, _ = torch.xpu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
Expand Down Expand Up @@ -1142,6 +1162,9 @@ def is_device_cpu(device):
def is_device_mps(device):
return is_device_type(device, 'mps')

def is_device_xpu(device):
return is_device_type(device, 'xpu')

def is_device_cuda(device):
return is_device_type(device, 'cuda')

Expand Down Expand Up @@ -1173,7 +1196,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False

if is_intel_xpu():
return True
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_properties(device).has_fp16

if is_ascend_npu():
return True
Expand Down Expand Up @@ -1236,7 +1262,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False

if is_intel_xpu():
return True
if torch_version_numeric < (2, 6):
return True
else:
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']

if is_ascend_npu():
return True
Expand Down
Loading