Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 4f0ae87

Browse files
committed
Wrap batch SchedulerOptions with result
1 parent c222d4d commit 4f0ae87

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core.Image;
3+
using OnnxStack.Core.Video;
24
using OnnxStack.StableDiffusion.Config;
35

46
namespace OnnxStack.StableDiffusion.Common
57
{
68
public record BatchResult(SchedulerOptions SchedulerOptions, DenseTensor<float> Result);
9+
public record BatchImageResult(SchedulerOptions SchedulerOptions, OnnxImage Result);
10+
public record BatchVideoResult(SchedulerOptions SchedulerOptions, OnnxVideo Result);
711
}

OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public interface IPipeline
101101
/// <param name="progressCallback">The progress callback.</param>
102102
/// <param name="cancellationToken">The cancellation token.</param>
103103
/// <returns></returns>
104-
IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
104+
IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
105105

106106

107107
/// <summary>
@@ -126,6 +126,6 @@ public interface IPipeline
126126
/// <param name="progressCallback">The progress callback.</param>
127127
/// <param name="cancellationToken">The cancellation token.</param>
128128
/// <returns></returns>
129-
IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
129+
IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
130130
}
131131
}

OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
133133
/// <param name="progressCallback">The progress callback.</param>
134134
/// <param name="cancellationToken">The cancellation token.</param>
135135
/// <returns></returns>
136-
public abstract IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
136+
public abstract IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
137137

138138

139139
/// <summary>
@@ -158,7 +158,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
158158
/// <param name="progressCallback">The progress callback.</param>
159159
/// <param name="cancellationToken">The cancellation token.</param>
160160
/// <returns></returns>
161-
public abstract IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
161+
public abstract IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
162162

163163

164164
/// <summary>

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ public override async Task<OnnxImage> GenerateImageAsync(PromptOptions promptOpt
291291
/// <param name="progressCallback">The progress callback.</param>
292292
/// <param name="cancellationToken">The cancellation token.</param>
293293
/// <returns></returns>
294-
public override async IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
294+
public override async IAsyncEnumerable<BatchImageResult> GenerateImageBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
295295
{
296296
var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting...");
297297
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
@@ -316,7 +316,7 @@ public override async IAsyncEnumerable<OnnxImage> GenerateImageBatchAsync(BatchO
316316
foreach (var batchSchedulerOption in batchSchedulerOptions)
317317
{
318318
var tensorResult = await DiffuseImageAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
319-
yield return new OnnxImage(tensorResult);
319+
yield return new BatchImageResult(batchSchedulerOption, new OnnxImage(tensorResult));
320320
batchIndex++;
321321
}
322322

@@ -367,7 +367,7 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
367367
/// <param name="progressCallback">The progress callback.</param>
368368
/// <param name="cancellationToken">The cancellation token.</param>
369369
/// <returns></returns>
370-
public override async IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
370+
public override async IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync(BatchOptions batchOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
371371
{
372372
var diffuseBatchTime = _logger?.LogBegin("Batch Diffuser starting...");
373373
var options = GetSchedulerOptionsOrDefault(schedulerOptions);
@@ -392,11 +392,11 @@ public override async IAsyncEnumerable<OnnxVideo> GenerateVideoBatchAsync(BatchO
392392
foreach (var batchSchedulerOption in batchSchedulerOptions)
393393
{
394394
var frames = new List<OnnxImage>();
395-
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, options, promptEmbeddings, performGuidance, progressCallback, cancellationToken))
395+
await foreach (var frameTensor in DiffuseVideoAsync(diffuser, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, progressCallback, cancellationToken))
396396
{
397397
frames.Add(new OnnxImage(frameTensor));
398398
}
399-
yield return new OnnxVideo(promptOptions.InputVideo.Info, frames);
399+
yield return new BatchVideoResult(batchSchedulerOption, new OnnxVideo(promptOptions.InputVideo.Info, frames));
400400
batchIndex++;
401401
}
402402

0 commit comments

Comments
 (0)