Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
import shutil
import glob

from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis

from .utils import extract_arcface_bgr_embedding, tensor_to_np_image, np_image_to_tensor, resize_and_pad_pil_image, draw_kps, escape_path_for_url
from .utils import extract_arcface_bgr_embedding, tensor_to_np_image, np_image_to_tensor, resize_and_pad_pil_image, draw_kps, escape_path_for_url, init_arcface_model
from .infuse_net import load_infuse_net_flux
from .resampler import Resampler

Expand Down Expand Up @@ -110,7 +109,7 @@ def load_insightface(self, image_proj_model_name, image_proj_num_tokens, face_an
device = comfy.model_management.get_torch_device()

# Load arcface model
arcface_model = init_recognition_model('arcface', device=device)
arcface_model = init_arcface_model(device=device)

# Load image proj model
image_emb_dim = 512
Expand Down Expand Up @@ -197,8 +196,8 @@ def extract_id_embedding(self, face_detector, arcface_model, image_proj_model, i

face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
landmark = face_info['kps']
id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, arcface_model)
id_embed = id_embed.clone().unsqueeze(0).float().to(device)
id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, arcface_model, device=device)
id_embed = id_embed.clone().unsqueeze(0).float()
id_embed = id_embed.reshape([1, -1, 512])
id_embed = id_embed.to(device=device, dtype=torch.bfloat16)
with torch.no_grad():
Expand Down
17 changes: 15 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,29 @@
import torch
import numpy as np
from insightface.utils import face_align
from facexlib.utils import load_file_from_url
from facexlib.recognition import arcface_arch
from PIL import Image
import math
import cv2

def extract_arcface_bgr_embedding(in_image, landmark, arcface_model, in_settings=None):
ARCFACE_MODEL_URL = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth'

def init_arcface_model(device='cuda'):
model = arcface_arch.Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to(device).eval()

model_path = load_file_from_url(url=ARCFACE_MODEL_URL, model_dir='facexlib/weights')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we store the model file in the comfyUI/models folder?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! There's no need to load it from URL, it can be loaded from a local file. That would actually be nicer in my personal opinion, I just copied the way it was done under the hood by the arcface module, to avoid loading the PR with other changes

model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
model.eval()
model = model.to(device)
return model

def extract_arcface_bgr_embedding(in_image, landmark, arcface_model, in_settings=None, device='cuda'):
kps = landmark
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(kps), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0,3,1,2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.cuda().contiguous()
arc_face_image = arc_face_image.to(device).contiguous()
face_emb = arcface_model(arc_face_image)[0] # [512], normalized
return face_emb

Expand Down