Skip to content

Commit 9f91305

Browse files
authored
Cosmos Predict2 (#11695)
* support text-to-image * update example * make fix-copies * support use_flow_sigmas in EDM scheduler instead of maintain cosmos-specific scheduler * support video-to-world * update * rename text2image pipeline * make fix-copies * add t2i test * add test for v2w pipeline * support edm dpmsolver multistep * update * update * update * update tests * fix tests * safety checker * make conversion script work without guardrail
1 parent 368958d commit 9f91305

14 files changed

+2471
-60
lines changed

docs/source/en/api/pipelines/cosmos.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
3636
- all
3737
- __call__
3838

39+
## Cosmos2TextToImagePipeline
40+
41+
[[autodoc]] Cosmos2TextToImagePipeline
42+
- all
43+
- __call__
44+
45+
## Cosmos2VideoToWorldPipeline
46+
47+
[[autodoc]] Cosmos2VideoToWorldPipeline
48+
- all
49+
- __call__
50+
3951
## CosmosPipelineOutput
4052

4153
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
54+
55+
## CosmosImagePipelineOutput
56+
57+
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput

scripts/convert_cosmos_to_diffusers.py

Lines changed: 186 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
from huggingface_hub import snapshot_download
88
from transformers import T5EncoderModel, T5TokenizerFast
99

10-
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
10+
from diffusers import (
11+
AutoencoderKLCosmos,
12+
AutoencoderKLWan,
13+
Cosmos2TextToImagePipeline,
14+
Cosmos2VideoToWorldPipeline,
15+
CosmosTextToWorldPipeline,
16+
CosmosTransformer3DModel,
17+
CosmosVideoToWorldPipeline,
18+
EDMEulerScheduler,
19+
FlowMatchEulerDiscreteScheduler,
20+
)
1121

1222

1323
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -29,7 +39,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
2939
state_dict[new_key] = state_dict.pop(key)
3040

3141

32-
TRANSFORMER_KEYS_RENAME_DICT = {
42+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
3343
"t_embedder.1": "time_embed.t_embedder",
3444
"affline_norm": "time_embed.norm",
3545
".blocks.0.block.attn": ".attn1",
@@ -56,14 +66,53 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
5666
"final_layer.linear": "proj_out",
5767
}
5868

59-
TRANSFORMER_SPECIAL_KEYS_REMAP = {
69+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
6070
"blocks.block": rename_transformer_blocks_,
6171
"logvar.0.freqs": remove_keys_,
6272
"logvar.0.phases": remove_keys_,
6373
"logvar.1.weight": remove_keys_,
6474
"pos_embedder.seq": remove_keys_,
6575
}
6676

77+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
78+
"t_embedder.1": "time_embed.t_embedder",
79+
"t_embedding_norm": "time_embed.norm",
80+
"blocks": "transformer_blocks",
81+
"adaln_modulation_self_attn.1": "norm1.linear_1",
82+
"adaln_modulation_self_attn.2": "norm1.linear_2",
83+
"adaln_modulation_cross_attn.1": "norm2.linear_1",
84+
"adaln_modulation_cross_attn.2": "norm2.linear_2",
85+
"adaln_modulation_mlp.1": "norm3.linear_1",
86+
"adaln_modulation_mlp.2": "norm3.linear_2",
87+
"self_attn": "attn1",
88+
"cross_attn": "attn2",
89+
"q_proj": "to_q",
90+
"k_proj": "to_k",
91+
"v_proj": "to_v",
92+
"output_proj": "to_out.0",
93+
"q_norm": "norm_q",
94+
"k_norm": "norm_k",
95+
"mlp.layer1": "ff.net.0.proj",
96+
"mlp.layer2": "ff.net.2",
97+
"x_embedder.proj.1": "patch_embed.proj",
98+
# "extra_pos_embedder": "learnable_pos_embed",
99+
"final_layer.adaln_modulation.1": "norm_out.linear_1",
100+
"final_layer.adaln_modulation.2": "norm_out.linear_2",
101+
"final_layer.linear": "proj_out",
102+
}
103+
104+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
105+
"accum_video_sample_counter": remove_keys_,
106+
"accum_image_sample_counter": remove_keys_,
107+
"accum_iteration": remove_keys_,
108+
"accum_train_in_hours": remove_keys_,
109+
"pos_embedder.seq": remove_keys_,
110+
"pos_embedder.dim_spatial_range": remove_keys_,
111+
"pos_embedder.dim_temporal_range": remove_keys_,
112+
"_extra_state": remove_keys_,
113+
}
114+
115+
67116
TRANSFORMER_CONFIGS = {
68117
"Cosmos-1.0-Diffusion-7B-Text2World": {
69118
"in_channels": 16,
@@ -125,6 +174,66 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
125174
"concat_padding_mask": True,
126175
"extra_pos_embed_type": "learnable",
127176
},
177+
"Cosmos-2.0-Diffusion-2B-Text2Image": {
178+
"in_channels": 16,
179+
"out_channels": 16,
180+
"num_attention_heads": 16,
181+
"attention_head_dim": 128,
182+
"num_layers": 28,
183+
"mlp_ratio": 4.0,
184+
"text_embed_dim": 1024,
185+
"adaln_lora_dim": 256,
186+
"max_size": (128, 240, 240),
187+
"patch_size": (1, 2, 2),
188+
"rope_scale": (1.0, 4.0, 4.0),
189+
"concat_padding_mask": True,
190+
"extra_pos_embed_type": None,
191+
},
192+
"Cosmos-2.0-Diffusion-14B-Text2Image": {
193+
"in_channels": 16,
194+
"out_channels": 16,
195+
"num_attention_heads": 40,
196+
"attention_head_dim": 128,
197+
"num_layers": 36,
198+
"mlp_ratio": 4.0,
199+
"text_embed_dim": 1024,
200+
"adaln_lora_dim": 256,
201+
"max_size": (128, 240, 240),
202+
"patch_size": (1, 2, 2),
203+
"rope_scale": (1.0, 4.0, 4.0),
204+
"concat_padding_mask": True,
205+
"extra_pos_embed_type": None,
206+
},
207+
"Cosmos-2.0-Diffusion-2B-Video2World": {
208+
"in_channels": 16 + 1,
209+
"out_channels": 16,
210+
"num_attention_heads": 16,
211+
"attention_head_dim": 128,
212+
"num_layers": 28,
213+
"mlp_ratio": 4.0,
214+
"text_embed_dim": 1024,
215+
"adaln_lora_dim": 256,
216+
"max_size": (128, 240, 240),
217+
"patch_size": (1, 2, 2),
218+
"rope_scale": (1.0, 3.0, 3.0),
219+
"concat_padding_mask": True,
220+
"extra_pos_embed_type": None,
221+
},
222+
"Cosmos-2.0-Diffusion-14B-Video2World": {
223+
"in_channels": 16 + 1,
224+
"out_channels": 16,
225+
"num_attention_heads": 40,
226+
"attention_head_dim": 128,
227+
"num_layers": 36,
228+
"mlp_ratio": 4.0,
229+
"text_embed_dim": 1024,
230+
"adaln_lora_dim": 256,
231+
"max_size": (128, 240, 240),
232+
"patch_size": (1, 2, 2),
233+
"rope_scale": (20 / 24, 2.0, 2.0),
234+
"concat_padding_mask": True,
235+
"extra_pos_embed_type": None,
236+
},
128237
}
129238

