|
127 | 127 | "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
128 | 128 | "wan_vae": "decoder.middle.0.residual.0.gamma",
|
129 | 129 | "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
| 130 | + "cosmos-1.0": [ |
| 131 | + "net.x_embedder.proj.1.weight", |
| 132 | + "net.blocks.block1.blocks.0.block.attn.to_q.0.weight", |
| 133 | + "net.extra_pos_embedder.pos_emb_h", |
| 134 | + ], |
| 135 | + "cosmos-2.0": [ |
| 136 | + "net.x_embedder.proj.1.weight", |
| 137 | + "net.blocks.0.self_attn.q_proj.weight", |
| 138 | + "net.pos_embedder.dim_spatial_range", |
| 139 | + ], |
130 | 140 | }
|
131 | 141 |
|
132 | 142 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
193 | 203 | "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
194 | 204 | "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
195 | 205 | "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
| 206 | + "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"}, |
| 207 | + "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"}, |
| 208 | + "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"}, |
| 209 | + "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"}, |
| 210 | + "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"}, |
| 211 | + "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"}, |
| 212 | + "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"}, |
| 213 | + "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"}, |
196 | 214 | }
|
197 | 215 |
|
198 | 216 | # Use to configure model sample size when original config is provided
|
@@ -704,11 +722,32 @@ def infer_diffusers_model_type(checkpoint):
|
704 | 722 | model_type = "wan-t2v-14B"
|
705 | 723 | else:
|
706 | 724 | model_type = "wan-i2v-14B"
|
| 725 | + |
707 | 726 | elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
708 | 727 | # All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
709 | 728 | model_type = "wan-t2v-14B"
|
| 729 | + |
710 | 730 | elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
711 | 731 | model_type = "hidream"
|
| 732 | + |
| 733 | + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]): |
| 734 | + x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape |
| 735 | + if x_embedder_shape[1] == 68: |
| 736 | + model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B" |
| 737 | + elif x_embedder_shape[1] == 72: |
| 738 | + model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B" |
| 739 | + else: |
| 740 | + raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.") |
| 741 | + |
| 742 | + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]): |
| 743 | + x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape |
| 744 | + if x_embedder_shape[1] == 68: |
| 745 | + model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B" |
| 746 | + elif x_embedder_shape[1] == 72: |
| 747 | + model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B" |
| 748 | + else: |
| 749 | + raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.") |
| 750 | + |
712 | 751 | else:
|
713 | 752 | model_type = "v1"
|
714 | 753 |
|
@@ -3479,3 +3518,116 @@ def swap_scale_shift(weight):
|
3479 | 3518 | converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
3480 | 3519 |
|
3481 | 3520 | return converted_state_dict
|
| 3521 | + |
| 3522 | + |
| 3523 | +def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 3524 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 3525 | + |
| 3526 | + def remove_keys_(key: str, state_dict): |
| 3527 | + state_dict.pop(key) |
| 3528 | + |
| 3529 | + def rename_transformer_blocks_(key: str, state_dict): |
| 3530 | + block_index = int(key.split(".")[1].removeprefix("block")) |
| 3531 | + new_key = key |
| 3532 | + old_prefix = f"blocks.block{block_index}" |
| 3533 | + new_prefix = f"transformer_blocks.{block_index}" |
| 3534 | + new_key = new_prefix + new_key.removeprefix(old_prefix) |
| 3535 | + state_dict[new_key] = state_dict.pop(key) |
| 3536 | + |
| 3537 | + TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = { |
| 3538 | + "t_embedder.1": "time_embed.t_embedder", |
| 3539 | + "affline_norm": "time_embed.norm", |
| 3540 | + ".blocks.0.block.attn": ".attn1", |
| 3541 | + ".blocks.1.block.attn": ".attn2", |
| 3542 | + ".blocks.2.block": ".ff", |
| 3543 | + ".blocks.0.adaLN_modulation.1": ".norm1.linear_1", |
| 3544 | + ".blocks.0.adaLN_modulation.2": ".norm1.linear_2", |
| 3545 | + ".blocks.1.adaLN_modulation.1": ".norm2.linear_1", |
| 3546 | + ".blocks.1.adaLN_modulation.2": ".norm2.linear_2", |
| 3547 | + ".blocks.2.adaLN_modulation.1": ".norm3.linear_1", |
| 3548 | + ".blocks.2.adaLN_modulation.2": ".norm3.linear_2", |
| 3549 | + "to_q.0": "to_q", |
| 3550 | + "to_q.1": "norm_q", |
| 3551 | + "to_k.0": "to_k", |
| 3552 | + "to_k.1": "norm_k", |
| 3553 | + "to_v.0": "to_v", |
| 3554 | + "layer1": "net.0.proj", |
| 3555 | + "layer2": "net.2", |
| 3556 | + "proj.1": "proj", |
| 3557 | + "x_embedder": "patch_embed", |
| 3558 | + "extra_pos_embedder": "learnable_pos_embed", |
| 3559 | + "final_layer.adaLN_modulation.1": "norm_out.linear_1", |
| 3560 | + "final_layer.adaLN_modulation.2": "norm_out.linear_2", |
| 3561 | + "final_layer.linear": "proj_out", |
| 3562 | + } |
| 3563 | + |
| 3564 | + TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = { |
| 3565 | + "blocks.block": rename_transformer_blocks_, |
| 3566 | + "logvar.0.freqs": remove_keys_, |
| 3567 | + "logvar.0.phases": remove_keys_, |
| 3568 | + "logvar.1.weight": remove_keys_, |
| 3569 | + "pos_embedder.seq": remove_keys_, |
| 3570 | + } |
| 3571 | + |
| 3572 | + TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = { |
| 3573 | + "t_embedder.1": "time_embed.t_embedder", |
| 3574 | + "t_embedding_norm": "time_embed.norm", |
| 3575 | + "blocks": "transformer_blocks", |
| 3576 | + "adaln_modulation_self_attn.1": "norm1.linear_1", |
| 3577 | + "adaln_modulation_self_attn.2": "norm1.linear_2", |
| 3578 | + "adaln_modulation_cross_attn.1": "norm2.linear_1", |
| 3579 | + "adaln_modulation_cross_attn.2": "norm2.linear_2", |
| 3580 | + "adaln_modulation_mlp.1": "norm3.linear_1", |
| 3581 | + "adaln_modulation_mlp.2": "norm3.linear_2", |
| 3582 | + "self_attn": "attn1", |
| 3583 | + "cross_attn": "attn2", |
| 3584 | + "q_proj": "to_q", |
| 3585 | + "k_proj": "to_k", |
| 3586 | + "v_proj": "to_v", |
| 3587 | + "output_proj": "to_out.0", |
| 3588 | + "q_norm": "norm_q", |
| 3589 | + "k_norm": "norm_k", |
| 3590 | + "mlp.layer1": "ff.net.0.proj", |
| 3591 | + "mlp.layer2": "ff.net.2", |
| 3592 | + "x_embedder.proj.1": "patch_embed.proj", |
| 3593 | + "final_layer.adaln_modulation.1": "norm_out.linear_1", |
| 3594 | + "final_layer.adaln_modulation.2": "norm_out.linear_2", |
| 3595 | + "final_layer.linear": "proj_out", |
| 3596 | + } |
| 3597 | + |
| 3598 | + TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = { |
| 3599 | + "accum_video_sample_counter": remove_keys_, |
| 3600 | + "accum_image_sample_counter": remove_keys_, |
| 3601 | + "accum_iteration": remove_keys_, |
| 3602 | + "accum_train_in_hours": remove_keys_, |
| 3603 | + "pos_embedder.seq": remove_keys_, |
| 3604 | + "pos_embedder.dim_spatial_range": remove_keys_, |
| 3605 | + "pos_embedder.dim_temporal_range": remove_keys_, |
| 3606 | + "_extra_state": remove_keys_, |
| 3607 | + } |
| 3608 | + |
| 3609 | + PREFIX_KEY = "net." |
| 3610 | + if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint: |
| 3611 | + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 |
| 3612 | + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 |
| 3613 | + else: |
| 3614 | + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 |
| 3615 | + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 |
| 3616 | + |
| 3617 | + state_dict_keys = list(converted_state_dict.keys()) |
| 3618 | + for key in state_dict_keys: |
| 3619 | + new_key = key[:] |
| 3620 | + if new_key.startswith(PREFIX_KEY): |
| 3621 | + new_key = new_key.removeprefix(PREFIX_KEY) |
| 3622 | + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 3623 | + new_key = new_key.replace(replace_key, rename_key) |
| 3624 | + converted_state_dict[new_key] = converted_state_dict.pop(key) |
| 3625 | + |
| 3626 | + state_dict_keys = list(converted_state_dict.keys()) |
| 3627 | + for key in state_dict_keys: |
| 3628 | + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 3629 | + if special_key not in key: |
| 3630 | + continue |
| 3631 | + handler_fn_inplace(key, converted_state_dict) |
| 3632 | + |
| 3633 | + return converted_state_dict |
0 commit comments