Skip to content

Commit

Permalink
Merge branch 'typo-fix-1' of github.com:LRY89757/xDiT into typo-fix-1
Browse files Browse the repository at this point in the history
  • Loading branch information
LRY89757 committed Sep 25, 2024
2 parents e1fc90d + 4b383c5 commit 289dc65
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
47 changes: 47 additions & 0 deletions examples/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from diffusers import FluxPipeline
from torch.profiler import profile, record_function, ProfilerActivity

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda:1")
# pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

def single_run(num_inference_steps=50):
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")

# warmup
def warmup(times=3):
for _ in range(times):
single_run()

def run():
single_run(num_inference_steps=30)
num_inference_steps=10
# Example PyTorch code to profile
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
record_shapes=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tensorboard/flux")
) as prof:
with record_function("flux_pipeline"):
single_run(num_inference_steps=num_inference_steps)
# prof.export_chrome_trace("test_trace_" + "flux" + f"_steps_{num_inference_steps}" + ".json")

def main():
warmup()
run()

if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions examples/flux_profier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
import time
import torch
import torch.distributed
from xfuser import xFuserFluxPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

pipe = xFuserFluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)

if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
else:
pipe = pipe.to(f"cuda:{local_rank}")

pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.height,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
max_sequence_length=256,
guidance_scale=0.0,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
image.save(f"./results/{image_name}")
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()

0 comments on commit 289dc65

Please sign in to comment.