130239
VAE_KEYS_RENAME_DICT = {
@@ -216,9 +325,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
216325
return state_dict
217326

218327

219-
def convert_transformer(transformer_type: str, ckpt_path: str):
328+
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
220329
PREFIX_KEY = "net."
221-
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
330+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
331+
332+
if "Cosmos-1.0" in transformer_type:
333+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
334+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
335+
elif "Cosmos-2.0" in transformer_type:
336+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
337+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
338+
else:
339+
assert False
222340

223341
with init_empty_weights():
224342
config = TRANSFORMER_CONFIGS[transformer_type]
@@ -281,13 +399,61 @@ def convert_vae(vae_type: str):
281399
return vae
282400

283401

402+
def save_pipeline_cosmos_1_0(args, transformer, vae):
403+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
404+
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
405+
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
406+
# So, the sigma_min values that is used is the default value of 0.002.
407+
scheduler = EDMEulerScheduler(
408+
sigma_min=0.002,
409+
sigma_max=80,
410+
sigma_data=0.5,
411+
sigma_schedule="karras",
412+
num_train_timesteps=1000,
413+
prediction_type="epsilon",
414+
rho=7.0,
415+
final_sigmas_type="sigma_min",
416+
)
417+
418+
pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
419+
pipe = pipe_cls(
420+
text_encoder=text_encoder,
421+
tokenizer=tokenizer,
422+
transformer=transformer,
423+
vae=vae,
424+
scheduler=scheduler,
425+
safety_checker=lambda *args, **kwargs: None,
426+
)
427+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
428+
429+
430+
def save_pipeline_cosmos_2_0(args, transformer, vae):
431+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
432+
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
433+
434+
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
435+
436+
pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
437+
pipe = pipe_cls(
438+
text_encoder=text_encoder,
439+
tokenizer=tokenizer,
440+
transformer=transformer,
441+
vae=vae,
442+
scheduler=scheduler,
443+
safety_checker=lambda *args, **kwargs: None,
444+
)
445+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
446+
447+
284448
def get_args():
285449
parser = argparse.ArgumentParser()
286450
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
287451
parser.add_argument(
288452
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
289453
)
290-
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
454+
parser.add_argument(
455+
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
456+
)
291457
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
292458
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
293459
parser.add_argument("--save_pipeline", action="store_true")
@@ -316,37 +482,26 @@ def get_args():
316482
assert args.tokenizer_path is not None
317483

318484
if args.transformer_ckpt_path is not None:
319-
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
485+
weights_only = "Cosmos-1.0" in args.transformer_type
486+
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
320487
transformer = transformer.to(dtype=dtype)
321488
if not args.save_pipeline:
322489
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
323490

324491
if args.vae_type is not None:
325-
vae = convert_vae(args.vae_type)
492+
if "Cosmos-1.0" in args.transformer_type:
493+
vae = convert_vae(args.vae_type)
494+
else:
495+
vae = AutoencoderKLWan.from_pretrained(
496+
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
497+
)
326498
if not args.save_pipeline:
327499
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
328500

329501
if args.save_pipeline:
330-
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
331-
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
332-
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
333-
# So, the sigma_min values that is used is the default value of 0.002.
334-
scheduler = EDMEulerScheduler(
335-
sigma_min=0.002,
336-
sigma_max=80,
337-
sigma_data=0.5,
338-
sigma_schedule="karras",
339-
num_train_timesteps=1000,
340-
prediction_type="epsilon",
341-
rho=7.0,
342-
final_sigmas_type="sigma_min",
343-
)
344-
345-
pipe = CosmosTextToWorldPipeline(
346-
text_encoder=text_encoder,
347-
tokenizer=tokenizer,
348-
transformer=transformer,
349-
vae=vae,
350-
scheduler=scheduler,
351-
)
352-
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
502+
if "Cosmos-1.0" in args.transformer_type:
503+
save_pipeline_cosmos_1_0(args, transformer, vae)
504+
elif "Cosmos-2.0" in args.transformer_type:
505+
save_pipeline_cosmos_2_0(args, transformer, vae)
506+
else:
507+
assert False

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,8 @@
361361
"CogView4ControlPipeline",
362362
"CogView4Pipeline",
363363
"ConsisIDPipeline",
364+
"Cosmos2TextToImagePipeline",
365+
"Cosmos2VideoToWorldPipeline",
364366
"CosmosTextToWorldPipeline",
365367
"CosmosVideoToWorldPipeline",
366368
"CycleDiffusionPipeline",
@@ -949,6 +951,8 @@
949951
CogView4ControlPipeline,
950952
CogView4Pipeline,
951953
ConsisIDPipeline,
954+
Cosmos2TextToImagePipeline,
955+
Cosmos2VideoToWorldPipeline,
952956
CosmosTextToWorldPipeline,
953957
CosmosVideoToWorldPipeline,
954958
CycleDiffusionPipeline,

0 commit comments

Comments
 (0)