Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
Expand Down
11 changes: 3 additions & 8 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,7 @@ def get_total_memory(dev=None, torch_total_too=False):
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
if torch_version_numeric < (2, 6):
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
else:
_, mem_total_xpu = torch.xpu.mem_get_info(dev)
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved
mem_total = mem_total_xpu
elif is_ascend_npu():
Expand Down Expand Up @@ -880,6 +877,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return d

# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d

Expand Down Expand Up @@ -1109,10 +1107,7 @@ def get_free_memory(dev=None, torch_free_too=False):
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
if torch_version_numeric < (2, 6):
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
else:
mem_free_xpu, _ = torch.xpu.mem_get_info(dev)
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
Expand Down
9 changes: 9 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def execute_script(script_path):
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

if __name__ == "__main__":
if args.default_device is not None:
default_dev = args.default_device
devices = list(range(32))
devices.remove(default_dev)
devices.insert(0, default_dev)
devices = ','.join(map(str, devices))
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)

if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
Expand Down
Loading