|
4 | 4 | # SPDX-License-Identifier: BSD-3-Clause |
5 | 5 | # |
6 | 6 | # ----------------------------------------------------------------------------- |
| 7 | + |
| 8 | +""" |
| 9 | +FLUX.1-schnell Image Generation Example |
| 10 | +
|
| 11 | +This example demonstrates how to use the QEFFFluxPipeline to generate images |
| 12 | +using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a |
| 13 | +fast, distilled version of the FLUX.1 text-to-image model optimized for |
| 14 | +speed with minimal quality loss. |
| 15 | +
|
| 16 | +Key Features: |
| 17 | +- Fast inference with only 4 steps |
| 18 | +- High-quality image generation from text prompts |
| 19 | +- Optimized for Qualcomm Cloud AI 100 using ONNX runtime |
| 20 | +- Deterministic output using fixed random seed |
| 21 | +
|
| 22 | +Output: |
| 23 | +- Generates an image based on the text prompt |
| 24 | +- Saves the image as 'cat_with_sign.png' in the current directory |
| 25 | +""" |
| 26 | + |
7 | 27 | import torch |
8 | 28 |
|
9 | 29 | from QEfficient import QEFFFluxPipeline |
10 | 30 |
|
11 | | -pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", use_onnx_function=True) |
| 31 | +# Initialize the FLUX.1-schnell pipeline from pretrained weights |
| 32 | +# use_onnx_function=True enables ONNX-based optimizations for faster compilation |
| 33 | +pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", use_onnx_function=False) |
12 | 34 |
|
| 35 | +# Generate an image from a text prompt |
13 | 36 | output = pipeline( |
14 | 37 | prompt="A cat holding a sign that says hello world", |
15 | 38 | guidance_scale=0.0, |
16 | 39 | num_inference_steps=4, |
17 | 40 | max_sequence_length=256, |
18 | 41 | generator=torch.manual_seed(42), |
19 | 42 | ) |
| 43 | + |
| 44 | +# Extract the generated image from the output |
20 | 45 | image = output.images[0] |
| 46 | + |
| 47 | +# Save the generated image to disk |
21 | 48 | image.save("cat_with_sign.png") |
22 | 49 |
|
| 50 | +# Print the output object (contains perf info) |
23 | 51 | print(output) |
0 commit comments