Skip to content

Commit 3524e18

Browse files
cherry pick 3689 to 2.8 release:flux fp4 (#3696)
1 parent c19791f commit 3524e18

File tree

5 files changed

+69
-25
lines changed

5 files changed

+69
-25
lines changed

examples/apps/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ python flux_demo.py
2323

2424
### Using Different Precision Modes
2525

26+
- FP4 mode:
27+
```bash
28+
python flux_demo.py --dtype fp4
29+
```
30+
2631
- FP8 mode:
2732
```bash
2833
python flux_demo.py --dtype fp8

examples/apps/flux_demo.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
from diffusers import FluxPipeline
1313
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
1414

15-
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
16-
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
17-
from register_sdpa import *
18-
1915
DEVICE = "cuda:0"
2016

2117

@@ -24,8 +20,17 @@ def compile_model(
2420
) -> tuple[
2521
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
2622
]:
23+
use_explicit_typing = False
24+
if args.dtype == "fp4":
25+
use_explicit_typing = True
26+
enabled_precisions = {torch.float4_e2m1fn_x2}
27+
ptq_config = mtq.NVFP4_DEFAULT_CFG
28+
if args.fp4_mha:
29+
from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG
30+
31+
ptq_config = NVFP4_FP8_MHA_CONFIG
2732

28-
if args.dtype == "fp8":
33+
elif args.dtype == "fp8":
2934
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
3035
ptq_config = mtq.FP8_DEFAULT_CFG
3136

@@ -107,26 +112,33 @@ def forward_loop(mod):
107112
"enabled_precisions": enabled_precisions,
108113
"truncate_double": True,
109114
"min_block_size": 1,
110-
"use_python_runtime": True,
115+
"use_python_runtime": False,
111116
"immutable_weights": False,
112-
"offload_module_to_cpu": True,
117+
"offload_module_to_cpu": args.low_vram_mode,
118+
"use_explicit_typing": use_explicit_typing,
113119
}
114120
if args.low_vram_mode:
115121
pipe.remove_all_hooks()
116122
pipe.enable_sequential_cpu_offload()
117123
remove_hook_from_module(pipe.transformer, recurse=True)
118124
pipe.transformer.to(DEVICE)
125+
119126
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
120127
if dynamic_shapes:
121128
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
122129
pipe.transformer = trt_gm
123-
130+
seed = 42
124131
image = pipe(
125-
"Test",
132+
[
133+
"enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"
134+
],
126135
output_type="pil",
127-
num_inference_steps=2,
136+
num_inference_steps=30,
128137
num_images_per_prompt=batch_size,
138+
generator=torch.Generator("cuda").manual_seed(seed),
129139
).images
140+
print(f"generated {len(image)} images")
141+
image[0].save("/tmp/forest.png")
130142

131143
torch.cuda.empty_cache()
132144

@@ -242,12 +254,16 @@ def main(args):
242254
parser = argparse.ArgumentParser(
243255
description="Run Flux quantization with different dtypes"
244256
)
245-
246257
parser.add_argument(
247258
"--dtype",
248-
choices=["fp8", "int8", "fp16"],
259+
choices=["fp4", "fp8", "int8", "fp16"],
249260
default="fp16",
250-
help="Select the data type to use (fp8 or int8 or fp16)",
261+
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
262+
)
263+
parser.add_argument(
264+
"--fp4_mha",
265+
action="store_true",
266+
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
251267
)
252268
parser.add_argument(
253269
"--low_vram_mode",

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ def cross_compile_for_windows(
258258

259259
if use_explicit_typing:
260260
if len(enabled_precisions) != 1 or not any(
261-
x in enabled_precisions for x in {torch.float32, dtype.f32}
261+
x in enabled_precisions
262+
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
262263
):
263264
raise AssertionError(
264-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
265+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
265266
)
266267

267268
if use_fp32_acc:
@@ -591,10 +592,11 @@ def compile(
591592

592593
if use_explicit_typing:
593594
if len(enabled_precisions) != 1 or not any(
594-
x in enabled_precisions for x in {torch.float32, dtype.f32}
595+
x in enabled_precisions
596+
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
595597
):
596598
raise AssertionError(
597-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
599+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
598600
)
599601

600602
if use_fp32_acc:

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram:
334334
# Check if any quantization precision is enabled
335335
if self.enabled_precisions and any(
336336
precision in self.enabled_precisions
337-
for precision in (torch.float8_e4m3fn, torch.int8)
337+
for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2)
338338
):
339339
try:
340340
from modelopt.torch.quantization.utils import export_torch_mode

tools/perf/Flux/flux_perf.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,29 @@
33
import sys
44
from time import time
55

6+
import torch
7+
68
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps"))
79
from flux_demo import compile_model
810

911

1012
def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
13+
print(f"Running warmup with {batch_size=} {inference_step=} iterations=10")
14+
# warmup
15+
for i in range(10):
16+
start = time()
17+
images = pipe(
18+
prompt,
19+
output_type="pil",
20+
num_inference_steps=inference_step,
21+
num_images_per_prompt=batch_size,
22+
).images
23+
print(
24+
f"Warmup {i} done in {time() - start} seconds, with {batch_size=} {inference_step=}, generated {len(images)} images"
25+
)
1126

27+
# actual benchmark
28+
print(f"Running benchmark with {batch_size=} {inference_step=} {iterations=}")
1229
start = time()
1330
for i in range(iterations):
1431
image = pipe(
@@ -18,32 +35,36 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
1835
num_images_per_prompt=batch_size,
1936
).images
2037
end = time()
21-
2238
print(f"Batch Size: {batch_size}")
2339
print("Time Elapse for", iterations, "iterations:", end - start)
2440
print(
2541
"Average Latency Per Step:",
2642
(end - start) / inference_step / iterations / batch_size,
2743
)
28-
return image
44+
return
2945

3046

3147
def main(args):
48+
print(f"Running flux_perfwith args: {args}")
3249
pipe, backbone, trt_gm = compile_model(args)
33-
for batch_size in range(1, args.max_batch_size + 1):
34-
benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3)
50+
51+
benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3)
3552

3653

3754
if __name__ == "__main__":
3855
parser = argparse.ArgumentParser(
3956
description="Run Flux quantization with different dtypes"
4057
)
41-
4258
parser.add_argument(
4359
"--dtype",
44-
choices=["fp8", "int8", "fp16"],
60+
choices=["fp4", "fp8", "int8", "fp16"],
4561
default="fp16",
46-
help="Select the data type to use (fp8 or int8 or fp16)",
62+
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
63+
)
64+
parser.add_argument(
65+
"--fp4_mha",
66+
action="store_true",
67+
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
4768
)
4869
parser.add_argument(
4970
"--low_vram_mode",

0 commit comments

Comments
 (0)