Skip to content

Commit 3ce8f17

Browse files
committed
Remove flow_shift
1 parent 7cc9348 commit 3ce8f17

File tree

9 files changed

+46
-23
lines changed

9 files changed

+46
-23
lines changed

TensorStack.Python/Common/PipelineOptions.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ public record PipelineOptions
1515
public int Height { get; set; }
1616
public int Width { get; set; }
1717
public int Frames { get; set; }
18-
public float OutputFrameRate { get; set; }
18+
public float FrameRate { get; set; }
1919
public float Shift { get; set; }
20-
public float FlowShift { get; set; }
2120
public float Strength { get; set; }
2221
public SchedulerType Scheduler { get; set; }
2322
public List<LoraOptions> LoraOptions { get; set; }

TensorStack.Python/Common/PipelineProgress.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ public record PipelineProgress : IRunProgress
1414
public float Downloaded { get; set; }
1515
public float DownloadTotal { get; set; }
1616
public float DownloadSpeed { get; set; }
17+
public bool IsLoading => Process == "Loading";
18+
public bool IsGenerating => Process == "Generate" && Iterations > 0;
1719
public bool IsDownloading => DownloadTotal > 0 || Downloaded > 0 || DownloadSpeed > 0;
1820

1921
public readonly static IProgress<PipelineProgress> ConsoleCallback = new Progress<PipelineProgress>(Console.WriteLine);

TensorStack.Python/Config/EnvironmentConfig.cs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
namespace TensorStack.Python.Config
1+
using System;
2+
3+
namespace TensorStack.Python.Config
24
{
35
public record EnvironmentConfig
46
{
@@ -8,13 +10,30 @@ public record EnvironmentConfig
810
public string Directory { get; set; }
911

1012

13+
public static EnvironmentConfig Default(VendorType vendorType)
14+
{
15+
return vendorType switch
16+
{
17+
VendorType.AMD => DefaultROCM,
18+
VendorType.Nvidia => DefaultCUDA,
19+
_ => DefaultCPU
20+
};
21+
}
22+
23+
public static EnvironmentConfig Default(object vendor)
24+
{
25+
throw new NotImplementedException();
26+
}
27+
1128
public readonly static EnvironmentConfig DefaultCPU = new()
1229
{
1330
Environment = "default-cpu",
1431
Directory = "PythonRuntime",
1532
Requirements =
1633
[
17-
"torchvision==0.22.0",
34+
"torch==2.9.1",
35+
"torchvaudio==2.9.1",
36+
"torchvision==0.24.1",
1837

1938
// Default Packages
2039
"typing",
@@ -24,11 +43,9 @@ public record EnvironmentConfig
2443
"diffusers",
2544
"protobuf",
2645
"sentencepiece",
27-
"pillow",
2846
"ftfy",
2947
"scipy",
30-
"peft",
31-
"pillow"
48+
"peft"
3249
]
3350
};
3451

@@ -39,8 +56,10 @@ public record EnvironmentConfig
3956
Directory = "PythonRuntime",
4057
Requirements =
4158
[
42-
"--extra-index-url https://download.pytorch.org/whl/cu118",
43-
"torchvision==0.22.0+cu118",
59+
"--extra-index-url https://download.pytorch.org/whl/cu128",
60+
"torch==2.9.1+cu128",
61+
"torchaudio==2.9.1+cu128",
62+
"torchvision==0.24.1+cu128",
4463

4564
// Default Packages
4665
"typing",
@@ -50,11 +69,9 @@ public record EnvironmentConfig
5069
"diffusers",
5170
"protobuf",
5271
"sentencepiece",
53-
"pillow",
5472
"ftfy",
5573
"scipy",
56-
"peft",
57-
"pillow"
74+
"peft"
5875
]
5976
};
6077

@@ -80,12 +97,18 @@ public record EnvironmentConfig
8097
"diffusers",
8198
"protobuf",
8299
"sentencepiece",
83-
"pillow",
84100
"ftfy",
85101
"scipy",
86-
"peft",
87-
"pillow"
102+
"peft"
88103
]
89104
};
90105
}
106+
107+
public enum VendorType
108+
{
109+
Unknown = 0,
110+
AMD = 4098,
111+
Nvidia = 4318,
112+
Intel = 32902
113+
}
91114
}

