@@ -66,8 +66,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
66
66
{
67
67
// Create random seed if none was set
68
68
schedulerOptions . Seed = schedulerOptions . Seed > 0 ? schedulerOptions . Seed : Random . Shared . Next ( ) ;
69
- Console . WriteLine ( $ "Scheduler: { promptOptions . SchedulerType } , Size: { schedulerOptions . Width } x{ schedulerOptions . Height } , Seed: { schedulerOptions . Seed } , Steps: { schedulerOptions . InferenceSteps } , Guidance: { schedulerOptions . GuidanceScale } ") ;
70
-
69
+
71
70
// Get Scheduler
72
71
using ( var scheduler = GetScheduler ( promptOptions , schedulerOptions ) )
73
72
{
@@ -78,7 +77,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
78
77
var timesteps = GetTimesteps ( promptOptions , schedulerOptions , scheduler ) ;
79
78
80
79
// Create latent sample
81
- var latentSample = PrepareLatents ( promptOptions , schedulerOptions , scheduler , timesteps ) ;
80
+ var latents = PrepareLatents ( promptOptions , schedulerOptions , scheduler , timesteps ) ;
82
81
83
82
// Loop though the timesteps
84
83
var step = 0 ;
@@ -87,8 +86,9 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
87
86
cancellationToken . ThrowIfCancellationRequested ( ) ;
88
87
89
88
// Create input tensor.
90
- var inputTensor = scheduler . ScaleInput ( latentSample . Duplicate ( schedulerOptions . GetScaledDimension ( 2 ) ) , timestep ) ;
89
+ var inputTensor = scheduler . ScaleInput ( latents . Duplicate ( schedulerOptions . GetScaledDimension ( 2 ) ) , timestep ) ;
91
90
91
+ // Create Input Parameters
92
92
var inputNames = _onnxModelService . GetInputNames ( OnnxModelType . Unet ) ;
93
93
var inputParameters = CreateInputParameters (
94
94
NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , inputTensor ) ,
@@ -98,27 +98,24 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
98
98
// Run Inference
99
99
using ( var inferResult = await _onnxModelService . RunInferenceAsync ( OnnxModelType . Unet , inputParameters ) )
100
100
{
101
- var resultTensor = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
102
-
103
- // Split tensors from 2,4,(H/8),(W/8) to 1,4,(H/8),(W/8)
104
- var splitTensors = resultTensor . SplitTensor ( schedulerOptions . GetScaledDimension ( ) , schedulerOptions . GetScaledHeight ( ) , schedulerOptions . GetScaledWidth ( ) ) ;
105
- var noisePred = splitTensors . Item1 ;
106
- var noisePredText = splitTensors . Item2 ;
101
+ var noisePred = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
107
102
108
103
// Perform guidance
109
- noisePred = noisePred . PerformGuidance ( noisePredText , schedulerOptions . GuidanceScale ) ;
104
+ if ( schedulerOptions . GuidanceScale > 1.0f )
105
+ {
106
+ var ( noisePredUncond , noisePredText ) = noisePred . SplitTensor ( schedulerOptions . GetScaledDimension ( ) ) ;
107
+ noisePred = noisePredUncond . PerformGuidance ( noisePredText , schedulerOptions . GuidanceScale ) ;
108
+ }
110
109
111
- // LMS Scheduler Step
112
- latentSample = scheduler . Step ( noisePred , timestep , latentSample ) ;
113
- // ImageHelpers.TensorToImageDebug(latentSample, 64, $@"Examples\StableDebug\Latent_{step}.png");
110
+ // Scheduler Step
111
+ latents = scheduler . Step ( noisePred , timestep , latents ) ;
114
112
}
115
113
116
- Console . WriteLine ( $ "Step: { ++ step } /{ timesteps . Count } ") ;
117
- progress ? . Invoke ( step , timesteps . Count ) ;
114
+ progress ? . Invoke ( ++ step , timesteps . Count ) ;
118
115
}
119
116
120
117
// Decode Latents
121
- return await DecodeLatents ( schedulerOptions , latentSample ) ;
118
+ return await DecodeLatents ( schedulerOptions , latents ) ;
122
119
}
123
120
}
124
121
@@ -192,7 +189,7 @@ protected static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions o
192
189
using ( var image = imageTensor . ToImage ( ) )
193
190
{
194
191
// Resize image
195
- ImageHelpers . Resize ( image , 224 , 224 ) ;
192
+ ImageHelpers . Resize ( image , new [ ] { 1 , 3 , 224 , 224 } ) ;
196
193
197
194
// Preprocess image
198
195
var input = new DenseTensor < float > ( new [ ] { 1 , 3 , 224 , 224 } ) ;
0 commit comments