-
Notifications
You must be signed in to change notification settings - Fork 133
/
Copy pathflux_usp_example.py
179 lines (152 loc) · 7.34 KB
/
flux_usp_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Flux inference with USP
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py
import functools
from typing import List, Optional
import logging
import time
import torch
from diffusers import DiffusionPipeline, FluxPipeline
from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
initialize_runtime_state,
get_pipeline_parallel_world_size,
)
from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0
def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
*args,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
**kwargs,
):
assert hidden_states.shape[0] % get_classifier_free_guidance_world_size() == 0, \
f"Cannot split dim 0 of hidden_states ({hidden_states.shape[0]}) into {get_classifier_free_guidance_world_size()} parts."
if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0:
get_runtime_state().split_text_embed_in_sp = False
else:
get_runtime_state().split_text_embed_in_sp = True
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
if get_runtime_state().split_text_embed_in_sp:
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserFluxAttnProcessor2_0()
output = original_forward(
hidden_states,
encoder_hidden_states,
*args,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
**kwargs,
)
return_dict = not isinstance(output, tuple)
sample = output[0]
sample = get_sp_group().all_gather(sample, dim=-2)
sample = get_cfg_group().all_gather(sample, dim=0)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])
new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward
def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank
assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
pipe = FluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
else:
pipe = pipe.to(f"cuda:{local_rank}")
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
initialize_runtime_state(pipe, engine_config)
get_runtime_state().set_input_parameters(
height=input_config.height,
width=input_config.width,
batch_size=1,
num_inference_steps=input_config.num_inference_steps,
max_condition_sequence_length=512,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)
parallelize_transformer(pipe)
if engine_config.runtime_config.use_torch_compile:
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
# one step to warmup the torch compiler
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=1,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
).images
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
image.save(f"./results/{image_name}")
print(f"image {i} saved to ./results/{image_name}")
if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()
if __name__ == "__main__":
main()