diff --git a/OnnxStack.Console/Examples/TextGenerationExample.cs b/OnnxStack.Console/Examples/TextGenerationExample.cs
new file mode 100644
index 0000000..6706d9d
--- /dev/null
+++ b/OnnxStack.Console/Examples/TextGenerationExample.cs
@@ -0,0 +1,37 @@
+using OnnxStack.Core.Config;
+using OnnxStack.TextGeneration.Models;
+using OnnxStack.TextGeneration.Pipelines;
+
+namespace OnnxStack.Console.Runner
+{
+ public sealed class TextGenerationExample : IExampleRunner
+ {
+ public TextGenerationExample()
+ {
+ }
+
+ public int Index => 40;
+
+ public string Name => "Text Generation Demo";
+
+ public string Description => "Text Generation Example";
+
+ public async Task RunAsync()
+ {
+ var pipeline = TextGenerationPipeline.CreatePipeline("D:\\Repositories\\phi2_onnx", executionProvider: ExecutionProvider.Cuda);
+
+ await pipeline.LoadAsync();
+
+ while (true)
+ {
+ OutputHelpers.WriteConsole("Enter Prompt: ", ConsoleColor.Gray);
+ var promptOptions = new PromptOptionsModel(OutputHelpers.ReadConsole(ConsoleColor.Cyan));
+ var searchOptions = new SearchOptionsModel();
+ await foreach (var token in pipeline.RunAsync(promptOptions, searchOptions))
+ {
+ OutputHelpers.WriteConsole(token.Content, ConsoleColor.Yellow, false);
+ }
+ }
+ }
+ }
+}
diff --git a/OnnxStack.Console/OnnxStack.Console.csproj b/OnnxStack.Console/OnnxStack.Console.csproj
index 2b16d82..ed4dfc6 100644
--- a/OnnxStack.Console/OnnxStack.Console.csproj
+++ b/OnnxStack.Console/OnnxStack.Console.csproj
@@ -22,6 +22,7 @@
+
diff --git a/OnnxStack.TextGeneration/Binaries/cuda/Microsoft.ML.OnnxRuntimeGenAI.dll b/OnnxStack.TextGeneration/Binaries/cuda/Microsoft.ML.OnnxRuntimeGenAI.dll
new file mode 100644
index 0000000..426b96c
Binary files /dev/null and b/OnnxStack.TextGeneration/Binaries/cuda/Microsoft.ML.OnnxRuntimeGenAI.dll differ
diff --git a/OnnxStack.TextGeneration/Binaries/cuda/onnxruntime-genai.dll b/OnnxStack.TextGeneration/Binaries/cuda/onnxruntime-genai.dll
new file mode 100644
index 0000000..5e42e9b
Binary files /dev/null and b/OnnxStack.TextGeneration/Binaries/cuda/onnxruntime-genai.dll differ
diff --git a/OnnxStack.TextGeneration/Common/TextGenerationModel.cs b/OnnxStack.TextGeneration/Common/TextGenerationModel.cs
new file mode 100644
index 0000000..92921bd
--- /dev/null
+++ b/OnnxStack.TextGeneration/Common/TextGenerationModel.cs
@@ -0,0 +1,70 @@
+using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.OnnxRuntimeGenAI;
+using OnnxStack.Core.Config;
+
+namespace OnnxStack.TextGeneration.Common
+{
+ public class TextGenerationModel : IDisposable //: OnnxModelSession
+ {
+ private Model _model;
+ private Tokenizer _tokenizer;
+ private readonly TextGenerationModelConfig _configuration;
+ public TextGenerationModel(TextGenerationModelConfig configuration)
+ {
+ _configuration = configuration;
+ }
+
+ public Model Model => _model;
+ public Tokenizer Tokenizer => _tokenizer;
+
+
+ ///
+ /// Loads the model session.
+ ///
+ public async Task LoadAsync()
+ {
+ if (_model is not null)
+ return; // Already Loaded
+
+ await Task.Run(() =>
+ {
+ _model = new Model(_configuration.OnnxModelPath);
+ _tokenizer = new Tokenizer(_model);
+ });
+ }
+
+
+ ///
+ /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+ ///
+ public void Dispose()
+ {
+ _tokenizer?.Dispose();
+ _model?.Dispose();
+ _model = null;
+ _tokenizer = null;
+ }
+
+
+ public static TextGenerationModel Create(TextGenerationModelConfig configuration)
+ {
+ return new TextGenerationModel(configuration);
+ }
+
+ public static TextGenerationModel Create(string modelPath, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
+ {
+ var configuration = new TextGenerationModelConfig
+ {
+ DeviceId = deviceId,
+ ExecutionProvider = executionProvider,
+ ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
+ InterOpNumThreads = 0,
+ IntraOpNumThreads = 0,
+ OnnxModelPath = modelPath
+ };
+ return new TextGenerationModel(configuration);
+ }
+
+
+ }
+}
diff --git a/OnnxStack.TextGeneration/Common/TextGenerationModelConfig.cs b/OnnxStack.TextGeneration/Common/TextGenerationModelConfig.cs
new file mode 100644
index 0000000..914d77a
--- /dev/null
+++ b/OnnxStack.TextGeneration/Common/TextGenerationModelConfig.cs
@@ -0,0 +1,9 @@
+using OnnxStack.Core.Config;
+
+namespace OnnxStack.TextGeneration.Common
+{
+ public record TextGenerationModelConfig : OnnxModelConfig
+ {
+
+ }
+}
diff --git a/OnnxStack.TextGeneration/Common/TextGenerationModelSet.cs b/OnnxStack.TextGeneration/Common/TextGenerationModelSet.cs
new file mode 100644
index 0000000..a558231
--- /dev/null
+++ b/OnnxStack.TextGeneration/Common/TextGenerationModelSet.cs
@@ -0,0 +1,20 @@
+using Microsoft.ML.OnnxRuntime;
+using OnnxStack.Core.Config;
+using System.Text.Json.Serialization;
+
+namespace OnnxStack.TextGeneration.Common
+{
+ public record TextGenerationModelSet : IOnnxModelSetConfig
+ {
+ public string Name { get; set; }
+ public bool IsEnabled { get; set; }
+ public int DeviceId { get; set; }
+ public int InterOpNumThreads { get; set; }
+ public int IntraOpNumThreads { get; set; }
+ public ExecutionMode ExecutionMode { get; set; }
+ public ExecutionProvider ExecutionProvider { get; set; }
+
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public TextGenerationModelConfig TextGenerationConfig { get; set; }
+ }
+}
diff --git a/OnnxStack.TextGeneration/Models/PromptOptionsModel.cs b/OnnxStack.TextGeneration/Models/PromptOptionsModel.cs
new file mode 100644
index 0000000..508103c
--- /dev/null
+++ b/OnnxStack.TextGeneration/Models/PromptOptionsModel.cs
@@ -0,0 +1,4 @@
+namespace OnnxStack.TextGeneration.Models
+{
+ public record class PromptOptionsModel(string Prompt);
+}
diff --git a/OnnxStack.TextGeneration/Models/SearchOptionsModel.cs b/OnnxStack.TextGeneration/Models/SearchOptionsModel.cs
new file mode 100644
index 0000000..fea74b1
--- /dev/null
+++ b/OnnxStack.TextGeneration/Models/SearchOptionsModel.cs
@@ -0,0 +1,20 @@
+namespace OnnxStack.TextGeneration.Models
+{
+ public class SearchOptionsModel
+ {
+ public int TopK { get; set; } = 50;
+ public float TopP { get; set; } = 0.95f;
+ public float Temperature { get; set; } = 1;
+ public float RepetitionPenalty { get; set; } = 0.9f;
+ public bool PastPresentShareBuffer { get; set; } = false;
+ public int NumReturnSequences { get; set; } = 1;
+ public int NumBeams { get; set; } = 1;
+ public int NoRepeatNgramSize { get; set; } = 0;
+ public int MinLength { get; set; } = 0;
+ public int MaxLength { get; set; } = 512;
+ public float LengthPenalty { get; set; } = 1;
+ public float DiversityPenalty { get; set; } = 0;
+ public bool EarlyStopping { get; set; } = true;
+ public bool DoSample { get; set; } = false;
+ }
+}
\ No newline at end of file
diff --git a/OnnxStack.TextGeneration/Models/TokenModel.cs b/OnnxStack.TextGeneration/Models/TokenModel.cs
new file mode 100644
index 0000000..8c013c7
--- /dev/null
+++ b/OnnxStack.TextGeneration/Models/TokenModel.cs
@@ -0,0 +1,4 @@
+namespace OnnxStack.TextGeneration.Models
+{
+ public readonly record struct TokenModel(int Id, string Content);
+}
\ No newline at end of file
diff --git a/OnnxStack.TextGeneration/OnnxStack.TextGeneration.csproj b/OnnxStack.TextGeneration/OnnxStack.TextGeneration.csproj
new file mode 100644
index 0000000..ef26c41
--- /dev/null
+++ b/OnnxStack.TextGeneration/OnnxStack.TextGeneration.csproj
@@ -0,0 +1,42 @@
+
+
+
+ net7.0
+ enable
+ disable
+ x64
+ x64
+
+
+
+
+ \
+ True
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Binaries\cuda\Microsoft.ML.OnnxRuntimeGenAI.dll
+
+
+
+
+
+ PreserveNewest
+
+
+ PreserveNewest
+
+
+
+
diff --git a/OnnxStack.TextGeneration/Pipelines/TextGenerationPipeline.cs b/OnnxStack.TextGeneration/Pipelines/TextGenerationPipeline.cs
new file mode 100644
index 0000000..8a149a5
--- /dev/null
+++ b/OnnxStack.TextGeneration/Pipelines/TextGenerationPipeline.cs
@@ -0,0 +1,182 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.ML.OnnxRuntimeGenAI;
+using OnnxStack.Core;
+using OnnxStack.Core.Config;
+using OnnxStack.TextGeneration.Common;
+using OnnxStack.TextGeneration.Models;
+using System.Runtime.CompilerServices;
+
+namespace OnnxStack.TextGeneration.Pipelines
+{
+
+ public class TextGenerationPipeline
+ {
+ private readonly string _name;
+ private readonly ILogger _logger;
+ private readonly TextGenerationModel _model;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The name.
+ /// The text generation model.
+ /// The logger.
+ public TextGenerationPipeline(string name, TextGenerationModel model, ILogger logger = default)
+ {
+ _name = name;
+ _logger = logger;
+ _model = model;
+ }
+
+
+ ///
+ /// Gets the name.
+ ///
+ ///
+ public string Name => _name;
+
+
+ ///
+ /// Loads the model.
+ ///
+ ///
+ public Task LoadAsync()
+ {
+ return _model.LoadAsync();
+ }
+
+
+ ///
+ /// Unloads the models.
+ ///
+ public async Task UnloadAsync()
+ {
+ await Task.Yield();
+ _model?.Dispose();
+ }
+
+
+ ///
+ /// Runs the text generation pipeline
+ ///
+ /// The image frames.
+ /// The cancellation token.
+ ///
+ public IAsyncEnumerable RunAsync(PromptOptionsModel promptOptions, SearchOptionsModel searchOptions, CancellationToken cancellationToken = default)
+ {
+ return RunInternalAsync(promptOptions, searchOptions, cancellationToken);
+ }
+
+
+ ///
+ /// Runs the text generation pipeline
+ ///
+ /// The input image.
+ /// The cancellation token.
+ ///
+ private async IAsyncEnumerable RunInternalAsync(PromptOptionsModel promptOptions, SearchOptionsModel searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var timestamp = _logger?.LogBegin("Run text generation pipeline stream...");
+ var sequences = await EncodePrompt(promptOptions, cancellationToken);
+
+ using (var generatorParams = new GeneratorParams(_model.Model))
+ {
+ ApplySearchOptions(generatorParams, searchOptions);
+ generatorParams.SetInputSequences(sequences);
+
+ using (var tokenizerStream = _model.Tokenizer.CreateStream())
+ using (var generator = new Generator(_model.Model, generatorParams))
+ {
+ while (!generator.IsDone())
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ yield return await Task.Run(() =>
+ {
+ generator.ComputeLogits();
+ generator.GenerateNextTokenTop();
+
+ var tokenId = generator.GetSequence(0)[^1];
+ return new TokenModel(tokenId, tokenizerStream.Decode(tokenId));
+ }, cancellationToken);
+ }
+ }
+ }
+ _logger?.LogEnd("Text generation pipeline stream complete.", timestamp);
+ }
+
+
+ ///
+ /// Encodes the prompt.
+ ///
+ /// The prompt options.
+ /// The cancellation token.
+ ///
+ private async Task EncodePrompt(PromptOptionsModel promptOptions, CancellationToken cancellationToken = default)
+ {
+ return await Task.Run(() => _model.Tokenizer.Encode(promptOptions.Prompt), cancellationToken);
+ }
+
+
+ ///
+ /// Applies the search options to the GeneratorParams instance.
+ ///
+ /// The generator parameters.
+ /// The search options.
+ private static void ApplySearchOptions(GeneratorParams generatorParams, SearchOptionsModel searchOptions)
+ {
+ generatorParams.SetSearchOption("top_p", searchOptions.TopP);
+ generatorParams.SetSearchOption("top_k", searchOptions.TopK);
+ generatorParams.SetSearchOption("temperature", searchOptions.Temperature);
+ generatorParams.SetSearchOption("repetition_penalty", searchOptions.RepetitionPenalty);
+ generatorParams.SetSearchOption("past_present_share_buffer", searchOptions.PastPresentShareBuffer);
+ generatorParams.SetSearchOption("num_return_sequences", searchOptions.NumReturnSequences);
+ generatorParams.SetSearchOption("no_repeat_ngram_size", searchOptions.NoRepeatNgramSize);
+ generatorParams.SetSearchOption("min_length", searchOptions.MinLength);
+ generatorParams.SetSearchOption("max_length", searchOptions.MaxLength);
+ generatorParams.SetSearchOption("length_penalty", searchOptions.LengthPenalty);
+ generatorParams.SetSearchOption("early_stopping", searchOptions.EarlyStopping);
+ generatorParams.SetSearchOption("do_sample", searchOptions.DoSample);
+ generatorParams.SetSearchOption("diversity_penalty", searchOptions.DiversityPenalty);
+ }
+
+
+ ///
+ /// Creates the pipeline from a TextGenerationModelSet.
+ ///
+ /// The model set.
+ /// The logger.
+ ///
+ public static TextGenerationPipeline CreatePipeline(TextGenerationModelSet modelSet, ILogger logger = default)
+ {
+ var textGenerationModel = new TextGenerationModel(modelSet.TextGenerationConfig.ApplyDefaults(modelSet));
+ return new TextGenerationPipeline(modelSet.Name, textGenerationModel, logger);
+ }
+
+
+ ///
+ /// Creates the pipeline from the specified file.
+ ///
+ /// The model file.
+ /// The device identifier.
+ /// The execution provider.
+ /// The logger.
+ ///
+ public static TextGenerationPipeline CreatePipeline(string modelFile, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
+ {
+ var name = Path.GetFileNameWithoutExtension(modelFile);
+ var configuration = new TextGenerationModelSet
+ {
+ Name = name,
+ IsEnabled = true,
+ DeviceId = deviceId,
+ ExecutionProvider = executionProvider,
+ TextGenerationConfig = new TextGenerationModelConfig
+ {
+ OnnxModelPath = modelFile
+ }
+ };
+ return CreatePipeline(configuration, logger);
+ }
+ }
+}
diff --git a/OnnxStack.sln b/OnnxStack.sln
index 379223b..fdfc230 100644
--- a/OnnxStack.sln
+++ b/OnnxStack.sln
@@ -17,6 +17,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "OnnxStack.Adapter", "OnnxSt
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxStack.FeatureExtractor", "OnnxStack.FeatureExtractor\OnnxStack.FeatureExtractor.csproj", "{0E8095A4-83EF-48AD-9BD1-0CC2893050F6}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OnnxStack.TextGeneration", "OnnxStack.TextGeneration\OnnxStack.TextGeneration.csproj", "{82C08941-721A-4662-83D0-6819DE6F6477}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
@@ -81,6 +83,14 @@ Global
{0E8095A4-83EF-48AD-9BD1-0CC2893050F6}.Release|x64.Build.0 = Release|x64
{0E8095A4-83EF-48AD-9BD1-0CC2893050F6}.Release-Nvidia|x64.ActiveCfg = Release|x64
{0E8095A4-83EF-48AD-9BD1-0CC2893050F6}.Release-Nvidia|x64.Build.0 = Release|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Debug|x64.ActiveCfg = Debug|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Debug|x64.Build.0 = Debug|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Debug-Nvidia|x64.ActiveCfg = Debug|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Debug-Nvidia|x64.Build.0 = Debug|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Release|x64.ActiveCfg = Release|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Release|x64.Build.0 = Release|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Release-Nvidia|x64.ActiveCfg = Release|x64
+ {82C08941-721A-4662-83D0-6819DE6F6477}.Release-Nvidia|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE