Skip to content

Commit 0273726

Browse files
authored
Changed weight map to tensor and fix the refit bug (#3573)
1 parent 07e4643 commit 0273726

File tree

7 files changed

+197
-56
lines changed

7 files changed

+197
-56
lines changed

examples/apps/README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Flux Demo with Torch-TensorRT
2+
3+
This demo showcases the Flux image generation model accelerated using Torch-TensorRT, with support for different precision modes (FP8, INT8, FP16) and dynamic shapes.
4+
5+
6+
## Installation
7+
8+
1. Install the required dependencies:
9+
10+
```bash
11+
pip install gradio==5.29.0 nvidia-modelopt==0.27.1 diffusers==0.33.1 accelerate==1.3.0
12+
```
13+
14+
## Usage
15+
16+
The demo can be run with different configurations:
17+
18+
### Basic Usage (FP16)
19+
20+
```bash
21+
python flux_demo.py
22+
```
23+
24+
### Using Different Precision Modes
25+
26+
- FP8 mode:
27+
```bash
28+
python flux_demo.py --dtype fp8
29+
```
30+
31+
- INT8 mode:
32+
```bash
33+
python flux_demo.py --dtype int8
34+
```
35+
36+
- FP16 mode (default):
37+
```bash
38+
python flux_demo.py --dtype fp16
39+
```
40+
41+
### Additional Options
42+
43+
- Enable dynamic shapes (allows variable batch sizes):
44+
```bash
45+
python flux_demo.py --dynamic_shapes
46+
```
47+
48+
- Low VRAM mode (for GPUs with ≤32GB VRAM):
49+
```bash
50+
python flux_demo.py --low_vram_mode
51+
```
52+
53+
You can combine these options as needed. For example:
54+
```bash
55+
python flux_demo.py --dtype fp8 --dynamic_shapes --low_vram_mode
56+
```

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Union
32

4-
import numpy as np
53
import torch
64
from torch_tensorrt.dynamo._settings import CompilationSettings
75
from torch_tensorrt.dynamo.types import TRTNetwork
@@ -24,10 +22,21 @@ class ConversionContext:
2422
default_factory=CompilationSettings
2523
)
2624
requires_output_allocator: bool = False
27-
weight_refit_map: dict[str, np.array] = field(default_factory=dict)
28-
cpu_weights_reference_holder: dict[str, Union[torch.Tensor]] = field(
29-
default_factory=dict
30-
)
25+
weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict)
26+
cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list)
27+
28+
def record_weight(self, name: str, weight: torch.Tensor) -> None:
29+
"""
30+
Record the weight and name for refitting and CPU reference.
31+
For the refit map, the key is the weight name that appears in the TRT engine and the value is the weight tensor.
32+
For the CPU reference holder, we need to hold the reference to the weight tensor until the whole compilation process is complete.
33+
34+
Args:
35+
name: Name of the weight
36+
weight: Weight to record
37+
"""
38+
self.weight_refit_map[name] = weight
39+
self.cpu_weights_reference_holder.append(weight)
3140

