generated from amazon-archives/__template_MIT-0
-
Notifications
You must be signed in to change notification settings - Fork 47
feat: add sample bundles for training a LoRA and generating images fr… #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
crowecawcaw
wants to merge
8
commits into
aws-deadline:mainline
Choose a base branch
from
crowecawcaw:lora
base: mainline
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a7f64e7
feat: add sample bundles for training a LoRA and generating images fr…
crowecawcaw dd5e252
feat: add GPU requirement to diffusion job bundles
crowecawcaw 29b8317
Merge branch 'mainline' into lora
crowecawcaw 29935ab
make accelerator more flexible
crowecawcaw 8f0a026
use conda
crowecawcaw 13c00cb
Merge branch 'mainline' into lora
crowecawcaw d592571
refactor: Make Python output in lora sample unbuffered
mwiebe fd6f80f
refactor: Clean up the intro with examples, modify the code to preser…
mwiebe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,3 +16,4 @@ conda_recipes/archive_files/vraystd_* | |
| **/*.pyc | ||
| **/*.aex | ||
| **/*.conda | ||
| **/*.safetensors | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Diffusers LoRA Training and Image Generation | ||
|
|
||
| Train your own AI image models with just 10-100 photos, then generate unlimited new images. | ||
|
|
||
| ## What can you do with this? | ||
|
|
||
| **Train on your pet:** Collect ~50 photos of your dog from different angles, train a LoRA for 30 minutes, then generate images of your dog surfing, wearing a tuxedo, or sitting on the moon. | ||
|
|
||
| **Capture an art style:** Gather 20-30 images in a specific illustration style, train a LoRA, then apply that style to any prompt—"a cyberpunk city in [your style]" or "portrait of a knight in [your style]." | ||
|
|
||
| **Product photography:** Train on photos of your product, then generate it in various settings, lighting conditions, or contexts without expensive photo shoots. | ||
|
|
||
| ## How it works | ||
|
|
||
| These AWS Deadline Cloud job bundles use [Hugging Face Diffusers](https://huggingface.co/docs/diffusers/index) to fine-tune Stable Diffusion with LoRA (Low-Rank Adaptation)—a technique that creates small, efficient model adapters from your images. | ||
|
|
||
| **Workflow:** | ||
| 1. **Prepare Training Data** - Collect images of your subject in a directory | ||
| 2. **Train LoRA** - Submit training job using the `lora_training` bundle | ||
| 3. **Download Training Output** - Download the trained LoRA weights | ||
| 4. **Generate Images** - Submit generation job using the `image_generation` bundle with your trained LoRA | ||
|
|
||
| ## Job bundles | ||
|
|
||
| ### 1. lora_training | ||
| Train custom LoRA adapters for Stable Diffusion models using your own image datasets. | ||
|
|
||
| **Fleet requirements:** | ||
| - GPU with 16GB+ VRAM (or CPU for testing) | ||
| - Linux OS | ||
|
|
||
| **Key Parameters:** | ||
| - **Base Model**: SD 1.4, SD 1.5, SD 2.1, or SDXL | ||
| - **Dataset Path**: Local directory containing training images (.jpg, .png, .jpeg) | ||
| - **Instance Prompt**: Text describing your training images | ||
| - **Max Training Steps**: Number of training iterations | ||
| - **LoRA Rank**: Rank of LoRA matrices | ||
| - **LoRA Alpha**: Scaling factor for LoRA strength | ||
| - **Output Directory**: Where to save trained LoRA weights | ||
|
|
||
| **Example:** | ||
| Use the job bundle GUI submitter to select parameters values: | ||
|
|
||
| ```bash | ||
| deadline bundle gui-submit ./lora_training | ||
| ``` | ||
|
|
||
| Or, use the CLI submitter: | ||
|
|
||
| ```bash | ||
| deadline bundle submit ./lora_training \ | ||
| --parameter DatasetPath=./sample_data \ | ||
| --parameter InstancePrompt="a photo of Luna, my dog" \ | ||
| --parameter OutputDir=/tmp/lora_output \ | ||
| --parameter MaxTrainSteps=500 \ | ||
| --parameter LoRARank=4 \ | ||
| --parameter LoRAAlpha=16 | ||
| ``` | ||
|
|
||
| **Output:** LoRA weights saved as `pytorch_lora_weights.safetensors` with embedded metadata (base model, rank, alpha, instance prompt). The generation job reads this metadata automatically—no need to re-enter training parameters. | ||
|
|
||
| **Download Output:** | ||
| After training completes, download the LoRA weights to use in generation: | ||
| ```bash | ||
| deadline job download-output --job-id <training-job-id> | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ### 2. image_generation | ||
| Generate images using Stable Diffusion with trained LoRA adapters. | ||
|
|
||
| **Fleet requirements:** | ||
| - Linux worker | ||
| - GPU recommended for faster generation (CPU will work but be slow) | ||
| - Trained LoRA adapter from lora_training job | ||
|
|
||
| **Key Parameters:** | ||
| - **LoRA Path**: Path to trained LoRA `.safetensors` file (base model and rank are read from embedded metadata) | ||
| - **Prompt**: Text description of image to generate | ||
| - **Negative Prompt**: What to avoid in generation | ||
| - **Number of Images**: Total images to generate (parallelized) | ||
| - **Width/Height**: Output dimensions (default: 512x512) | ||
| - **Inference Steps**: Denoising steps (default: 50) | ||
| - **Seed**: Random seed for reproducibility (-1 for random) | ||
|
|
||
| **Example:** | ||
| Use the job bundle GUI submitter to select parameters values: | ||
|
|
||
| ```bash | ||
| deadline bundle gui-submit ./image_generation | ||
| ``` | ||
|
|
||
| Or, use the CLI submitter: | ||
|
|
||
| ```bash | ||
| deadline bundle submit ./image_generation \ | ||
| --parameter LoRAPath=/tmp/lora_output/pytorch_lora_weights.safetensors \ | ||
| --parameter Prompt="Luna, my dog, wearing a tuxedo" \ | ||
| --parameter OutputDir=/tmp/dog_images \ | ||
| --parameter NumImages=10 | ||
| ``` | ||
|
|
||
| **Output:** PNG images saved as `image_0001.png`, `image_0002.png`, etc. | ||
|
|
||
| **Note:** After job completion, download outputs with: | ||
| ```bash | ||
| deadline job download-output --job-id <job-id> | ||
| ``` |
109 changes: 109 additions & 0 deletions
109
job_bundles/lora_image_diffusion/image_generation/generate_image.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| import argparse | ||
| import os | ||
| import torch | ||
| from diffusers import StableDiffusionPipeline | ||
| from safetensors import safe_open | ||
| from safetensors.torch import load_file | ||
| from peft import LoraConfig, PeftModel, set_peft_model_state_dict | ||
|
|
||
| # Set cache directory | ||
| cache_dir = os.path.expanduser("~/.models/huggingface") | ||
| os.makedirs(cache_dir, exist_ok=True) | ||
| os.environ["HF_HOME"] = cache_dir | ||
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | ||
| os.environ["HF_DATASETS_CACHE"] = cache_dir | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--lora-path", required=True) | ||
| parser.add_argument("--prompt", required=True) | ||
| parser.add_argument("--negative-prompt", required=True) | ||
| parser.add_argument("--width", type=int, required=True) | ||
| parser.add_argument("--height", type=int, required=True) | ||
| parser.add_argument("--num-inference-steps", type=int, required=True) | ||
| parser.add_argument("--guidance-scale", type=float, required=True) | ||
| parser.add_argument("--seed", type=int, required=True) | ||
| parser.add_argument("--image-index", type=int, required=True) | ||
| parser.add_argument("--output-dir", required=True) | ||
| args = parser.parse_args() | ||
|
|
||
|
|
||
| def load_lora_with_metadata(filepath): | ||
| """Load LoRA weights and extract embedded metadata from safetensors file.""" | ||
| if not os.path.isfile(filepath): | ||
| raise ValueError(f"LoRA weights file not found: {filepath}") | ||
|
|
||
| with safe_open(filepath, framework="pt") as f: | ||
| metadata = f.metadata() | ||
|
|
||
| if not metadata: | ||
| raise ValueError( | ||
| f"No metadata found in LoRA file '{filepath}'. " | ||
| "This file may have been created with an older version of the training script." | ||
| ) | ||
|
|
||
| required_keys = ["base_model", "lora_rank"] | ||
| missing = [k for k in required_keys if k not in metadata] | ||
| if missing: | ||
| raise ValueError(f"LoRA file missing required metadata: {missing}") | ||
|
|
||
| state_dict = load_file(filepath) | ||
| return state_dict, metadata | ||
|
|
||
|
|
||
| # Load LoRA weights and metadata | ||
| lora_state_dict, lora_metadata = load_lora_with_metadata(args.lora_path) | ||
| base_model = lora_metadata["base_model"] | ||
| lora_rank = int(lora_metadata["lora_rank"]) | ||
| lora_alpha = int(lora_metadata.get("lora_alpha", lora_rank)) | ||
|
|
||
| print(f"Loaded LoRA metadata: base_model={base_model}, rank={lora_rank}, alpha={lora_alpha}") | ||
| print(f"Loaded {len(lora_state_dict)} LoRA tensors from file") | ||
|
|
||
| device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" | ||
| print(f"Using device: {device}") | ||
|
|
||
| pipe = StableDiffusionPipeline.from_pretrained( | ||
| base_model, | ||
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | ||
| safety_checker=None, | ||
| requires_safety_checker=False, | ||
| ).to(device) | ||
|
|
||
| lora_config = LoraConfig( | ||
| r=lora_rank, lora_alpha=lora_alpha, target_modules=["to_k", "to_q", "to_v", "to_out.0"] | ||
| ) | ||
| pipe.unet = PeftModel(pipe.unet, lora_config) | ||
|
|
||
| set_peft_model_state_dict(pipe.unet, lora_state_dict) | ||
|
|
||
| lora_params_loaded = len( | ||
| [k for k in pipe.unet.state_dict().keys() if "lora" in k.lower()] | ||
| ) | ||
| if lora_params_loaded == 0: | ||
| raise RuntimeError("FAILED: 0 LoRA parameters loaded into model!") | ||
|
|
||
| print(f"Successfully loaded {lora_params_loaded} LoRA parameters into UNet") | ||
| print(f"LoRA scale: {lora_alpha / lora_rank}") | ||
|
|
||
| if args.seed == -1: | ||
| seed = args.image_index * 12345 | ||
| else: | ||
| seed = args.seed | ||
|
|
||
| generator = torch.Generator(device=device).manual_seed(seed) | ||
|
|
||
| image = pipe( | ||
| prompt=args.prompt, | ||
| negative_prompt=args.negative_prompt, | ||
| width=args.width, | ||
| height=args.height, | ||
| num_inference_steps=args.num_inference_steps, | ||
| guidance_scale=args.guidance_scale, | ||
| generator=generator, | ||
| ).images[0] | ||
|
|
||
| filename = f"image_{args.image_index:04d}.png" | ||
| output_path = os.path.join(args.output_dir, filename) | ||
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | ||
| image.save(output_path) | ||
| print(f"Generated image saved to {output_path}") |
171 changes: 171 additions & 0 deletions
171
job_bundles/lora_image_diffusion/image_generation/template.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| specificationVersion: 'jobtemplate-2023-09' | ||
| name: Diffusers Image Generation | ||
| description: Generate images using Stable Diffusion with optional LoRA adapters | ||
|
|
||
| parameterDefinitions: | ||
| - name: LoRAPath | ||
| type: PATH | ||
| objectType: FILE | ||
| dataFlow: IN | ||
| userInterface: | ||
| control: CHOOSE_INPUT_FILE | ||
| label: LoRA Weights | ||
| groupLabel: Model Settings | ||
| description: Path to trained LoRA weights .safetensors file (required) | ||
|
|
||
| - name: Prompt | ||
| type: STRING | ||
| userInterface: | ||
| control: LINE_EDIT | ||
| label: Prompt | ||
| groupLabel: Generation Settings | ||
| description: Text prompt for image generation | ||
| default: "a photo of an astronaut riding a horse on mars" | ||
|
|
||
| - name: NegativePrompt | ||
| type: STRING | ||
| userInterface: | ||
| control: LINE_EDIT | ||
| label: Negative Prompt | ||
| groupLabel: Generation Settings | ||
| description: Negative prompt to guide what to avoid | ||
| default: "blurry, low quality" | ||
|
|
||
| - name: OutputDir | ||
| type: PATH | ||
| objectType: DIRECTORY | ||
| dataFlow: OUT | ||
| userInterface: | ||
| control: CHOOSE_DIRECTORY | ||
| label: Output Directory | ||
| groupLabel: Output Settings | ||
|
|
||
| - name: NumImages | ||
| type: INT | ||
| userInterface: | ||
| control: SPIN_BOX | ||
| label: Number of Images | ||
| groupLabel: Generation Settings | ||
| description: Total number of images to generate | ||
| default: 10 | ||
| minValue: 1 | ||
| maxValue: 1000 | ||
|
|
||
| - name: Width | ||
| type: INT | ||
| userInterface: | ||
| control: SPIN_BOX | ||
| label: Width | ||
| groupLabel: Generation Settings | ||
| default: 512 | ||
| minValue: 256 | ||
| maxValue: 1024 | ||
|
|
||
| - name: Height | ||
| type: INT | ||
| userInterface: | ||
| control: SPIN_BOX | ||
| label: Height | ||
| groupLabel: Generation Settings | ||
| default: 512 | ||
| minValue: 256 | ||
| maxValue: 1024 | ||
|
|
||
| - name: NumInferenceSteps | ||
| type: INT | ||
| userInterface: | ||
| control: SPIN_BOX | ||
| label: Inference Steps | ||
| groupLabel: Generation Settings | ||
| description: Number of denoising steps | ||
| default: 50 | ||
| minValue: 1 | ||
| maxValue: 150 | ||
|
|
||
| - name: GuidanceScale | ||
| type: STRING | ||
| userInterface: | ||
| control: LINE_EDIT | ||
| label: Guidance Scale | ||
| groupLabel: Generation Settings | ||
| description: Classifier-free guidance scale | ||
| default: "7.5" | ||
|
|
||
| - name: Seed | ||
| type: STRING | ||
| userInterface: | ||
| control: LINE_EDIT | ||
| label: Seed | ||
| groupLabel: Generation Settings | ||
| description: Random seed for reproducibility (-1 for random) | ||
| default: "-1" | ||
|
|
||
| - name: GenerateScript | ||
| description: Generation script | ||
| userInterface: | ||
| control: HIDDEN | ||
| type: PATH | ||
| objectType: FILE | ||
| dataFlow: IN | ||
| default: generate_image.py | ||
|
|
||
| - name: CondaChannels | ||
| type: STRING | ||
| userInterface: | ||
| control: HIDDEN | ||
| default: conda-forge | ||
| description: A list of conda channels to get packages from. The job expects a Queue Environment to handle this. | ||
|
|
||
| - name: CondaPackages | ||
| type: STRING | ||
| userInterface: | ||
| control: HIDDEN | ||
| default: "pytorch diffusers transformers accelerate peft safetensors pillow" | ||
| description: > | ||
| Conda packages to install for the environment. These packages must be available in a channel | ||
| made available by a Queue Environment. | ||
|
|
||
| jobEnvironments: | ||
| - name: UnbufferedOutput | ||
| variables: | ||
| # Turn off buffering of Python's output | ||
| PYTHONUNBUFFERED: "True" | ||
|
|
||
| steps: | ||
| - name: GenerateImages | ||
| parameterSpace: | ||
| taskParameterDefinitions: | ||
| - name: ImageIndex | ||
| type: INT | ||
| range: "1-{{Param.NumImages}}" | ||
| script: | ||
| actions: | ||
| onRun: | ||
| command: bash | ||
| args: ['-c', '{{Task.File.Run}}'] | ||
| embeddedFiles: | ||
| - name: Run | ||
| type: TEXT | ||
| runnable: true | ||
| data: | | ||
| #!/bin/bash | ||
| set -euo pipefail | ||
|
|
||
| python -u {{Param.GenerateScript}} \ | ||
| --lora-path "{{Param.LoRAPath}}" \ | ||
| --prompt "{{Param.Prompt}}" \ | ||
| --negative-prompt "{{Param.NegativePrompt}}" \ | ||
| --width {{Param.Width}} \ | ||
| --height {{Param.Height}} \ | ||
| --num-inference-steps {{Param.NumInferenceSteps}} \ | ||
| --guidance-scale {{Param.GuidanceScale}} \ | ||
| --seed {{Param.Seed}} \ | ||
| --image-index {{Task.Param.ImageIndex}} \ | ||
| --output-dir "{{Param.OutputDir}}" | ||
| hostRequirements: | ||
| attributes: | ||
| - name: attr.worker.os.family | ||
| anyOf: [linux] | ||
| amounts: | ||
| - name: amount.worker.gpu | ||
| min: 1 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.