Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ conda_recipes/archive_files/vraystd_*
**/*.pyc
**/*.aex
**/*.conda
**/*.safetensors
109 changes: 109 additions & 0 deletions job_bundles/lora_image_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Diffusers LoRA Training and Image Generation

Train your own AI image models with just 10-100 photos, then generate unlimited new images.

## What can you do with this?

**Train on your pet:** Collect ~50 photos of your dog from different angles, train a LoRA for 30 minutes, then generate images of your dog surfing, wearing a tuxedo, or sitting on the moon.

**Capture an art style:** Gather 20-30 images in a specific illustration style, train a LoRA, then apply that style to any prompt—"a cyberpunk city in [your style]" or "portrait of a knight in [your style]."

**Product photography:** Train on photos of your product, then generate it in various settings, lighting conditions, or contexts without expensive photo shoots.

## How it works

These AWS Deadline Cloud job bundles use [Hugging Face Diffusers](https://huggingface.co/docs/diffusers/index) to fine-tune Stable Diffusion with LoRA (Low-Rank Adaptation)—a technique that creates small, efficient model adapters from your images.

**Workflow:**
1. **Prepare Training Data** - Collect images of your subject in a directory
2. **Train LoRA** - Submit training job using the `lora_training` bundle
3. **Download Training Output** - Download the trained LoRA weights
4. **Generate Images** - Submit generation job using the `image_generation` bundle with your trained LoRA

## Job bundles

### 1. lora_training
Train custom LoRA adapters for Stable Diffusion models using your own image datasets.

**Fleet requirements:**
- GPU with 16GB+ VRAM (or CPU for testing)
- Linux OS

**Key Parameters:**
- **Base Model**: SD 1.4, SD 1.5, SD 2.1, or SDXL
- **Dataset Path**: Local directory containing training images (.jpg, .png, .jpeg)
- **Instance Prompt**: Text describing your training images
- **Max Training Steps**: Number of training iterations
- **LoRA Rank**: Rank of LoRA matrices
- **LoRA Alpha**: Scaling factor for LoRA strength
- **Output Directory**: Where to save trained LoRA weights

**Example:**
Use the job bundle GUI submitter to select parameters values:

```bash
deadline bundle gui-submit ./lora_training
```

Or, use the CLI submitter:

```bash
deadline bundle submit ./lora_training \
--parameter DatasetPath=./sample_data \
--parameter InstancePrompt="a photo of Luna, my dog" \
--parameter OutputDir=/tmp/lora_output \
--parameter MaxTrainSteps=500 \
--parameter LoRARank=4 \
--parameter LoRAAlpha=16
```

**Output:** LoRA weights saved as `pytorch_lora_weights.safetensors` with embedded metadata (base model, rank, alpha, instance prompt). The generation job reads this metadata automatically—no need to re-enter training parameters.

**Download Output:**
After training completes, download the LoRA weights to use in generation:
```bash
deadline job download-output --job-id <training-job-id>
```

---

### 2. image_generation
Generate images using Stable Diffusion with trained LoRA adapters.

**Fleet requirements:**
- Linux worker
- GPU recommended for faster generation (CPU will work but be slow)
- Trained LoRA adapter from lora_training job

**Key Parameters:**
- **LoRA Path**: Path to trained LoRA `.safetensors` file (base model and rank are read from embedded metadata)
- **Prompt**: Text description of image to generate
- **Negative Prompt**: What to avoid in generation
- **Number of Images**: Total images to generate (parallelized)
- **Width/Height**: Output dimensions (default: 512x512)
- **Inference Steps**: Denoising steps (default: 50)
- **Seed**: Random seed for reproducibility (-1 for random)

**Example:**
Use the job bundle GUI submitter to select parameters values:

```bash
deadline bundle gui-submit ./image_generation
```

Or, use the CLI submitter:

```bash
deadline bundle submit ./image_generation \
--parameter LoRAPath=/tmp/lora_output/pytorch_lora_weights.safetensors \
--parameter Prompt="Luna, my dog, wearing a tuxedo" \
--parameter OutputDir=/tmp/dog_images \
--parameter NumImages=10
```

**Output:** PNG images saved as `image_0001.png`, `image_0002.png`, etc.

**Note:** After job completion, download outputs with:
```bash
deadline job download-output --job-id <job-id>
```
109 changes: 109 additions & 0 deletions job_bundles/lora_image_diffusion/image_generation/generate_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import argparse
import os
import torch
from diffusers import StableDiffusionPipeline
from safetensors import safe_open
from safetensors.torch import load_file
from peft import LoraConfig, PeftModel, set_peft_model_state_dict

# Set cache directory
cache_dir = os.path.expanduser("~/.models/huggingface")
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_DATASETS_CACHE"] = cache_dir

parser = argparse.ArgumentParser()
parser.add_argument("--lora-path", required=True)
parser.add_argument("--prompt", required=True)
parser.add_argument("--negative-prompt", required=True)
parser.add_argument("--width", type=int, required=True)
parser.add_argument("--height", type=int, required=True)
parser.add_argument("--num-inference-steps", type=int, required=True)
parser.add_argument("--guidance-scale", type=float, required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--image-index", type=int, required=True)
parser.add_argument("--output-dir", required=True)
args = parser.parse_args()


def load_lora_with_metadata(filepath):
"""Load LoRA weights and extract embedded metadata from safetensors file."""
if not os.path.isfile(filepath):
raise ValueError(f"LoRA weights file not found: {filepath}")

with safe_open(filepath, framework="pt") as f:
metadata = f.metadata()

if not metadata:
raise ValueError(
f"No metadata found in LoRA file '{filepath}'. "
"This file may have been created with an older version of the training script."
)

required_keys = ["base_model", "lora_rank"]
missing = [k for k in required_keys if k not in metadata]
if missing:
raise ValueError(f"LoRA file missing required metadata: {missing}")

state_dict = load_file(filepath)
return state_dict, metadata


# Load LoRA weights and metadata
lora_state_dict, lora_metadata = load_lora_with_metadata(args.lora_path)
base_model = lora_metadata["base_model"]
lora_rank = int(lora_metadata["lora_rank"])
lora_alpha = int(lora_metadata.get("lora_alpha", lora_rank))

print(f"Loaded LoRA metadata: base_model={base_model}, rank={lora_rank}, alpha={lora_alpha}")
print(f"Loaded {len(lora_state_dict)} LoRA tensors from file")

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using device: {device}")

pipe = StableDiffusionPipeline.from_pretrained(
base_model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False,
).to(device)

lora_config = LoraConfig(
r=lora_rank, lora_alpha=lora_alpha, target_modules=["to_k", "to_q", "to_v", "to_out.0"]
)
pipe.unet = PeftModel(pipe.unet, lora_config)

set_peft_model_state_dict(pipe.unet, lora_state_dict)

lora_params_loaded = len(
[k for k in pipe.unet.state_dict().keys() if "lora" in k.lower()]
)
if lora_params_loaded == 0:
raise RuntimeError("FAILED: 0 LoRA parameters loaded into model!")

print(f"Successfully loaded {lora_params_loaded} LoRA parameters into UNet")
print(f"LoRA scale: {lora_alpha / lora_rank}")

if args.seed == -1:
seed = args.image_index * 12345
else:
seed = args.seed

generator = torch.Generator(device=device).manual_seed(seed)

image = pipe(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
width=args.width,
height=args.height,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=generator,
).images[0]

filename = f"image_{args.image_index:04d}.png"
output_path = os.path.join(args.output_dir, filename)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
image.save(output_path)
print(f"Generated image saved to {output_path}")
171 changes: 171 additions & 0 deletions job_bundles/lora_image_diffusion/image_generation/template.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
specificationVersion: 'jobtemplate-2023-09'
name: Diffusers Image Generation
description: Generate images using Stable Diffusion with optional LoRA adapters

parameterDefinitions:
- name: LoRAPath
type: PATH
objectType: FILE
dataFlow: IN
userInterface:
control: CHOOSE_INPUT_FILE
label: LoRA Weights
groupLabel: Model Settings
description: Path to trained LoRA weights .safetensors file (required)

- name: Prompt
type: STRING
userInterface:
control: LINE_EDIT
label: Prompt
groupLabel: Generation Settings
description: Text prompt for image generation
default: "a photo of an astronaut riding a horse on mars"

- name: NegativePrompt
type: STRING
userInterface:
control: LINE_EDIT
label: Negative Prompt
groupLabel: Generation Settings
description: Negative prompt to guide what to avoid
default: "blurry, low quality"

- name: OutputDir
type: PATH
objectType: DIRECTORY
dataFlow: OUT
userInterface:
control: CHOOSE_DIRECTORY
label: Output Directory
groupLabel: Output Settings

- name: NumImages
type: INT
userInterface:
control: SPIN_BOX
label: Number of Images
groupLabel: Generation Settings
description: Total number of images to generate
default: 10
minValue: 1
maxValue: 1000

- name: Width
type: INT
userInterface:
control: SPIN_BOX
label: Width
groupLabel: Generation Settings
default: 512
minValue: 256
maxValue: 1024

- name: Height
type: INT
userInterface:
control: SPIN_BOX
label: Height
groupLabel: Generation Settings
default: 512
minValue: 256
maxValue: 1024

- name: NumInferenceSteps
type: INT
userInterface:
control: SPIN_BOX
label: Inference Steps
groupLabel: Generation Settings
description: Number of denoising steps
default: 50
minValue: 1
maxValue: 150

- name: GuidanceScale
type: STRING
userInterface:
control: LINE_EDIT
label: Guidance Scale
groupLabel: Generation Settings
description: Classifier-free guidance scale
default: "7.5"

- name: Seed
type: STRING
userInterface:
control: LINE_EDIT
label: Seed
groupLabel: Generation Settings
description: Random seed for reproducibility (-1 for random)
default: "-1"

- name: GenerateScript
description: Generation script
userInterface:
control: HIDDEN
type: PATH
objectType: FILE
dataFlow: IN
default: generate_image.py

- name: CondaChannels
type: STRING
userInterface:
control: HIDDEN
default: conda-forge
description: A list of conda channels to get packages from. The job expects a Queue Environment to handle this.

- name: CondaPackages
type: STRING
userInterface:
control: HIDDEN
default: "pytorch diffusers transformers accelerate peft safetensors pillow"
description: >
Conda packages to install for the environment. These packages must be available in a channel
made available by a Queue Environment.

jobEnvironments:
- name: UnbufferedOutput
variables:
# Turn off buffering of Python's output
PYTHONUNBUFFERED: "True"

steps:
- name: GenerateImages
parameterSpace:
taskParameterDefinitions:
- name: ImageIndex
type: INT
range: "1-{{Param.NumImages}}"
script:
actions:
onRun:
command: bash
args: ['-c', '{{Task.File.Run}}']
embeddedFiles:
- name: Run
type: TEXT
runnable: true
data: |
#!/bin/bash
set -euo pipefail

python -u {{Param.GenerateScript}} \
--lora-path "{{Param.LoRAPath}}" \
--prompt "{{Param.Prompt}}" \
--negative-prompt "{{Param.NegativePrompt}}" \
--width {{Param.Width}} \
--height {{Param.Height}} \
--num-inference-steps {{Param.NumInferenceSteps}} \
--guidance-scale {{Param.GuidanceScale}} \
--seed {{Param.Seed}} \
--image-index {{Task.Param.ImageIndex}} \
--output-dir "{{Param.OutputDir}}"
hostRequirements:
attributes:
- name: attr.worker.os.family
anyOf: [linux]
amounts:
- name: amount.worker.gpu
min: 1
Loading