Skip to content

Commit 4feb4d6

Browse files
committed
update
1 parent 7d7f274 commit 4feb4d6

File tree

3 files changed

+28
-20
lines changed

3 files changed

+28
-20
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
- local: optimization/memory
182182
title: Reduce memory usage
183183
- local: optimization/speed-memory-optims
184-
title: Compile and offloading
184+
title: Compile and offloading quantized models
185185
- local: optimization/xformers
186186
title: xFormers
187187
- local: optimization/tome

docs/source/en/optimization/memory.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,13 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
172172
> [!WARNING]
173173
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
174174
175-
## CPU offloading
175+
## Offloading
176+
177+
Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.
178+
179+
Refer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.
180+
181+
### CPU offloading
176182

177183
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
178184

@@ -203,7 +209,7 @@ pipeline(
203209
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
204210
```
205211

206-
## Model offloading
212+
### Model offloading
207213

208214
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
209215

@@ -234,7 +240,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
234240

235241
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
236242

237-
## Group offloading
243+
### Group offloading
238244

239245
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
240246

@@ -278,7 +284,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
278284
export_to_video(video, "output.mp4", fps=8)
279285
```
280286

281-
### CUDA stream
287+
#### CUDA stream
282288

283289
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
284290

@@ -295,13 +301,6 @@ pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_d
295301

296302
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
297303

298-
<Tip>
299-
300-
The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
301-
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).
302-
303-
</Tip>
304-
305304
## Layerwise casting
306305

307306
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.

docs/source/en/optimization/speed-memory-optims.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
1010
specific language governing permissions and limitations under the License.
1111
-->
1212

13-
# Compile and offloading
13+
# Compile and offloading quantized models
1414

1515
When optimizing models, you often face trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it comes at the cost of increased memory consumption since it needs to store intermediate attention layer outputs.
1616

17-
A more balanced optimization strategy combines [torch.compile](./fp16#torchcompile) with various offloading methods. This approach not only accelerates inference but also helps lower memory-usage.
17+
A more balanced optimization strategy combines [torch.compile](./fp16#torchcompile) with various [offloading methods](./memory#offloading) on a quantized model. This approach not only accelerates inference but also helps lower memory-usage.
1818

19-
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage.
19+
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
2020

21-
| combination | latency | memory-usage |
21+
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
22+
23+
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.
24+
25+
| combination | latency (s) | memory-usage (GB) |
2226
|---|---|---|
23-
| quantization, torch.compile | | |
24-
| quantization, torch.compile, model CPU offloading | | |
25-
| quantization, torch.compile, group offloading | | |
27+
| quantization | 32.602 | 14.9453 |
28+
| quantization, torch.compile | 25.847 | 14.9448 |
29+
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
30+
| quantization, torch.compile, group offloading | 60.235 | 12.2369 |
31+
<small>These results are benchmarked on Flux with a RTX 4090. The `transformer` and `text_encoder_2` components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
32+
33+
> [!TIP]
34+
> We recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/) for better torch.compile support.
2635
27-
This guide will show you how to compile and offload a model.
36+
This guide will show you how to compile and offload a quantized model.
2837

2938
## Quantization and torch.compile
3039

0 commit comments

Comments
 (0)