TensorStack.Python/LogParser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private static PipelineProgress ParsePythonLog(string logEntry)
109109
DownloadTotal = megabytesTotal,
110110
Downloaded = megabytesDownloaded,
111111
DownloadSpeed = megabytesSecond,
112-
Process = "Generate"
112+
Process = messageSection.StartsWith("Loading") ? "Loading" : "Generate"
113113
};
114114

115115
}

TensorStack.Python/Pipelines/ChromaPipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def generate(
115115
scheduler: str,
116116
numFrames: int,
117117
shift: float,
118-
flowShift: float,
119118
strength: float,
120119
loraOptions: Optional[Dict[str, float]] = None,
121120
inputData: Optional[Sequence[float]] = None,

TensorStack.Python/Pipelines/QwenImagePipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def generate(
123123
scheduler: str,
124124
numFrames: int,
125125
shift: float,
126-
flowShift: float,
127126
strength: float,
128127
loraOptions: Optional[Dict[str, float]] = None,
129128
inputData: Optional[Sequence[float]] = None,

TensorStack.Python/Pipelines/WanPipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def generate(
116116
scheduler: str,
117117
numFrames: int,
118118
shift: float,
119-
flowShift: float,
120119
strength: float,
121120
loraOptions: Optional[Dict[str, float]] = None,
122121
inputData: Optional[Sequence[float]] = None,
@@ -127,6 +126,7 @@ def generate(
127126
_reset()
128127

129128
# scheduler
129+
flowShift = 5.0 if height > 480 else 3.0 # 5.0 for 720P, 3.0 for 480P
130130
_pipeline.scheduler = UniPCMultistepScheduler.from_config(_pipeline.scheduler.config, flow_shift=flowShift)
131131

132132
#Lora Adapters

TensorStack.Python/Pipelines/ZImagePipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def generate(
115115
scheduler: str,
116116
numFrames: int,
117117
shift: float,
118-
flowShift: float,
119118
strength: float,
120119
loraOptions: Optional[Dict[str, float]] = None,
121120
inputData: Optional[Sequence[float]] = None,

TensorStack.Python/PythonPipeline.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,11 @@ public Task<Tensor<float>> GenerateAsync(PipelineOptions options, CancellationTo
172172
using (var scheduler = PyObject.From(options.Scheduler.ToString()))
173173
using (var numFrames = PyObject.From(options.Frames))
174174
using (var shift = PyObject.From(options.Shift))
175-
using (var flowShift = PyObject.From(options.FlowShift))
176175
using (var strength = PyObject.From(options.Strength))
177176
using (var loraOptions = PyObject.From(loraConfig))
178177
using (var inputData = PyObject.From(imageInput?.Memory.ToArray()))
179178
using (var inputShape = PyObject.From(imageInput?.Dimensions.ToArray()))
180-
using (var pythonResult = _functionGenerate.Call(prompt, negativePrompt, guidance, guidance2, steps, steps2, height, width, seed, scheduler, numFrames, shift, flowShift, strength, loraOptions, inputData, inputShape))
179+
using (var pythonResult = _functionGenerate.Call(prompt, negativePrompt, guidance, guidance2, steps, steps2, height, width, seed, scheduler, numFrames, shift, strength, loraOptions, inputData, inputShape))
181180
{
182181
var result = pythonResult
183182
.BareImportAs<IPyBuffer, PyObjectImporters.Buffer>()
@@ -330,7 +329,7 @@ private async Task LoggingLoop(int refreshRate)
330329
var logs = await GetLogsAsync();
331330
foreach (var progress in LogParser.ParseLogs(logs))
332331
{
333-
_logger?.LogInformation("[PythonRuntime] {message}", progress.Message);
332+
_logger?.LogInformation("[PythonRuntime] {Message}", progress.Message);
334333
_progressCallback?.Report(progress);
335334
}
336335
await Task.Delay(refreshRate, _progressCancellation.Token);
@@ -347,6 +346,9 @@ private Exception HandlePythonException(PythonInvocationException ex)
347346
{
348347
if (ex.InnerException is PythonRuntimeException pyex)
349348
{
349+
if (ex.InnerException.Message.Equals("Operation Canceled"))
350+
return new OperationCanceledException();
351+
350352
_logger?.LogError(pyex, "{PythonExceptionType} exception occured", ex.PythonExceptionType);
351353
if (!pyex.PythonStackTrace.IsNullOrEmpty())
352354
_logger?.LogError(string.Join(Environment.NewLine, pyex.PythonStackTrace));

0 commit comments

Comments
 (0)