From d57daf938a5635e4629bb07c2ea84bc2ae0cedda Mon Sep 17 00:00:00 2001 From: Geremie Yeo <100673850+bogoconic1@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:58:45 +0800 Subject: [PATCH 1/3] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index e083ffba8..77435f2c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ regex tqdm torch torchvision +safetensors From c2bdbb6ded7d92819f938655543d643082e8bea7 Mon Sep 17 00:00:00 2001 From: Geremie Yeo <100673850+bogoconic1@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:59:43 +0800 Subject: [PATCH 2/3] load safetensors version in clip --- clip/clip.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/clip/clip.py b/clip/clip.py index 398a6282c..4329e81b4 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -12,6 +12,7 @@ from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from safetensors import safe_open try: from torchvision.transforms import InterpolationMode @@ -124,16 +125,20 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a raise RuntimeError(f"Model {name} not found; available models = {available_models()}") with open(model_path, 'rb') as opened_file: - try: - # loading JIT archive - model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(opened_file, map_location="cpu") + if model_path.endswith('.safetensors'): + with safe_open(model_path, framework="pt", device="cpu") as f: + state_dict = {key: f.get_tensor(key) for key in f.keys()} + else: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") if not jit: model = build_model(state_dict or model.state_dict()).to(device) From 1486984f3ac7969f15b0d5ad26c3f9660bf66393 Mon Sep 17 00:00:00 2001 From: bogoconic1 Date: Tue, 3 Jun 2025 15:00:21 +0000 Subject: [PATCH 3/3] add support for huggingface bin/safetensors variant --- clip/clip.py | 212 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 211 insertions(+), 1 deletion(-) diff --git a/clip/clip.py b/clip/clip.py index 4329e81b4..a86e3d82f 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -91,6 +91,210 @@ def available_models() -> List[str]: """Returns the names of available CLIP models""" return list(_MODELS.keys()) +def convert_hf_to_openai_full(hf_state_dict): + """ + Complete conversion from Hugging Face CLIP to OpenAI CLIP format + """ + converted_dict = {} + + for key, tensor in hf_state_dict.items(): + # Skip position_ids - they're not needed in OpenAI format + if 'position_ids' in key: + print(f"Skipping {key} - not needed in OpenAI format") + continue + + # Handle projection weights + elif key == 'visual_projection.weight': + # In OpenAI CLIP, this is stored as visual.proj (transposed) + converted_dict['visual.proj'] = tensor.T # Note the transpose! + continue + + elif key == 'text_projection.weight': + # In OpenAI CLIP, this is stored as text_projection (transposed) + converted_dict['text_projection'] = tensor.T # Note the transpose! + continue + + # Handle other standard mappings + elif key.startswith('text_model.'): + if 'embeddings.token_embedding.weight' in key: + converted_dict['token_embedding.weight'] = tensor + elif 'embeddings.position_embedding.weight' in key: + converted_dict['positional_embedding'] = tensor + elif 'final_layer_norm.weight' in key: + converted_dict['ln_final.weight'] = tensor + elif 'final_layer_norm.bias' in key: + converted_dict['ln_final.bias'] = tensor + elif 'encoder.layers.' in key: + converted_dict.update(convert_text_layer(key, tensor)) + + elif key.startswith('vision_model.'): + if 'embeddings.patch_embedding.weight' in key: + converted_dict['visual.conv1.weight'] = tensor + elif 'embeddings.position_embedding.weight' in key: + converted_dict['visual.positional_embedding'] = tensor + elif 'embeddings.class_embedding' in key: + converted_dict['visual.class_embedding'] = tensor + elif 'pre_layrnorm.weight' in key: + converted_dict['visual.ln_pre.weight'] = tensor + elif 'pre_layrnorm.bias' in key: + converted_dict['visual.ln_pre.bias'] = tensor + elif 'post_layernorm.weight' in key: + converted_dict['visual.ln_post.weight'] = tensor + elif 'post_layernorm.bias' in key: + converted_dict['visual.ln_post.bias'] = tensor + elif 'encoder.layers.' in key: + converted_dict.update(convert_vision_layer(key, tensor)) + + elif key == 'logit_scale': + converted_dict['logit_scale'] = tensor + + else: + print(f"Unhandled key: {key}") + + # Handle the q/k/v -> in_proj_weight conversion + converted_dict = combine_qkv_projections_complete(hf_state_dict, converted_dict) + + return converted_dict + +def convert_text_layer(key, tensor): + """Convert text layer keys""" + import re + result = {} + + layer_match = re.search(r'encoder\.layers\.(\d+)', key) + if not layer_match: + return result + + layer_num = layer_match.group(1) + + if 'self_attn.out_proj.weight' in key: + result[f'transformer.resblocks.{layer_num}.attn.out_proj.weight'] = tensor + elif 'self_attn.out_proj.bias' in key: + result[f'transformer.resblocks.{layer_num}.attn.out_proj.bias'] = tensor + elif 'layer_norm1.weight' in key: + result[f'transformer.resblocks.{layer_num}.ln_1.weight'] = tensor + elif 'layer_norm1.bias' in key: + result[f'transformer.resblocks.{layer_num}.ln_1.bias'] = tensor + elif 'layer_norm2.weight' in key: + result[f'transformer.resblocks.{layer_num}.ln_2.weight'] = tensor + elif 'layer_norm2.bias' in key: + result[f'transformer.resblocks.{layer_num}.ln_2.bias'] = tensor + elif 'mlp.fc1.weight' in key: + result[f'transformer.resblocks.{layer_num}.mlp.c_fc.weight'] = tensor + elif 'mlp.fc1.bias' in key: + result[f'transformer.resblocks.{layer_num}.mlp.c_fc.bias'] = tensor + elif 'mlp.fc2.weight' in key: + result[f'transformer.resblocks.{layer_num}.mlp.c_proj.weight'] = tensor + elif 'mlp.fc2.bias' in key: + result[f'transformer.resblocks.{layer_num}.mlp.c_proj.bias'] = tensor + # Skip q/k/v proj weights - handled separately + + return result + +def convert_vision_layer(key, tensor): + """Convert vision layer keys""" + import re + result = {} + + layer_match = re.search(r'encoder\.layers\.(\d+)', key) + if not layer_match: + return result + + layer_num = layer_match.group(1) + + if 'self_attn.out_proj.weight' in key: + result[f'visual.transformer.resblocks.{layer_num}.attn.out_proj.weight'] = tensor + elif 'self_attn.out_proj.bias' in key: + result[f'visual.transformer.resblocks.{layer_num}.attn.out_proj.bias'] = tensor + elif 'layer_norm1.weight' in key: + result[f'visual.transformer.resblocks.{layer_num}.ln_1.weight'] = tensor + elif 'layer_norm1.bias' in key: + result[f'visual.transformer.resblocks.{layer_num}.ln_1.bias'] = tensor + elif 'layer_norm2.weight' in key: + result[f'visual.transformer.resblocks.{layer_num}.ln_2.weight'] = tensor + elif 'layer_norm2.bias' in key: + result[f'visual.transformer.resblocks.{layer_num}.ln_2.bias'] = tensor + elif 'mlp.fc1.weight' in key: + result[f'visual.transformer.resblocks.{layer_num}.mlp.c_fc.weight'] = tensor + elif 'mlp.fc1.bias' in key: + result[f'visual.transformer.resblocks.{layer_num}.mlp.c_fc.bias'] = tensor + elif 'mlp.fc2.weight' in key: + result[f'visual.transformer.resblocks.{layer_num}.mlp.c_proj.weight'] = tensor + elif 'mlp.fc2.bias' in key: + result[f'visual.transformer.resblocks.{layer_num}.mlp.c_proj.bias'] = tensor + # Skip q/k/v proj weights - handled separately + + return result + +def combine_qkv_projections_complete(hf_state_dict, converted_dict): + """Combine q, k, v projections for both text and vision models""" + import re + + # Process text model layers + for key in hf_state_dict.keys(): + if 'text_model.encoder.layers.' in key and 'self_attn.q_proj.weight' in key: + layer_match = re.search(r'layers\.(\d+)', key) + if layer_match: + layer_num = layer_match.group(1) + + q_key = f'text_model.encoder.layers.{layer_num}.self_attn.q_proj.weight' + k_key = f'text_model.encoder.layers.{layer_num}.self_attn.k_proj.weight' + v_key = f'text_model.encoder.layers.{layer_num}.self_attn.v_proj.weight' + + if all(k in hf_state_dict for k in [q_key, k_key, v_key]): + combined_weight = torch.cat([ + hf_state_dict[q_key], + hf_state_dict[k_key], + hf_state_dict[v_key] + ], dim=0) + converted_dict[f'transformer.resblocks.{layer_num}.attn.in_proj_weight'] = combined_weight + + # Handle biases if they exist + q_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.q_proj.bias' + k_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.k_proj.bias' + v_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.v_proj.bias' + + if all(k in hf_state_dict for k in [q_bias_key, k_bias_key, v_bias_key]): + combined_bias = torch.cat([ + hf_state_dict[q_bias_key], + hf_state_dict[k_bias_key], + hf_state_dict[v_bias_key] + ], dim=0) + converted_dict[f'transformer.resblocks.{layer_num}.attn.in_proj_bias'] = combined_bias + + # Process vision model layers + for key in hf_state_dict.keys(): + if 'vision_model.encoder.layers.' in key and 'self_attn.q_proj.weight' in key: + layer_match = re.search(r'layers\.(\d+)', key) + if layer_match: + layer_num = layer_match.group(1) + + q_key = f'vision_model.encoder.layers.{layer_num}.self_attn.q_proj.weight' + k_key = f'vision_model.encoder.layers.{layer_num}.self_attn.k_proj.weight' + v_key = f'vision_model.encoder.layers.{layer_num}.self_attn.v_proj.weight' + + if all(k in hf_state_dict for k in [q_key, k_key, v_key]): + combined_weight = torch.cat([ + hf_state_dict[q_key], + hf_state_dict[k_key], + hf_state_dict[v_key] + ], dim=0) + converted_dict[f'visual.transformer.resblocks.{layer_num}.attn.in_proj_weight'] = combined_weight + + # Handle biases + q_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.q_proj.bias' + k_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.k_proj.bias' + v_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.v_proj.bias' + + if all(k in hf_state_dict for k in [q_bias_key, k_bias_key, v_bias_key]): + combined_bias = torch.cat([ + hf_state_dict[q_bias_key], + hf_state_dict[k_bias_key], + hf_state_dict[v_bias_key] + ], dim=0) + converted_dict[f'visual.transformer.resblocks.{layer_num}.attn.in_proj_bias'] = combined_bias + + return converted_dict def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): """Load a CLIP model @@ -125,7 +329,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a raise RuntimeError(f"Model {name} not found; available models = {available_models()}") with open(model_path, 'rb') as opened_file: - if model_path.endswith('.safetensors'): + if model_path.endswith('.bin'): + state_dict = torch.load(model_path, map_location="cpu") + elif model_path.endswith('.safetensors'): with safe_open(model_path, framework="pt", device="cpu") as f: state_dict = {key: f.get_tensor(key) for key in f.keys()} else: @@ -140,6 +346,10 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a jit = False state_dict = torch.load(opened_file, map_location="cpu") + if model_path.endswith('.bin') or model_path.endswith('.safetensors'): + state_dict = convert_hf_to_openai_full(state_dict) + + if not jit: model = build_model(state_dict or model.state_dict()).to(device) if str(device) == "cpu":