diff --git a/OnnxStack.Console/Examples/TextGenerationExample.cs b/OnnxStack.Console/Examples/TextGenerationExample.cs new file mode 100644 index 0000000..6706d9d --- /dev/null +++ b/OnnxStack.Console/Examples/TextGenerationExample.cs @@ -0,0 +1,37 @@ +using OnnxStack.Core.Config; +using OnnxStack.TextGeneration.Models; +using OnnxStack.TextGeneration.Pipelines; + +namespace OnnxStack.Console.Runner +{ + public sealed class TextGenerationExample : IExampleRunner + { + public TextGenerationExample() + { + } + + public int Index => 40; + + public string Name => "Text Generation Demo"; + + public string Description => "Text Generation Example"; + + public async Task RunAsync() + { + var pipeline = TextGenerationPipeline.CreatePipeline("D:\\Repositories\\phi2_onnx", executionProvider: ExecutionProvider.Cuda); + + await pipeline.LoadAsync(); + + while (true) + { + OutputHelpers.WriteConsole("Enter Prompt: ", ConsoleColor.Gray); + var promptOptions = new PromptOptionsModel(OutputHelpers.ReadConsole(ConsoleColor.Cyan)); + var searchOptions = new SearchOptionsModel(); + await foreach (var token in pipeline.RunAsync(promptOptions, searchOptions)) + { + OutputHelpers.WriteConsole(token.Content, ConsoleColor.Yellow, false); + } + } + } + } +} diff --git a/OnnxStack.Console/OnnxStack.Console.csproj b/OnnxStack.Console/OnnxStack.Console.csproj index 2b16d82..ed4dfc6 100644 --- a/OnnxStack.Console/OnnxStack.Console.csproj +++ b/OnnxStack.Console/OnnxStack.Console.csproj @@ -22,6 +22,7 @@ + diff --git a/OnnxStack.TextGeneration/Binaries/cuda/Microsoft.ML.OnnxRuntimeGenAI.dll b/OnnxStack.TextGeneration/Binaries/cuda/Microsoft.ML.OnnxRuntimeGenAI.dll new file mode 100644 index 0000000..426b96c Binary files /dev/null and b/OnnxStack.TextGeneration/Binaries/cuda/Microsoft.ML.OnnxRuntimeGenAI.dll differ diff --git a/OnnxStack.TextGeneration/Binaries/cuda/onnxruntime-genai.dll b/OnnxStack.TextGeneration/Binaries/cuda/onnxruntime-genai.dll new file mode 100644 index 0000000..5e42e9b Binary files /dev/null and b/OnnxStack.TextGeneration/Binaries/cuda/onnxruntime-genai.dll differ diff --git a/OnnxStack.TextGeneration/Common/TextGenerationModel.cs b/OnnxStack.TextGeneration/Common/TextGenerationModel.cs new file mode 100644 index 0000000..92921bd --- /dev/null +++ b/OnnxStack.TextGeneration/Common/TextGenerationModel.cs @@ -0,0 +1,70 @@ +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntimeGenAI; +using OnnxStack.Core.Config; + +namespace OnnxStack.TextGeneration.Common +{ + public class TextGenerationModel : IDisposable //: OnnxModelSession + { + private Model _model; + private Tokenizer _tokenizer; + private readonly TextGenerationModelConfig _configuration; + public TextGenerationModel(TextGenerationModelConfig configuration) + { + _configuration = configuration; + } + + public Model Model => _model; + public Tokenizer Tokenizer => _tokenizer; + + + /// + /// Loads the model session. + /// + public async Task LoadAsync() + { + if (_model is not null) + return; // Already Loaded + + await Task.Run(() => + { + _model = new Model(_configuration.OnnxModelPath); + _tokenizer = new Tokenizer(_model); + }); + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public void Dispose() + { + _tokenizer?.Dispose(); + _model?.Dispose(); + _model = null; + _tokenizer = null; + } + + + public static TextGenerationModel Create(TextGenerationModelConfig configuration) + { + return new TextGenerationModel(configuration); + } + + public static TextGenerationModel Create(string modelPath, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + { + var configuration = new TextGenerationModelConfig + { + DeviceId = deviceId, + ExecutionProvider = executionProvider, + ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, + InterOpNumThreads = 0, + IntraOpNumThreads = 0, + OnnxModelPath = modelPath + }; + return new TextGenerationModel(configuration); + } + + + } +} diff --git a/OnnxStack.TextGeneration/Common/TextGenerationModelConfig.cs b/OnnxStack.TextGeneration/Common/TextGenerationModelConfig.cs new file mode 100644 index 0000000..914d77a --- /dev/null +++ b/OnnxStack.TextGeneration/Common/TextGenerationModelConfig.cs @@ -0,0 +1,9 @@ +using OnnxStack.Core.Config; + +namespace OnnxStack.TextGeneration.Common +{ + public record TextGenerationModelConfig : OnnxModelConfig + { + + } +} diff --git a/OnnxStack.TextGeneration/Common/TextGenerationModelSet.cs b/OnnxStack.TextGeneration/Common/TextGenerationModelSet.cs new file mode 100644 index 0000000..a558231 --- /dev/null +++ b/OnnxStack.TextGeneration/Common/TextGenerationModelSet.cs @@ -0,0 +1,20 @@ +using Microsoft.ML.OnnxRuntime; +using OnnxStack.Core.Config; +using System.Text.Json.Serialization; + +namespace OnnxStack.TextGeneration.Common +{ + public record TextGenerationModelSet : IOnnxModelSetConfig + { + public string Name { get; set; } + public bool IsEnabled { get; set; } + public int DeviceId { get; set; } + public int InterOpNumThreads { get; set; } + public int IntraOpNumThreads { get; set; } + public ExecutionMode ExecutionMode { get; set; } + public ExecutionProvider ExecutionProvider { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public TextGenerationModelConfig TextGenerationConfig { get; set; } + } +} diff --git a/OnnxStack.TextGeneration/Models/PromptOptionsModel.cs b/OnnxStack.TextGeneration/Models/PromptOptionsModel.cs new file mode 100644 index 0000000..508103c --- /dev/null +++ b/OnnxStack.TextGeneration/Models/PromptOptionsModel.cs @@ -0,0 +1,4 @@ +namespace OnnxStack.TextGeneration.Models +{ + public record class PromptOptionsModel(string Prompt); +} diff --git a/OnnxStack.TextGeneration/Models/SearchOptionsModel.cs b/OnnxStack.TextGeneration/Models/SearchOptionsModel.cs new file mode 100644 index 0000000..fea74b1 --- /dev/null +++ b/OnnxStack.TextGeneration/Models/SearchOptionsModel.cs @@ -0,0 +1,20 @@ +namespace OnnxStack.TextGeneration.Models +{ + public class SearchOptionsModel + { + public int TopK { get; set; } = 50; + public float TopP { get; set; } = 0.95f; + public float Temperature { get; set; } = 1; + public float RepetitionPenalty { get; set; } = 0.9f; + public bool PastPresentShareBuffer { get; set; } = false; + public int NumReturnSequences { get; set; } = 1; + public int NumBeams { get; set; } = 1; + public int NoRepeatNgramSize { get; set; } = 0; + public int MinLength { get; set; } = 0; + public int MaxLength { get; set; } = 512; + public float LengthPenalty { get; set; } = 1; + public float DiversityPenalty { get; set; } = 0; + public bool EarlyStopping { get; set; } = true; + public bool DoSample { get; set; } = false; + } +} \ No newline at end of file diff --git a/OnnxStack.TextGeneration/Models/TokenModel.cs b/OnnxStack.TextGeneration/Models/TokenModel.cs new file mode 100644 index 0000000..8c013c7 --- /dev/null +++ b/OnnxStack.TextGeneration/Models/TokenModel.cs @@ -0,0 +1,4 @@ +namespace OnnxStack.TextGeneration.Models +{ + public readonly record struct TokenModel(int Id, string Content); +} \ No newline at end of file diff --git a/OnnxStack.TextGeneration/OnnxStack.TextGeneration.csproj b/OnnxStack.TextGeneration/OnnxStack.TextGeneration.csproj new file mode 100644 index 0000000..ef26c41 --- /dev/null +++ b/OnnxStack.TextGeneration/OnnxStack.TextGeneration.csproj @@ -0,0 +1,42 @@ + + + + net7.0 + enable + disable + x64 + x64 + + + + + \ + True + + + + + + + + + + + + + + + Binaries\cuda\Microsoft.ML.OnnxRuntimeGenAI.dll + + + + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/OnnxStack.TextGeneration/Pipelines/TextGenerationPipeline.cs b/OnnxStack.TextGeneration/Pipelines/TextGenerationPipeline.cs new file mode 100644 index 0000000..8a149a5 --- /dev/null +++ b/OnnxStack.TextGeneration/Pipelines/TextGenerationPipeline.cs @@ -0,0 +1,182 @@ +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntimeGenAI; +using OnnxStack.Core; +using OnnxStack.Core.Config; +using OnnxStack.TextGeneration.Common; +using OnnxStack.TextGeneration.Models; +using System.Runtime.CompilerServices; + +namespace OnnxStack.TextGeneration.Pipelines +{ + + public class TextGenerationPipeline + { + private readonly string _name; + private readonly ILogger _logger; + private readonly TextGenerationModel _model; + + /// + /// Initializes a new instance of the class. + /// + /// The name. + /// The text generation model. + /// The logger. + public TextGenerationPipeline(string name, TextGenerationModel model, ILogger logger = default) + { + _name = name; + _logger = logger; + _model = model; + } + + + /// + /// Gets the name. + /// + /// + public string Name => _name; + + + /// + /// Loads the model. + /// + /// + public Task LoadAsync() + { + return _model.LoadAsync(); + } + + + /// + /// Unloads the models. + /// + public async Task UnloadAsync() + { + await Task.Yield(); + _model?.Dispose(); + } + + + /// + /// Runs the text generation pipeline + /// + /// The image frames. + /// The cancellation token. + /// + public IAsyncEnumerable RunAsync(PromptOptionsModel promptOptions, SearchOptionsModel searchOptions, CancellationToken cancellationToken = default) + { + return RunInternalAsync(promptOptions, searchOptions, cancellationToken); + } + + + /// + /// Runs the text generation pipeline + /// + /// The input image. + /// The cancellation token. + /// + private async IAsyncEnumerable RunInternalAsync(PromptOptionsModel promptOptions, SearchOptionsModel searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var timestamp = _logger?.LogBegin("Run text generation pipeline stream..."); + var sequences = await EncodePrompt(promptOptions, cancellationToken); + + using (var generatorParams = new GeneratorParams(_model.Model)) + { + ApplySearchOptions(generatorParams, searchOptions); + generatorParams.SetInputSequences(sequences); + + using (var tokenizerStream = _model.Tokenizer.CreateStream()) + using (var generator = new Generator(_model.Model, generatorParams)) + { + while (!generator.IsDone()) + { + cancellationToken.ThrowIfCancellationRequested(); + + yield return await Task.Run(() => + { + generator.ComputeLogits(); + generator.GenerateNextTokenTop(); + + var tokenId = generator.GetSequence(0)[^1]; + return new TokenModel(tokenId, tokenizerStream.Decode(tokenId)); + }, cancellationToken); + } + } + } + _logger?.LogEnd("Text generation pipeline stream complete.", timestamp); + } + + + /// + /// Encodes the prompt. + /// + /// The prompt options. + /// The cancellation token. + /// + private async Task EncodePrompt(PromptOptionsModel promptOptions, CancellationToken cancellationToken = default) + { + return await Task.Run(() => _model.Tokenizer.Encode(promptOptions.Prompt), cancellationToken); + } + + + /// + /// Applies the search options to the GeneratorParams instance. + /// + /// The generator parameters. + /// The search options. + private static void ApplySearchOptions(GeneratorParams generatorParams, SearchOptionsModel searchOptions) + { + generatorParams.SetSearchOption("top_p", searchOptions.TopP); + generatorParams.SetSearchOption("top_k", searchOptions.TopK); + generatorParams.SetSearchOption("temperature", searchOptions.Temperature); + generatorParams.SetSearchOption("repetition_penalty", searchOptions.RepetitionPenalty); + generatorParams.SetSearchOption("past_present_share_buffer", searchOptions.PastPresentShareBuffer); + generatorParams.SetSearchOption("num_return_sequences", searchOptions.NumReturnSequences); + generatorParams.SetSearchOption("no_repeat_ngram_size", searchOptions.NoRepeatNgramSize); + generatorParams.SetSearchOption("min_length", searchOptions.MinLength); + generatorParams.SetSearchOption("max_length", searchOptions.MaxLength); + generatorParams.SetSearchOption("length_penalty", searchOptions.LengthPenalty); + generatorParams.SetSearchOption("early_stopping", searchOptions.EarlyStopping); + generatorParams.SetSearchOption("do_sample", searchOptions.DoSample); + generatorParams.SetSearchOption("diversity_penalty", searchOptions.DiversityPenalty); + } + + + /// + /// Creates the pipeline from a TextGenerationModelSet. + /// + /// The model set. + /// The logger. + /// + public static TextGenerationPipeline CreatePipeline(TextGenerationModelSet modelSet, ILogger logger = default) + { + var textGenerationModel = new TextGenerationModel(modelSet.TextGenerationConfig.ApplyDefaults(modelSet)); + return new TextGenerationPipeline(modelSet.Name, textGenerationModel, logger); + } + + + /// + /// Creates the pipeline from the specified file. + /// + /// The model file. + /// The device identifier. + /// The execution provider. + /// The logger. + /// + public static TextGenerationPipeline CreatePipeline(string modelFile, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default) + { + var name = Path.GetFileNameWithoutExtension(modelFile); + var configuration = new TextGenerationModelSet + { + Name = name, + IsEnabled = true, + DeviceId = deviceId, + ExecutionProvider = executionProvider, + TextGenerationConfig = new TextGenerationModelConfig + { + OnnxModelPath = modelFile + } + }; + return CreatePipeline(configuration, logger); + } + } +} diff --git a/OnnxStack.sln b/OnnxStack.sln index 379223b..fdfc230 100644 --- a/OnnxStack.sln +++ b/OnnxStack.sln @@ -17,6 +17,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "OnnxStack.Adapter", "OnnxSt EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OnnxStack.FeatureExtractor", "OnnxStack.FeatureExtractor\OnnxStack.FeatureExtractor.csproj", "{0E8095A4-83EF-48AD-9BD1-0CC2893050F6}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OnnxStack.TextGeneration", "OnnxStack.TextGeneration\OnnxStack.TextGeneration.csproj", "{82C08941-721A-4662-83D0-6819DE6F6477}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|x64 = Debug|x64 @@ -81,6 +83,14 @@ Global {0E8095A4-83EF-48AD-9BD1-0CC2893050F6}.Release|x64.Build.0 = Release|x64 {0E8095A4-83EF-48AD-9BD1-0CC2893050F6}.Release-Nvidia|x64.ActiveCfg = Release|x64 {0E8095A4-83EF-48AD-9BD1-0CC2893050F6}.Release-Nvidia|x64.Build.0 = Release|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Debug|x64.ActiveCfg = Debug|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Debug|x64.Build.0 = Debug|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Debug-Nvidia|x64.ActiveCfg = Debug|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Debug-Nvidia|x64.Build.0 = Debug|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Release|x64.ActiveCfg = Release|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Release|x64.Build.0 = Release|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Release-Nvidia|x64.ActiveCfg = Release|x64 + {82C08941-721A-4662-83D0-6819DE6F6477}.Release-Nvidia|x64.Build.0 = Release|x64 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE