Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit e31a85d

Browse files
committed
Inpaint Legacy process started
1 parent 566dd1e commit e31a85d

File tree

4 files changed

+276
-2
lines changed

4 files changed

+276
-2
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Config;
4+
using OnnxStack.Core.Services;
5+
using OnnxStack.StableDiffusion.Common;
6+
using OnnxStack.StableDiffusion.Config;
7+
using OnnxStack.StableDiffusion.Diffusers;
8+
using OnnxStack.StableDiffusion.Helpers;
9+
using SixLabors.ImageSharp;
10+
using SixLabors.ImageSharp.Processing;
11+
using SixLabors.ImageSharp.Processing.Processors.Transforms;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Linq;
15+
using System.Threading;
16+
using System.Threading.Tasks;
17+
18+
namespace OnnxStack.StableDiffusion.Services
19+
{
20+
public sealed class InpaintLegacyDiffuser : DiffuserBase
21+
{
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="InpaintLegacyDiffuser"/> class.
24+
/// </summary>
25+
/// <param name="configuration">The configuration.</param>
26+
/// <param name="onnxModelService">The onnx model service.</param>
27+
public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
28+
: base(onnxModelService, promptService)
29+
{
30+
}
31+
32+
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default)
33+
{
34+
// Create random seed if none was set
35+
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
36+
37+
// Get Scheduler
38+
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
39+
{
40+
// Process prompts
41+
var promptEmbeddings = await _promptService.CreatePromptAsync(promptOptions.Prompt, promptOptions.NegativePrompt);
42+
43+
// Get timesteps
44+
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
45+
46+
// Create latent sample
47+
var latentsOriginal = PrepareLatents(promptOptions, schedulerOptions, scheduler, timesteps);
48+
49+
// Create masks sample
50+
var maskImage = PrepareMask(promptOptions, schedulerOptions);
51+
52+
// Generate some noise
53+
var noise = scheduler.CreateRandomSample(latentsOriginal.Dimensions);
54+
55+
// Add noise to original latent
56+
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);
57+
58+
// Loop though the timesteps
59+
var step = 0;
60+
foreach (var timestep in timesteps)
61+
{
62+
cancellationToken.ThrowIfCancellationRequested();
63+
64+
// Create input tensor.
65+
var inputTensor = scheduler.ScaleInput(latents.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
66+
67+
// Create Input Parameters
68+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.Unet);
69+
var inputParameters = CreateInputParameters(
70+
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
71+
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
72+
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
73+
74+
// Run Inference
75+
using (var inferResult = await _onnxModelService.RunInferenceAsync(OnnxModelType.Unet, inputParameters))
76+
{
77+
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
78+
79+
// Perform guidance
80+
if (schedulerOptions.GuidanceScale > 1.0f)
81+
{
82+
var (noisePredUncond, noisePredText) = noisePred.SplitTensor(schedulerOptions.GetScaledDimension());
83+
noisePred = noisePredUncond.PerformGuidance(noisePredText, schedulerOptions.GuidanceScale);
84+
}
85+
86+
// Scheduler Step
87+
var steplatents = scheduler.Step(noisePred, timestep, latents);
88+
89+
// Add noise to original latent
90+
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { timestep });
91+
92+
// Apply mask and combine
93+
latents = ApplyMaskedLatents(steplatents, initLatentsProper, maskImage);
94+
95+
ImageHelpers.TensorToImageDebug(latents, $@"D:\Debug\Latent{step}.png");
96+
}
97+
98+
progress?.Invoke(++step, timesteps.Count);
99+
}
100+
101+
// Decode Latents
102+
return await DecodeLatents(schedulerOptions, latents);
103+
}
104+
}
105+
106+
107+
/// <summary>
108+
/// Gets the timesteps.
109+
/// </summary>
110+
/// <param name="prompt">The prompt.</param>
111+
/// <param name="options">The options.</param>
112+
/// <param name="scheduler">The scheduler.</param>
113+
/// <returns></returns>
114+
protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler)
115+
{
116+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
117+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
118+
return scheduler.Timesteps.Skip(start).ToList();
119+
}
120+
121+
122+
/// <summary>
123+
/// Prepares the latents for inference.
124+
/// </summary>
125+
/// <param name="prompt">The prompt.</param>
126+
/// <param name="options">The options.</param>
127+
/// <param name="scheduler">The scheduler.</param>
128+
/// <returns></returns>
129+
protected override DenseTensor<float> PrepareLatents(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
130+
{
131+
// Image input, decode, add noise, return as latent 0
132+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Width, options.Height });
133+
var inputNames = _onnxModelService.GetInputNames(OnnxModelType.VaeEncoder);
134+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], imageTensor));
135+
using (var inferResult = _onnxModelService.RunInference(OnnxModelType.VaeEncoder, inputParameters))
136+
{
137+
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
138+
var noisySample = sample
139+
.AddTensors(scheduler.CreateRandomSample(sample.Dimensions, options.InitialNoiseLevel))
140+
.MultipleTensorByFloat(_configuration.ScaleFactor);
141+
return noisySample;
142+
}
143+
}
144+
145+
146+
/// <summary>
147+
/// Prepares the mask.
148+
/// </summary>
149+
/// <param name="promptOptions">The prompt options.</param>
150+
/// <param name="schedulerOptions">The scheduler options.</param>
151+
/// <returns></returns>
152+
private DenseTensor<float> PrepareMask(PromptOptions promptOptions, SchedulerOptions schedulerOptions)
153+
{
154+
using (var mask = promptOptions.InputImageMask.ToImage())
155+
{
156+
// Prepare the mask
157+
int width = schedulerOptions.GetScaledWidth();
158+
int height = schedulerOptions.GetScaledHeight();
159+
mask.Mutate(x => x.Grayscale());
160+
mask.Mutate(x => x.Resize(new Size(width, height), KnownResamplers.NearestNeighbor, true));
161+
var maskTensor = new DenseTensor<float>(new[] { 1, 4, width, height });
162+
mask.ProcessPixelRows(img =>
163+
{
164+
for (int x = 0; x < width; x++)
165+
{
166+
for (int y = 0; y < height; y++)
167+
{
168+
var pixelSpan = img.GetRowSpan(y);
169+
var value = (float)pixelSpan[x].A / 255.0f;
170+
171+
//TODO: mask = 1 - mask # repaint white, keep black
172+
maskTensor[0, 0, y, x] = 0f;
173+
maskTensor[0, 1, y, x] = 0f; // Needed for shape only
174+
maskTensor[0, 2, y, x] = 0f; // Needed for shape only
175+
maskTensor[0, 3, y, x] = 0f; // Needed for shape only
176+
}
177+
}
178+
});
179+
return maskTensor;
180+
}
181+
}
182+
183+
184+
/// <summary>
185+
/// Applies the masked latents.
186+
/// </summary>
187+
/// <param name="latents">The latents.</param>
188+
/// <param name="initLatentsProper">The initialize latents proper.</param>
189+
/// <param name="mask">The mask.</param>
190+
/// <returns></returns>
191+
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
192+
{
193+
var result = new DenseTensor<float>(latents.Dimensions);
194+
for (int batch = 0; batch < latents.Dimensions[0]; batch++)
195+
{
196+
for (int channel = 0; channel < latents.Dimensions[1]; channel++)
197+
{
198+
for (int height = 0; height < latents.Dimensions[2]; height++)
199+
{
200+
for (int width = 0; width < latents.Dimensions[3]; width++)
201+
{
202+
float maskValue = mask[batch, 0, height, width];
203+
float latentsValue = latents[batch, channel, height, width];
204+
float initLatentsProperValue = initLatentsProper[batch, channel, height, width];
205+
206+
//TODO: Apply the logic to compute the result based on the mask
207+
float newValue = (initLatentsProperValue * maskValue) + (latentsValue * (1f - maskValue));
208+
result[batch, channel, height, width] = newValue;
209+
}
210+
}
211+
}
212+
}
213+
return result;
214+
}
215+
}
216+
}

OnnxStack.StableDiffusion/Helpers/TensorHelper.cs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,62 @@ public static DenseTensor<float> Abs(this DenseTensor<float> tensor)
280280
}
281281

282282

283+
/// <summary>
284+
/// Multiplies the specified tensor.
285+
/// </summary>
286+
/// <param name="tensor1">The tensor1.</param>
287+
/// <param name="tensor2">The tensor2.</param>
288+
/// <returns></returns>
289+
public static DenseTensor<float> Multiply(this DenseTensor<float> tensor1, DenseTensor<float> tensor2)
290+
{
291+
var result = new DenseTensor<float>(tensor1.Dimensions);
292+
for (int batch = 0; batch < tensor1.Dimensions[0]; batch++)
293+
{
294+
for (int channel = 0; channel < tensor1.Dimensions[1]; channel++)
295+
{
296+
for (int height = 0; height < tensor1.Dimensions[2]; height++)
297+
{
298+
for (int width = 0; width < tensor1.Dimensions[3]; width++)
299+
{
300+
var value1 = tensor1[batch, channel, height, width];
301+
var value2 = tensor2[batch, channel, height, width];
302+
result[batch, channel, height, width] = value1 * value2;
303+
}
304+
}
305+
}
306+
}
307+
return result;
308+
}
309+
310+
311+
/// <summary>
312+
/// Divides the specified tensor.
313+
/// </summary>
314+
/// <param name="tensor1">The tensor1.</param>
315+
/// <param name="tensor2">The tensor2.</param>
316+
/// <returns></returns>
317+
public static DenseTensor<float> Divide(this DenseTensor<float> tensor1, DenseTensor<float> tensor2)
318+
{
319+
var result = new DenseTensor<float>(tensor1.Dimensions);
320+
for (int batch = 0; batch < tensor1.Dimensions[0]; batch++)
321+
{
322+
for (int channel = 0; channel < tensor1.Dimensions[1]; channel++)
323+
{
324+
for (int height = 0; height < tensor1.Dimensions[2]; height++)
325+
{
326+
for (int width = 0; width < tensor1.Dimensions[3]; width++)
327+
{
328+
var value1 = tensor1[batch, channel, height, width];
329+
var value2 = tensor2[batch, channel, height, width];
330+
result[batch, channel, height, width] = value1 / value2;
331+
}
332+
}
333+
}
334+
}
335+
return result;
336+
}
337+
338+
283339
/// <summary>
284340
/// Generate a random Tensor from a normal distribution with mean 0 and variance 1
285341
/// </summary>

OnnxStack.StableDiffusion/Services/DiffuserService.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public sealed class DiffuserService : IDiffuserService
1515
private readonly IDiffuser _textDiffuser;
1616
private readonly IDiffuser _imageDiffuser;
1717
private readonly IDiffuser _inpaintDiffuser;
18+
private readonly IDiffuser _inpaintLegacyDiffuser;
1819

1920
/// <summary>
2021
/// Initializes a new instance of the <see cref="DiffuserService"/> class.
@@ -26,6 +27,7 @@ public DiffuserService(IOnnxModelService onnxModelService, IPromptService prompt
2627
_textDiffuser = new TextDiffuser(onnxModelService, promptService);
2728
_imageDiffuser = new ImageDiffuser(onnxModelService, promptService);
2829
_inpaintDiffuser = new InpaintDiffuser(onnxModelService, promptService);
30+
_inpaintLegacyDiffuser = new InpaintLegacyDiffuser(onnxModelService, promptService);
2931
}
3032

3133

@@ -41,7 +43,7 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
4143
{
4244
ProcessType.TextToImage => await _textDiffuser.DiffuseAsync(promptOptions, schedulerOptions, progress, cancellationToken),
4345
ProcessType.ImageToImage => await _imageDiffuser.DiffuseAsync(promptOptions, schedulerOptions, progress, cancellationToken),
44-
ProcessType.ImageInpaint => await _inpaintDiffuser.DiffuseAsync(promptOptions, schedulerOptions, progress, cancellationToken),
46+
ProcessType.ImageInpaint => await _inpaintLegacyDiffuser.DiffuseAsync(promptOptions, schedulerOptions, progress, cancellationToken),
4547
_ => throw new NotImplementedException()
4648
};
4749
}

OnnxStack.WebUI/wwwroot/js/stableDiffusionImageInpaint.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ const stableDiffusionImageInpaint = () => {
320320

321321

322322
// Map UI Events/Functions
323-
$(".image2image-control").hide();
323+
//$(".image2image-control").hide();
324324
buttonCancel.on("click", cancelDiffusion);
325325
buttonClear.on("click", clearHistory);
326326
buttonExecute.on("click", async () => { await executeDiffusion(); });

0 commit comments

Comments
 (0)