Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 566dd1e

Browse files
committed
Inpainting prototype
1 parent 51f0d76 commit 566dd1e

File tree

9 files changed

+333
-105
lines changed

9 files changed

+333
-105
lines changed

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public interface IStableDiffusionService
2828
/// <param name="options">The Scheduler options.</param>
2929
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
3030
/// <param name="cancellationToken">The cancellation token.</param>
31-
/// <returns>The diffusion result as <see cref="SixLabors.ImageSharp.Image<Rgb24>"/></returns>
32-
Task<Image<Rgb24>> GenerateAsImageAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
31+
/// <returns>The diffusion result as <see cref="SixLabors.ImageSharp.Image<Rgba32>"/></returns>
32+
Task<Image<Rgba32>> GenerateAsImageAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
3333

3434
/// <summary>
3535
/// Generates the StableDiffusion image using the prompt and options provided.

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
6666
{
6767
// Create random seed if none was set
6868
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
69-
Console.WriteLine($"Scheduler: {promptOptions.SchedulerType}, Size: {schedulerOptions.Width}x{schedulerOptions.Height}, Seed: {schedulerOptions.Seed}, Steps: {schedulerOptions.InferenceSteps}, Guidance: {schedulerOptions.GuidanceScale}");
70-
69+
7170
// Get Scheduler
7271
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
7372
{
@@ -78,7 +77,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
7877
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
7978

8079
// Create latent sample
81-
var latentSample = PrepareLatents(promptOptions, schedulerOptions, scheduler, timesteps);
80+
var latents = PrepareLatents(promptOptions, schedulerOptions, scheduler, timesteps);
8281

8382
// Loop though the timesteps
8483
var step = 0;
@@ -87,8 +86,9 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
8786
cancellationToken.ThrowIfCancellationRequested();
8887

8988
// Create input tensor.
90-
var inputTensor = scheduler.ScaleInput(latentSample.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
89+
var inputTensor = scheduler.ScaleInput(latents.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
9190

91+
// Create Input Parameters
9292
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.Unet);
9393
var inputParameters = CreateInputParameters(
9494
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
@@ -98,27 +98,24 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
9898
// Run Inference
9999
using (var inferResult = await _onnxModelService.RunInferenceAsync(OnnxModelType.Unet, inputParameters))
100100
{
101-
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
102-
103-
// Split tensors from 2,4,(H/8),(W/8) to 1,4,(H/8),(W/8)
104-
var splitTensors = resultTensor.SplitTensor(schedulerOptions.GetScaledDimension(), schedulerOptions.GetScaledHeight(), schedulerOptions.GetScaledWidth());
105-
var noisePred = splitTensors.Item1;
106-
var noisePredText = splitTensors.Item2;
101+
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
107102

108103
// Perform guidance
109-
noisePred = noisePred.PerformGuidance(noisePredText, schedulerOptions.GuidanceScale);
104+
if (schedulerOptions.GuidanceScale > 1.0f)
105+
{
106+
var (noisePredUncond, noisePredText) = noisePred.SplitTensor(schedulerOptions.GetScaledDimension());
107+
noisePred = noisePredUncond.PerformGuidance(noisePredText, schedulerOptions.GuidanceScale);
108+
}
110109

111-
// LMS Scheduler Step
112-
latentSample = scheduler.Step(noisePred, timestep, latentSample);
113-
// ImageHelpers.TensorToImageDebug(latentSample, 64, $@"Examples\StableDebug\Latent_{step}.png");
110+
// Scheduler Step
111+
latents = scheduler.Step(noisePred, timestep, latents);
114112
}
115113

116-
Console.WriteLine($"Step: {++step}/{timesteps.Count}");
117-
progress?.Invoke(step, timesteps.Count);
114+
progress?.Invoke(++step, timesteps.Count);
118115
}
119116

120117
// Decode Latents
121-
return await DecodeLatents(schedulerOptions, latentSample);
118+
return await DecodeLatents(schedulerOptions, latents);
122119
}
123120
}
124121

@@ -192,7 +189,7 @@ protected static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions o
192189
using (var image = imageTensor.ToImage())
193190
{
194191
// Resize image
195-
ImageHelpers.Resize(image, 224, 224);
192+
ImageHelpers.Resize(image, new[] { 1, 3, 224, 224 });
196193

197194
// Preprocess image
198195
var input = new DenseTensor<float>(new[] { 1, 3, 224, 224 });

OnnxStack.StableDiffusion/Diffusers/ImageDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, Schedul
5353
protected override DenseTensor<float> PrepareLatents(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5454
{
5555
// Image input, decode, add noise, return as latent 0
56-
var imageTensor = prompt.InputImage.ToDenseTensor(options.Width, options.Height);
56+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Width, options.Height });
5757
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.VaeEncoder);
5858
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
5959
using (var inferResult = _onnxModelService.RunInference(OnnxModelType.VaeEncoder, inputParameters))

0 commit comments

Comments
 (0)