3241
def clear_cpu_weights_reference_holder(self) -> None:
3342
self.cpu_weights_reference_holder.clear()

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def _construct_trt_network_def(self) -> None:
404404
@staticmethod
405405
def find_weight(
406406
weight_name: str,
407-
np_map: dict[str, Any],
407+
weight_refit_map: dict[str, Any],
408408
state_dict: dict[str, Any],
409409
device: torch.device,
410410
) -> str:
@@ -417,7 +417,7 @@ def find_weight(
417417
state_dict: state of the graph module
418418
"""
419419
with unset_fake_temporarily():
420-
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
420+
network_weight = weight_refit_map[weight_name].to(device)
421421
for sd_w_name, sd_weight in state_dict.items():
422422
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
423423
del state_dict[sd_w_name]
@@ -431,8 +431,8 @@ def check_weight_equal(
431431
device: torch.device,
432432
) -> Any:
433433
with unset_fake_temporarily():
434-
if not isinstance(network_weight, torch.Tensor):
435-
network_weight = torch.from_numpy(network_weight).to(device)
434+
if network_weight.device != device:
435+
network_weight = network_weight.to(device)
436436
try:
437437
return sd_weight.shape == network_weight.shape and torch.all(
438438
torch.abs(sd_weight - network_weight) < 0.01
@@ -501,8 +501,8 @@ def _save_weight_mapping(self) -> None:
501501
self.module.to(torch_device)
502502
sd = self.module.state_dict()
503503
weight_name_map: dict[str, Any] = {}
504-
np_map = self.ctx.weight_refit_map
505-
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
504+
weight_refit_map = self.ctx.weight_refit_map
505+
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
506506
net = self.ctx.net
507507
for i in range(net.num_layers):
508508
layer = net[i]
@@ -544,7 +544,7 @@ def _save_weight_mapping(self) -> None:
544544
else:
545545
sd_weight_name = f"{sd_weight_name}.{torch_attr}"
546546

547-
if engine_weight_name in np_map:
547+
if engine_weight_name in weight_refit_map:
548548
weight_name_map[engine_weight_name] = sd_weight_name
549549

550550
# Stage 2: Value mapping
@@ -553,10 +553,10 @@ def _save_weight_mapping(self) -> None:
553553
# There is no direct connection in batch_norm layer. So skip it
554554
pass
555555
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
556-
sd[sd_weight_name], np_map[engine_weight_name], torch_device
556+
sd[sd_weight_name], weight_refit_map[engine_weight_name], torch_device
557557
):
558558
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
559-
engine_weight_name, np_map, sd, torch_device
559+
engine_weight_name, weight_refit_map, sd, torch_device
560560
)
561561
if (
562562
weight_name_map[engine_weight_name] != ""
@@ -567,12 +567,13 @@ def _save_weight_mapping(self) -> None:
567567

568568
weight_name_map[engine_weight_name] = [
569569
weight_name_map[engine_weight_name],
570-
np_map[engine_weight_name].dtype,
570+
weight_refit_map[engine_weight_name].dtype,
571571
]
572572

573573
weight_name_map["constant_mapping"] = constant_mapping
574574
self.weight_name_map = weight_name_map
575-
del np_map, sd
575+
576+
del weight_refit_map, sd
576577
gc.collect()
577578
torch.cuda.empty_cache()
578579

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
import functools
44
import logging
55
import os
6-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
6+
from typing import (
7+
Any,
8+
Callable,
9+
Dict,
10+
List,
11+
Literal,
12+
Optional,
13+
Sequence,
14+
Tuple,
15+
Union,
16+
overload,
17+
)
718

819
import numpy as np
920
import tensorrt as trt
@@ -321,7 +332,16 @@ def cast_int_or_float_to_bool(
321332

322333

323334
def to_trt_weights(
324-
value: Any, target_quantized_type: Optional[trt.DataType] = None
335+
ctx: ConversionContext,
336+
value: torch.Tensor,
337+
name: str,
338+
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT"],
339+
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT"],
340+
target: Optional[Union[Target, str]] = None,
341+
source_ir: Optional[SourceIR] = None,
342+
target_quantized_type: Optional[trt.DataType] = None,
343+
dtype: Optional[trt.DataType] = None,
344+
count: Optional[int] = None,
325345
) -> trt.Weights:
326346
"""
327347
Convert a PyTorch tensor or NumPy array to TensorRT weights.
@@ -336,20 +356,51 @@ def to_trt_weights(
336356
- Input tensors are made contiguous before conversion
337357
- Data type is preserved from the original tensor/array
338358
"""
339-
if isinstance(value, torch.Tensor):
340-
# Tensor must be contiguous before conversion
341-
value = value.contiguous()
342-
value_trt_dtype = _enums.dtype._from(value.dtype).to(trt.DataType)
343-
return trt.Weights(value_trt_dtype, value.data_ptr(), value.nelement())
344-
elif isinstance(value, np.ndarray):
345-
value = np.ascontiguousarray(value)
346-
value_np_dtype = _enums.dtype._from(value.dtype).to(np.dtype, use_default=True)
347-
return trt.Weights(value_np_dtype, value.data, value.size)
348-
else:
359+
if not isinstance(value, torch.Tensor):
349360
raise AssertionError(
350-
f"to_trt_weights can only be called on torch.Tensor or np.ndarray, got an object of type: {type(value)}"
361+
f"to_trt_weights can only be called on torch.Tensor, got an object of type: {type(value)}"
351362
)
352363

364+
# Weight Recording
365+
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT"]
366+
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
367+
assert (
368+
layer_type_name in supported_layer_types
369+
), f"Encountered unsupported layer type: {layer_type_name}. Supported types are: {supported_layer_types}. Manually calling to_trt_weights with a custom layer type is not intended for general use."
370+
assert (
371+
weight_type_name in supported_weight_types
372+
), f"Encountered unsupported weight type: {weight_type_name}. Supported types are: {supported_weight_types}. Manually calling to_trt_weights with a custom weight type is not intended for general use."
373+
374+
if weight_type_name == "CONSTANT" and layer_type_name == "CONSTANT":
375+
weight_name = f"{name} CONSTANT"
376+
ctx.record_weight(weight_name, value)
377+
378+
else:
379+
assert (
380+
target is not None
381+
), "target must be provided if the weight type and layer type is not CONSTANT"
382+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
383+
target_name = (
384+
f"{source_ir}_ops.{target}"
385+
if isinstance(target, str)
386+
else f"{source_ir}_ops.{target.__name__}"
387+
)
388+
389+
weight_name = f"[{layer_type_name}]-[{target_name}]-[{name}] {weight_type_name}"
390+
ctx.record_weight(weight_name, value)
391+
392+
# TRT Weights Creation
393+
394+
# Tensor must be contiguous before conversion
395+
value = value.contiguous()
396+
if dtype is None:
397+
dtype = _enums.dtype._from(value.dtype).to(trt.DataType)
398+
399+
if count is None:
400+
count = value.nelement()
401+
402+
return trt.Weights(dtype, value.data_ptr(), count)
403+
353404

354405
def create_constant(
355406
ctx: ConversionContext,
@@ -405,34 +456,26 @@ def create_constant(
405456
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
406457
)
407458
shape[-1] = shape[-1] * 2
408-
weights = trt.Weights(
409-
type=trt.DataType.FP4,
410-
ptr=torch_value.data_ptr(),
459+
weights = to_trt_weights(
460+
ctx,
461+
torch_value,
462+
name,
463+
"CONSTANT",
464+
"CONSTANT",
465+
dtype=trt.DataType.FP4,
411466
count=torch_value.numel() * 2,
412467
)
413468
constant = ctx.net.add_constant(
414469
shape,
415470
weights,
416471
)
417472
constant.name = name
418-
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
419473
return constant.get_output(0)
420474

421-
# TODO: Refit map uses numpy arrays. Remove this once refit is updated to use torch.Tensor
422-
if torch_value.dtype == torch.bfloat16:
423-
torch_value_fp32 = torch_value.to(torch.float32)
424-
numpy_value = torch_value_fp32.numpy()
425-
else:
426-
numpy_value = torch_value.numpy()
427-
428-
# Used for refit
429-
ctx.weight_refit_map[name + " CONSTANT"] = numpy_value.reshape(-1)
430-
431-
# This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation.
432-
ctx.cpu_weights_reference_holder[name] = torch_value
475+
# Record the weight in ctx for refit and cpu memory reference
433476

434477
# Convert the torch.Tensor to a trt.Weights object
435-
trt_weights = to_trt_weights(torch_value)
478+
trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT")
436479
constant = ctx.net.add_constant(
437480
shape,
438481
trt_weights,

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ def convNd(
5555
# Process bias terms
5656
if isinstance(bias, (torch.Tensor, np.ndarray)):
5757
bias = to_torch(bias, dtype=input.dtype)
58-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
58+
bias = to_trt_weights(
59+
ctx,
60+
bias,
61+
name,
62+
layer_type_name="CONVOLUTION",
63+
weight_type_name="BIAS",
64+
target=target,
65+
source_ir=source_ir,
66+
)
5967

6068
elif isinstance(bias, TRTTensor):
6169
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
@@ -85,7 +93,15 @@ def convNd(
8593

8694
num_output_maps = weight.shape[0]
8795
kernel_shape = weight.shape[2:]
88-
weight = to_trt_weights(weight)
96+
weight = to_trt_weights(
97+
ctx,
98+
weight,
99+
name,
100+
layer_type_name="CONVOLUTION",
101+
weight_type_name="KERNEL",
102+
target=target,
103+
source_ir=source_ir,
104+
)
89105

90106
else:
91107
raise RuntimeError(
@@ -105,6 +121,9 @@ def convNd(
105121
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
106122
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
107123
)
124+
125+
set_layer_name(conv_layer, target, name, source_ir)
126+
108127
# If the weight is a TRTTensor, set it as an input of the layer
109128
if isinstance(weight, TRTTensor):
110129
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
@@ -145,8 +164,6 @@ def convNd(
145164
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
146165
)
147166

148-
set_layer_name(conv_layer, target, name, source_ir)
149-
150167
# Set relevant attributes of convolution layer
151168
if padding is not None:
152169
conv_layer.padding_nd = padding

0 commit comments

Comments
 (0)