7
7
from huggingface_hub import snapshot_download
8
8
from transformers import T5EncoderModel , T5TokenizerFast
9
9
10
- from diffusers import AutoencoderKLCosmos , CosmosTextToWorldPipeline , CosmosTransformer3DModel , EDMEulerScheduler
10
+ from diffusers import (
11
+ AutoencoderKLCosmos ,
12
+ AutoencoderKLWan ,
13
+ Cosmos2TextToImagePipeline ,
14
+ Cosmos2VideoToWorldPipeline ,
15
+ CosmosTextToWorldPipeline ,
16
+ CosmosTransformer3DModel ,
17
+ CosmosVideoToWorldPipeline ,
18
+ EDMEulerScheduler ,
19
+ FlowMatchEulerDiscreteScheduler ,
20
+ )
11
21
12
22
13
23
def remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -29,7 +39,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
29
39
state_dict [new_key ] = state_dict .pop (key )
30
40
31
41
32
- TRANSFORMER_KEYS_RENAME_DICT = {
42
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
33
43
"t_embedder.1" : "time_embed.t_embedder" ,
34
44
"affline_norm" : "time_embed.norm" ,
35
45
".blocks.0.block.attn" : ".attn1" ,
@@ -56,14 +66,53 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
56
66
"final_layer.linear" : "proj_out" ,
57
67
}
58
68
59
- TRANSFORMER_SPECIAL_KEYS_REMAP = {
69
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
60
70
"blocks.block" : rename_transformer_blocks_ ,
61
71
"logvar.0.freqs" : remove_keys_ ,
62
72
"logvar.0.phases" : remove_keys_ ,
63
73
"logvar.1.weight" : remove_keys_ ,
64
74
"pos_embedder.seq" : remove_keys_ ,
65
75
}
66
76
77
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
78
+ "t_embedder.1" : "time_embed.t_embedder" ,
79
+ "t_embedding_norm" : "time_embed.norm" ,
80
+ "blocks" : "transformer_blocks" ,
81
+ "adaln_modulation_self_attn.1" : "norm1.linear_1" ,
82
+ "adaln_modulation_self_attn.2" : "norm1.linear_2" ,
83
+ "adaln_modulation_cross_attn.1" : "norm2.linear_1" ,
84
+ "adaln_modulation_cross_attn.2" : "norm2.linear_2" ,
85
+ "adaln_modulation_mlp.1" : "norm3.linear_1" ,
86
+ "adaln_modulation_mlp.2" : "norm3.linear_2" ,
87
+ "self_attn" : "attn1" ,
88
+ "cross_attn" : "attn2" ,
89
+ "q_proj" : "to_q" ,
90
+ "k_proj" : "to_k" ,
91
+ "v_proj" : "to_v" ,
92
+ "output_proj" : "to_out.0" ,
93
+ "q_norm" : "norm_q" ,
94
+ "k_norm" : "norm_k" ,
95
+ "mlp.layer1" : "ff.net.0.proj" ,
96
+ "mlp.layer2" : "ff.net.2" ,
97
+ "x_embedder.proj.1" : "patch_embed.proj" ,
98
+ # "extra_pos_embedder": "learnable_pos_embed",
99
+ "final_layer.adaln_modulation.1" : "norm_out.linear_1" ,
100
+ "final_layer.adaln_modulation.2" : "norm_out.linear_2" ,
101
+ "final_layer.linear" : "proj_out" ,
102
+ }
103
+
104
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
105
+ "accum_video_sample_counter" : remove_keys_ ,
106
+ "accum_image_sample_counter" : remove_keys_ ,
107
+ "accum_iteration" : remove_keys_ ,
108
+ "accum_train_in_hours" : remove_keys_ ,
109
+ "pos_embedder.seq" : remove_keys_ ,
110
+ "pos_embedder.dim_spatial_range" : remove_keys_ ,
111
+ "pos_embedder.dim_temporal_range" : remove_keys_ ,
112
+ "_extra_state" : remove_keys_ ,
113
+ }
114
+
115
+
67
116
TRANSFORMER_CONFIGS = {
68
117
"Cosmos-1.0-Diffusion-7B-Text2World" : {
69
118
"in_channels" : 16 ,
@@ -125,6 +174,66 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
125
174
"concat_padding_mask" : True ,
126
175
"extra_pos_embed_type" : "learnable" ,
127
176
},
177
+ "Cosmos-2.0-Diffusion-2B-Text2Image" : {
178
+ "in_channels" : 16 ,
179
+ "out_channels" : 16 ,
180
+ "num_attention_heads" : 16 ,
181
+ "attention_head_dim" : 128 ,
182
+ "num_layers" : 28 ,
183
+ "mlp_ratio" : 4.0 ,
184
+ "text_embed_dim" : 1024 ,
185
+ "adaln_lora_dim" : 256 ,
186
+ "max_size" : (128 , 240 , 240 ),
187
+ "patch_size" : (1 , 2 , 2 ),
188
+ "rope_scale" : (1.0 , 4.0 , 4.0 ),
189
+ "concat_padding_mask" : True ,
190
+ "extra_pos_embed_type" : None ,
191
+ },
192
+ "Cosmos-2.0-Diffusion-14B-Text2Image" : {
193
+ "in_channels" : 16 ,
194
+ "out_channels" : 16 ,
195
+ "num_attention_heads" : 40 ,
196
+ "attention_head_dim" : 128 ,
197
+ "num_layers" : 36 ,
198
+ "mlp_ratio" : 4.0 ,
199
+ "text_embed_dim" : 1024 ,
200
+ "adaln_lora_dim" : 256 ,
201
+ "max_size" : (128 , 240 , 240 ),
202
+ "patch_size" : (1 , 2 , 2 ),
203
+ "rope_scale" : (1.0 , 4.0 , 4.0 ),
204
+ "concat_padding_mask" : True ,
205
+ "extra_pos_embed_type" : None ,
206
+ },
207
+ "Cosmos-2.0-Diffusion-2B-Video2World" : {
208
+ "in_channels" : 16 + 1 ,
209
+ "out_channels" : 16 ,
210
+ "num_attention_heads" : 16 ,
211
+ "attention_head_dim" : 128 ,
212
+ "num_layers" : 28 ,
213
+ "mlp_ratio" : 4.0 ,
214
+ "text_embed_dim" : 1024 ,
215
+ "adaln_lora_dim" : 256 ,
216
+ "max_size" : (128 , 240 , 240 ),
217
+ "patch_size" : (1 , 2 , 2 ),
218
+ "rope_scale" : (1.0 , 3.0 , 3.0 ),
219
+ "concat_padding_mask" : True ,
220
+ "extra_pos_embed_type" : None ,
221
+ },
222
+ "Cosmos-2.0-Diffusion-14B-Video2World" : {
223
+ "in_channels" : 16 + 1 ,
224
+ "out_channels" : 16 ,
225
+ "num_attention_heads" : 40 ,
226
+ "attention_head_dim" : 128 ,
227
+ "num_layers" : 36 ,
228
+ "mlp_ratio" : 4.0 ,
229
+ "text_embed_dim" : 1024 ,
230
+ "adaln_lora_dim" : 256 ,
231
+ "max_size" : (128 , 240 , 240 ),
232
+ "patch_size" : (1 , 2 , 2 ),
233
+ "rope_scale" : (20 / 24 , 2.0 , 2.0 ),
234
+ "concat_padding_mask" : True ,
235
+ "extra_pos_embed_type" : None ,
236
+ },
128
237
}
129
238
130
239
VAE_KEYS_RENAME_DICT = {
@@ -216,9 +325,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
216
325
return state_dict
217
326
218
327
219
- def convert_transformer (transformer_type : str , ckpt_path : str ):
328
+ def convert_transformer (transformer_type : str , ckpt_path : str , weights_only : bool = True ):
220
329
PREFIX_KEY = "net."
221
- original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = True ))
330
+ original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = weights_only ))
331
+
332
+ if "Cosmos-1.0" in transformer_type :
333
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
334
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
335
+ elif "Cosmos-2.0" in transformer_type :
336
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
337
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
338
+ else :
339
+ assert False
222
340
223
341
with init_empty_weights ():
224
342
config = TRANSFORMER_CONFIGS [transformer_type ]
@@ -281,13 +399,61 @@ def convert_vae(vae_type: str):
281
399
return vae
282
400
283
401
402
+ def save_pipeline_cosmos_1_0 (args , transformer , vae ):
403
+ text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .bfloat16 )
404
+ tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
405
+ # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
406
+ # So, the sigma_min values that is used is the default value of 0.002.
407
+ scheduler = EDMEulerScheduler (
408
+ sigma_min = 0.002 ,
409
+ sigma_max = 80 ,
410
+ sigma_data = 0.5 ,
411
+ sigma_schedule = "karras" ,
412
+ num_train_timesteps = 1000 ,
413
+ prediction_type = "epsilon" ,
414
+ rho = 7.0 ,
415
+ final_sigmas_type = "sigma_min" ,
416
+ )
417
+
418
+ pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args .transformer_type else CosmosVideoToWorldPipeline
419
+ pipe = pipe_cls (
420
+ text_encoder = text_encoder ,
421
+ tokenizer = tokenizer ,
422
+ transformer = transformer ,
423
+ vae = vae ,
424
+ scheduler = scheduler ,
425
+ safety_checker = lambda * args , ** kwargs : None ,
426
+ )
427
+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
428
+
429
+
430
+ def save_pipeline_cosmos_2_0 (args , transformer , vae ):
431
+ text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .bfloat16 )
432
+ tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
433
+
434
+ scheduler = FlowMatchEulerDiscreteScheduler (use_karras_sigmas = True )
435
+
436
+ pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args .transformer_type else Cosmos2VideoToWorldPipeline
437
+ pipe = pipe_cls (
438
+ text_encoder = text_encoder ,
439
+ tokenizer = tokenizer ,
440
+ transformer = transformer ,
441
+ vae = vae ,
442
+ scheduler = scheduler ,
443
+ safety_checker = lambda * args , ** kwargs : None ,
444
+ )
445
+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
446
+
447
+
284
448
def get_args ():
285
449
parser = argparse .ArgumentParser ()
286
450
parser .add_argument ("--transformer_type" , type = str , default = None , choices = list (TRANSFORMER_CONFIGS .keys ()))
287
451
parser .add_argument (
288
452
"--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
289
453
)
290
- parser .add_argument ("--vae_type" , type = str , default = None , choices = list (VAE_CONFIGS .keys ()), help = "Type of VAE" )
454
+ parser .add_argument (
455
+ "--vae_type" , type = str , default = None , choices = ["none" , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
456
+ )
291
457
parser .add_argument ("--text_encoder_path" , type = str , default = "google-t5/t5-11b" )
292
458
parser .add_argument ("--tokenizer_path" , type = str , default = "google-t5/t5-11b" )
293
459
parser .add_argument ("--save_pipeline" , action = "store_true" )
@@ -316,37 +482,26 @@ def get_args():
316
482
assert args .tokenizer_path is not None
317
483
318
484
if args .transformer_ckpt_path is not None :
319
- transformer = convert_transformer (args .transformer_type , args .transformer_ckpt_path )
485
+ weights_only = "Cosmos-1.0" in args .transformer_type
486
+ transformer = convert_transformer (args .transformer_type , args .transformer_ckpt_path , weights_only )
320
487
transformer = transformer .to (dtype = dtype )
321
488
if not args .save_pipeline :
322
489
transformer .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
323
490
324
491
if args .vae_type is not None :
325
- vae = convert_vae (args .vae_type )
492
+ if "Cosmos-1.0" in args .transformer_type :
493
+ vae = convert_vae (args .vae_type )
494
+ else :
495
+ vae = AutoencoderKLWan .from_pretrained (
496
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" , subfolder = "vae" , torch_dtype = torch .float32
497
+ )
326
498
if not args .save_pipeline :
327
499
vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
328
500
329
501
if args .save_pipeline :
330
- text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = dtype )
331
- tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
332
- # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
333
- # So, the sigma_min values that is used is the default value of 0.002.
334
- scheduler = EDMEulerScheduler (
335
- sigma_min = 0.002 ,
336
- sigma_max = 80 ,
337
- sigma_data = 0.5 ,
338
- sigma_schedule = "karras" ,
339
- num_train_timesteps = 1000 ,
340
- prediction_type = "epsilon" ,
341
- rho = 7.0 ,
342
- final_sigmas_type = "sigma_min" ,
343
- )
344
-
345
- pipe = CosmosTextToWorldPipeline (
346
- text_encoder = text_encoder ,
347
- tokenizer = tokenizer ,
348
- transformer = transformer ,
349
- vae = vae ,
350
- scheduler = scheduler ,
351
- )
352
- pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
502
+ if "Cosmos-1.0" in args .transformer_type :
503
+ save_pipeline_cosmos_1_0 (args , transformer , vae )
504
+ elif "Cosmos-2.0" in args .transformer_type :
505
+ save_pipeline_cosmos_2_0 (args , transformer , vae )
506
+ else :
507
+ assert False
0 commit comments