From 3aa19491c0e04be934a397424cf0fd950ff8cf59 Mon Sep 17 00:00:00 2001 From: Santiago Lezica Date: Wed, 25 Jun 2025 22:10:59 -0300 Subject: [PATCH 1/2] Fix arcface loading for non-CUDA devices --- nodes.py | 9 ++++----- utils.py | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index b87a569..1a41a62 100644 --- a/nodes.py +++ b/nodes.py @@ -23,10 +23,9 @@ import shutil import glob -from facexlib.recognition import init_recognition_model from insightface.app import FaceAnalysis -from .utils import extract_arcface_bgr_embedding, tensor_to_np_image, np_image_to_tensor, resize_and_pad_pil_image, draw_kps, escape_path_for_url +from .utils import extract_arcface_bgr_embedding, tensor_to_np_image, np_image_to_tensor, resize_and_pad_pil_image, draw_kps, escape_path_for_url, init_arcface_model from .infuse_net import load_infuse_net_flux from .resampler import Resampler @@ -110,7 +109,7 @@ def load_insightface(self, image_proj_model_name, image_proj_num_tokens, face_an device = comfy.model_management.get_torch_device() # Load arcface model - arcface_model = init_recognition_model('arcface', device=device) + arcface_model = init_arcface_model(device=device) # Load image proj model image_emb_dim = 512 @@ -197,8 +196,8 @@ def extract_id_embedding(self, face_detector, arcface_model, image_proj_model, i face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face landmark = face_info['kps'] - id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, arcface_model) - id_embed = id_embed.clone().unsqueeze(0).float().to(device) + id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, arcface_model, device=device) + id_embed = id_embed.clone().unsqueeze(0).float() id_embed = id_embed.reshape([1, -1, 512]) id_embed = id_embed.to(device=device, dtype=torch.bfloat16) with torch.no_grad(): diff --git a/utils.py b/utils.py index 4eef5ad..12425e4 100644 --- a/utils.py +++ b/utils.py @@ -15,16 +15,30 @@ import torch import numpy as np from insightface.utils import face_align +from facexlib.utils import load_file_from_url +from facexlib.recognition import arcface_arch from PIL import Image import math import cv2 -def extract_arcface_bgr_embedding(in_image, landmark, arcface_model, in_settings=None): +ARCFACE_MODEL_URL = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth' + +def init_arcface_model(device='cuda'): + model = arcface_arch.Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to(device).eval() + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth' + + model_path = load_file_from_url(url=ARCFACE_MODEL_URL, model_dir='facexlib/weights') + model.load_state_dict(torch.load(model_path, map_location=device), strict=True) + model.eval() + model = model.to(device) + return model + +def extract_arcface_bgr_embedding(in_image, landmark, arcface_model, in_settings=None, device='cuda'): kps = landmark arc_face_image = face_align.norm_crop(in_image, landmark=np.array(kps), image_size=112) arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0,3,1,2) / 255. arc_face_image = 2 * arc_face_image - 1 - arc_face_image = arc_face_image.cuda().contiguous() + arc_face_image = arc_face_image.to(device).contiguous() face_emb = arcface_model(arc_face_image)[0] # [512], normalized return face_emb From 8e8c3e8e2bd84d6bedc0a80af9fd0e310e69a321 Mon Sep 17 00:00:00 2001 From: Santiago Lezica Date: Wed, 25 Jun 2025 22:16:28 -0300 Subject: [PATCH 2/2] Remove unused line in init_arcface_model --- utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/utils.py b/utils.py index 12425e4..dba0e1e 100644 --- a/utils.py +++ b/utils.py @@ -25,7 +25,6 @@ def init_arcface_model(device='cuda'): model = arcface_arch.Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to(device).eval() - model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth' model_path = load_file_from_url(url=ARCFACE_MODEL_URL, model_dir='facexlib/weights') model.load_state_dict(torch.load(model_path, map_location=device), strict=True)