You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
174
174
175
-
## CPU offloading
175
+
## Offloading
176
+
177
+
Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.
178
+
179
+
Refer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.
180
+
181
+
### CPU offloading
176
182
177
183
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
236
242
237
-
## Group offloading
243
+
###Group offloading
238
244
239
245
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
297
303
298
-
<Tip>
299
-
300
-
The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
301
-
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).
302
-
303
-
</Tip>
304
-
305
304
## Layerwise casting
306
305
307
306
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
Copy file name to clipboardExpand all lines: docs/source/en/optimization/speed-memory-optims.md
+17-8Lines changed: 17 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,21 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
10
10
specific language governing permissions and limitations under the License.
11
11
-->
12
12
13
-
# Compile and offloading
13
+
# Compile and offloading quantized models
14
14
15
15
When optimizing models, you often face trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it comes at the cost of increased memory consumption since it needs to store intermediate attention layer outputs.
16
16
17
-
A more balanced optimization strategy combines [torch.compile](./fp16#torchcompile) with various offloading methods. This approach not only accelerates inference but also helps lower memory-usage.
17
+
A more balanced optimization strategy combines [torch.compile](./fp16#torchcompile) with various [offloading methods](./memory#offloading) on a quantized model. This approach not only accelerates inference but also helps lower memory-usage.
18
18
19
-
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage.
19
+
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
20
20
21
-
| combination | latency | memory-usage |
21
+
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
22
+
23
+
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.
24
+
25
+
| combination | latency (s) | memory-usage (GB) |
22
26
|---|---|---|
23
-
| quantization, torch.compile |||
24
-
| quantization, torch.compile, model CPU offloading |||
25
-
| quantization, torch.compile, group offloading |||
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
30
+
| quantization, torch.compile, group offloading | 60.235 | 12.2369 |
31
+
<small>These results are benchmarked on Flux with a RTX 4090. The `transformer` and `text_encoder_2` components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
32
+
33
+
> [!TIP]
34
+
> We recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/) for better torch.compile support.
26
35
27
-
This guide will show you how to compile and offload a model.
36
+
This guide will show you how to compile and offload a quantized model.
0 commit comments