Skip to content

Commit

Permalink
Optimize GPU Memory Usage in CogVideo (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored Oct 14, 2024
1 parent 8a08550 commit 3b9cfb3
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
- name: Remove Files
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}}_${{github.run_number}} sh -c "rm -r *"
- name: Destroy docker
run: docker stop xfuser_test_docker_${{github.repository_owner_id}}_${{github.run_number}}
run: docker stop xfuser_test_docker_${{github.repository_owner_id}}_${{github.run_number}}
12 changes: 11 additions & 1 deletion examples/cogvideox_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import time
import torch
import torch.distributed
Expand Down Expand Up @@ -35,12 +36,21 @@ def main():
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
elif args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
logging.info(f"rank {local_rank} model CPU offload enabled")
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

if args.enable_tiling:
pipe.vae.enable_tiling()

if args.enable_slicing:
pipe.vae.enable_slicing()

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

Expand Down
2 changes: 2 additions & 0 deletions examples/run_cogvideo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ CFG_ARGS="--use_cfg_parallel"
# PIPEFUSION_ARGS="--num_pipeline_patch 8"
# OUTPUT_ARGS="--output_type latent"
# PARALLLEL_VAE="--use_parallel_vae"
ENABLE_TILING="--enable_tiling"
# COMPILE_FLAG="--use_torch_compile"

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
Expand All @@ -35,4 +36,5 @@ $OUTPUT_ARGS \
--prompt "A small dog" \
$CFG_ARGS \
$PARALLLEL_VAE \
$ENABLE_TILING \
$COMPILE_FLAG
15 changes: 15 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,21 @@ def add_cli_args(parser: FlexibleArgumentParser):
action="store_true",
help="Offloading the weights to the CPU.",
)
runtime_group.add_argument(
"--enable_model_cpu_offload",
action="store_true",
help="Offloading the weights to the CPU.",
)
runtime_group.add_argument(
"--enable_tiling",
action="store_true",
help="Making VAE decode a tile at a time to save GPU memory.",
)
runtime_group.add_argument(
"--enable_slicing",
action="store_true",
help="Making VAE decode a tile at a time to save GPU memory.",
)

# DiTFastAttn arguments
fast_attn_group = parser.add_argument_group("DiTFastAttn Options")
Expand Down
1 change: 1 addition & 0 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,7 @@ def __call__(
# dropout
hidden_states = attn.to_out[1](hidden_states)


encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, latent_seq_length], dim=1
)
Expand Down

0 comments on commit 3b9cfb3

Please sign in to comment.