diff --git a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
index c0335823..e152bbbe 100644
--- a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
+++ b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
@@ -21,9 +21,10 @@
-
-
-
+
+
+
+
diff --git a/shell/agents/AIShell.OpenAI.Agent/Agent.cs b/shell/agents/AIShell.OpenAI.Agent/Agent.cs
index 16a86d4e..86ee8edf 100644
--- a/shell/agents/AIShell.OpenAI.Agent/Agent.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/Agent.cs
@@ -1,7 +1,8 @@
+using System.ClientModel;
using System.Text;
using System.Text.Json;
-using Azure.AI.OpenAI;
using AIShell.Abstraction;
+using OpenAI.Chat;
namespace AIShell.OpenAI.Agent;
@@ -106,37 +107,44 @@ public async Task ChatAsync(string input, IShell shell)
return checkPass;
}
- string responseContent = null;
- StreamingResponse response = await host.RunWithSpinnerAsync(
- () => _chatService.GetStreamingChatResponseAsync(input, token)
- ).ConfigureAwait(false);
+ IAsyncEnumerator response = await host
+ .RunWithSpinnerAsync(
+ () => _chatService.GetStreamingChatResponseAsync(input, token)
+ ).ConfigureAwait(false);
if (response is not null)
{
+ StreamingChatCompletionUpdate update = null;
using var streamingRender = host.NewStreamRender(token);
try
{
- await foreach (StreamingChatCompletionsUpdate chatUpdate in response)
+ do
{
- if (string.IsNullOrEmpty(chatUpdate.ContentUpdate))
+ update = response.Current;
+ if (update.ContentUpdate.Count > 0)
{
- continue;
+ streamingRender.Refresh(update.ContentUpdate[0].Text);
}
-
- streamingRender.Refresh(chatUpdate.ContentUpdate);
}
+ while (await response.MoveNextAsync().ConfigureAwait(continueOnCapturedContext: false));
}
catch (OperationCanceledException)
{
- // Ignore the cancellation exception.
+ update = null;
}
- responseContent = streamingRender.AccumulatedContent;
+ if (update is null)
+ {
+ _chatService.CalibrateChatHistory(usage: null, response: null);
+ }
+ else
+ {
+ string responseContent = streamingRender.AccumulatedContent;
+ _chatService.CalibrateChatHistory(update.Usage, new AssistantChatMessage(responseContent));
+ }
}
- _chatService.AddResponseToHistory(responseContent);
-
return checkPass;
}
diff --git a/shell/agents/AIShell.OpenAI.Agent/Helpers.cs b/shell/agents/AIShell.OpenAI.Agent/Helpers.cs
index ad5ac349..cbfccd87 100644
--- a/shell/agents/AIShell.OpenAI.Agent/Helpers.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/Helpers.cs
@@ -3,10 +3,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
-
-using Azure;
-using Azure.Core;
-using Azure.Core.Pipeline;
+using System.ClientModel.Primitives;
namespace AIShell.OpenAI.Agent;
@@ -134,69 +131,25 @@ public override JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions option
}
}
-#nullable enable
-
-///
-/// Used for setting user key for the Azure.OpenAI.Client.
-///
-internal sealed class UserKeyPolicy : HttpPipelineSynchronousPolicy
-{
- private readonly string _name;
- private readonly AzureKeyCredential _credential;
-
- ///
- /// Initializes a new instance of the class.
- ///
- /// The used to authenticate requests.
- /// The name of the key header used for the credential.
- public UserKeyPolicy(AzureKeyCredential credential, string name)
- {
- ArgumentNullException.ThrowIfNull(credential);
- ArgumentException.ThrowIfNullOrEmpty(name);
-
- _credential = credential;
- _name = name;
- }
-
- ///
- public override void OnSendingRequest(HttpMessage message)
- {
- base.OnSendingRequest(message);
- message.Request.Headers.SetValue(_name, _credential.Key);
- }
-}
-
///
-/// Used for configuring the retry policy for Azure.OpenAI.Client.
+/// Initializes a new instance of the class.
///
-internal sealed class ChatRetryPolicy : RetryPolicy
+/// The maximum number of retries to attempt.
+/// The delay to use for computing the interval between retry attempts.
+internal sealed class ChatRetryPolicy(int maxRetries = 2) : ClientRetryPolicy(maxRetries)
{
private const string RetryAfterHeaderName = "Retry-After";
private const string RetryAfterMsHeaderName = "retry-after-ms";
private const string XRetryAfterMsHeaderName = "x-ms-retry-after-ms";
- ///
- /// Initializes a new instance of the class.
- ///
- /// The maximum number of retries to attempt.
- /// The delay to use for computing the interval between retry attempts.
- public ChatRetryPolicy(int maxRetries = 2, DelayStrategy? delayStrategy = default) : base(
- maxRetries,
- delayStrategy ?? DelayStrategy.CreateExponentialDelayStrategy(
- initialDelay: TimeSpan.FromSeconds(0.8),
- maxDelay: TimeSpan.FromSeconds(5)))
- {
- // By default, we retry 2 times at most, and use a delay strategy that waits 5 seconds at most between retries.
- }
-
- protected override bool ShouldRetry(HttpMessage message, Exception? exception) => ShouldRetryImpl(message, exception);
- protected override ValueTask ShouldRetryAsync(HttpMessage message, Exception? exception) => new(ShouldRetryImpl(message, exception));
+ protected override bool ShouldRetry(PipelineMessage message, Exception exception) => ShouldRetryImpl(message, exception);
+ protected override ValueTask ShouldRetryAsync(PipelineMessage message, Exception exception) => new(ShouldRetryImpl(message, exception));
- private bool ShouldRetryImpl(HttpMessage message, Exception? exception)
+ private bool ShouldRetryImpl(PipelineMessage message, Exception exception)
{
bool result = base.ShouldRetry(message, exception);
- if (result && message.HasResponse)
+ if (result && message.Response is not null)
{
TimeSpan? retryAfter = GetRetryAfterHeaderValue(message.Response.Headers);
if (retryAfter > TimeSpan.FromSeconds(5))
@@ -209,22 +162,22 @@ private bool ShouldRetryImpl(HttpMessage message, Exception? exception)
return result;
}
- private static TimeSpan? GetRetryAfterHeaderValue(ResponseHeaders headers)
+ private static TimeSpan? GetRetryAfterHeaderValue(PipelineResponseHeaders headers)
{
if (headers.TryGetValue(RetryAfterMsHeaderName, out var retryAfterValue) ||
headers.TryGetValue(XRetryAfterMsHeaderName, out retryAfterValue))
{
- if (int.TryParse(retryAfterValue, out var delaySeconds))
+ if (int.TryParse(retryAfterValue, out var delayInMS))
{
- return TimeSpan.FromMilliseconds(delaySeconds);
+ return TimeSpan.FromMilliseconds(delayInMS);
}
}
if (headers.TryGetValue(RetryAfterHeaderName, out retryAfterValue))
{
- if (int.TryParse(retryAfterValue, out var delaySeconds))
+ if (int.TryParse(retryAfterValue, out var delayInSec))
{
- return TimeSpan.FromSeconds(delaySeconds);
+ return TimeSpan.FromSeconds(delayInSec);
}
if (DateTimeOffset.TryParse(retryAfterValue, out DateTimeOffset delayTime))
diff --git a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
index 1c509648..654f090d 100644
--- a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
@@ -1,18 +1,17 @@
-using SharpToken;
+using Microsoft.ML.Tokenizers;
namespace AIShell.OpenAI.Agent;
internal class ModelInfo
{
// Models gpt4, gpt3.5, and the variants of them all use the 'cl100k_base' token encoding.
- // But the gpt-4o model uses the 'o200k_base' token encoding. For reference:
- // https://github.com/openai/tiktoken/blob/5d970c1100d3210b42497203d6b5c1e30cfda6cb/tiktoken/model.py#L7
- // https://github.com/dmitry-brazhenko/SharpToken/blob/main/SharpToken/Lib/Model.cs#L8
+ // But gpt-4o and o1 models use the 'o200k_base' token encoding. For reference:
+ // https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/model.py
private const string Gpt4oEncoding = "o200k_base";
private const string Gpt34Encoding = "cl100k_base";
private static readonly Dictionary s_modelMap;
- private static readonly Dictionary> s_encodingMap;
+ private static readonly Dictionary> s_encodingMap;
static ModelInfo()
{
@@ -21,6 +20,7 @@ static ModelInfo()
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
s_modelMap = new(StringComparer.OrdinalIgnoreCase)
{
+ ["o1"] = new(tokenLimit: 200_000, encoding: Gpt4oEncoding),
["gpt-4o"] = new(tokenLimit: 128_000, encoding: Gpt4oEncoding),
["gpt-4"] = new(tokenLimit: 8_192),
["gpt-4-32k"] = new(tokenLimit: 32_768),
@@ -35,8 +35,8 @@ static ModelInfo()
// we don't block the startup and the values will be ready when we really need them.
s_encodingMap = new(StringComparer.OrdinalIgnoreCase)
{
- [Gpt34Encoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt34Encoding)),
- [Gpt4oEncoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt4oEncoding))
+ [Gpt34Encoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt34Encoding)),
+ [Gpt4oEncoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt4oEncoding))
};
}
@@ -45,24 +45,24 @@ private ModelInfo(int tokenLimit, string encoding = null)
TokenLimit = tokenLimit;
_encodingName = encoding ?? Gpt34Encoding;
- // For gpt4 and gpt3.5-turbo, the following 2 properties are the same.
+ // For gpt4o, gpt4 and gpt3.5-turbo, the following 2 properties are the same.
// See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
TokensPerMessage = 3;
TokensPerName = 1;
}
private readonly string _encodingName;
- private GptEncoding _gptEncoding;
+ private Tokenizer _gptEncoding;
internal int TokenLimit { get; }
internal int TokensPerMessage { get; }
internal int TokensPerName { get; }
- internal GptEncoding Encoding
+ internal Tokenizer Encoding
{
get {
- _gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task value)
+ _gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task value)
? value.Result
- : GptEncoding.GetEncoding(_encodingName);
+ : TiktokenTokenizer.CreateForEncoding(_encodingName);
return _gptEncoding;
}
}
diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs
index 22fcc48c..583a90d6 100644
--- a/shell/agents/AIShell.OpenAI.Agent/Service.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs
@@ -1,8 +1,9 @@
-using System.Diagnostics;
-using Azure;
-using Azure.Core;
+using System.ClientModel;
+using System.ClientModel.Primitives;
using Azure.AI.OpenAI;
-using SharpToken;
+using Microsoft.ML.Tokenizers;
+using OpenAI;
+using OpenAI.Chat;
namespace AIShell.OpenAI.Agent;
@@ -10,22 +11,34 @@ internal class ChatService
{
// TODO: Maybe expose this to our model registration?
// We can still use 1000 as the default value.
- private const int MaxResponseToken = 1000;
+ private const int MaxResponseToken = 2000;
private readonly string _historyRoot;
- private readonly List _chatHistory;
+ private readonly List _chatHistory;
+ private readonly List _chatHistoryTokens;
+ private readonly ChatCompletionOptions _chatOptions;
private GPT _gptToUse;
private Settings _settings;
- private OpenAIClient _client;
+ private ChatClient _client;
+ private int _totalInputToken;
internal ChatService(string historyRoot, Settings settings)
{
+ _chatHistory = [];
+ _chatHistoryTokens = [];
_historyRoot = historyRoot;
+
+ _totalInputToken = 0;
_settings = settings;
- _chatHistory = [];
+
+ _chatOptions = new ChatCompletionOptions()
+ {
+ Temperature = 0,
+ MaxOutputTokenCount = MaxResponseToken,
+ };
}
- internal List ChatHistory => _chatHistory;
+ internal List ChatHistory => _chatHistory;
internal void AddResponseToHistory(string response)
{
@@ -34,7 +47,7 @@ internal void AddResponseToHistory(string response)
return;
}
- _chatHistory.Add(new ChatRequestAssistantMessage(response));
+ _chatHistory.Add(ChatMessage.CreateAssistantMessage(response));
}
internal void RefreshSettings(Settings settings)
@@ -42,6 +55,49 @@ internal void RefreshSettings(Settings settings)
_settings = settings;
}
+ ///
+ /// It's almost impossible to relative-accurately calculate the token counts of all
+ /// messages, especially when tool calls are involved (tool call definitions and the
+ /// tool call payloads in AI response).
+ /// So, I decide to leverage the useage report from AI to track the token count of
+ /// the chat history. It's also an estimate, but I think more accurate than doing the
+ /// counting by ourselves.
+ ///
+ internal void CalibrateChatHistory(ChatTokenUsage usage, AssistantChatMessage response)
+ {
+ if (usage is null)
+ {
+ // Response was cancelled and we will remove the last query from history.
+ int index = _chatHistory.Count - 1;
+ _chatHistory.RemoveAt(index);
+ _chatHistoryTokens.RemoveAt(index);
+
+ return;
+ }
+
+ // Every reply is primed with <|start|>assistant<|message|>, so we subtract 3 from the 'InputTokenCount'.
+ int promptTokenCount = usage.InputTokenCount - 3;
+ // 'ReasoningTokenCount' should be 0 for non-o1 models.
+ int reasoningTokenCount = usage.OutputTokenDetails is null ? 0 : usage.OutputTokenDetails.ReasoningTokenCount;
+ int responseTokenCount = usage.OutputTokenCount - reasoningTokenCount;
+
+ if (_totalInputToken is 0)
+ {
+ // It was the first user message, so instead of adjusting the user message token count,
+ // we set the token count for system message and tool calls.
+ _chatHistoryTokens[0] = promptTokenCount - _chatHistoryTokens[^1];
+ }
+ else
+ {
+ // Adjust the token count of the user message, as our calculation is an estimate.
+ _chatHistoryTokens[^1] = promptTokenCount - _totalInputToken;
+ }
+
+ _chatHistory.Add(response);
+ _chatHistoryTokens.Add(responseTokenCount);
+ _totalInputToken = promptTokenCount + responseTokenCount;
+ }
+
private void RefreshOpenAIClient()
{
if (ReferenceEquals(_gptToUse, _settings.Active))
@@ -53,6 +109,7 @@ private void RefreshOpenAIClient()
GPT old = _gptToUse;
_gptToUse = _settings.Active;
_chatHistory.Clear();
+ _chatHistoryTokens.Clear();
if (old is not null
&& old.Type == _gptToUse.Type
@@ -65,211 +122,127 @@ private void RefreshOpenAIClient()
return;
}
- var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
+ string userKey = Utils.ConvertFromSecureString(_gptToUse.Key);
if (_gptToUse.Type is EndpointType.AzureOpenAI)
{
// Create a client that targets Azure OpenAI service or Azure API Management service.
+ var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
bool isApimEndpoint = _gptToUse.Endpoint.EndsWith(Utils.ApimGatewayDomain);
+
if (isApimEndpoint)
{
- string userkey = Utils.ConvertFromSecureString(_gptToUse.Key);
clientOptions.AddPolicy(
- new UserKeyPolicy(
- new AzureKeyCredential(userkey),
+ ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy(
+ new ApiKeyCredential(userKey),
Utils.ApimAuthorizationHeader),
- HttpPipelinePosition.PerRetry
- );
+ PipelinePosition.PerTry);
}
- string azOpenAIApiKey = isApimEndpoint
- ? "placeholder-api-key"
- : Utils.ConvertFromSecureString(_gptToUse.Key);
+ string azOpenAIApiKey = isApimEndpoint ? "placeholder-api-key" : userKey;
- _client = new OpenAIClient(
+ var aiClient = new AzureOpenAIClient(
new Uri(_gptToUse.Endpoint),
- new AzureKeyCredential(azOpenAIApiKey),
+ new ApiKeyCredential(azOpenAIApiKey),
clientOptions);
+
+ _client = aiClient.GetChatClient(_gptToUse.Deployment);
}
else
{
// Create a client that targets the non-Azure OpenAI service.
- _client = new OpenAIClient(Utils.ConvertFromSecureString(_gptToUse.Key), clientOptions);
+ var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
+ var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions);
+ _client = aiClient.GetChatClient(_gptToUse.ModelName);
}
}
- private int CountTokenForMessages(IEnumerable messages)
+ ///
+ /// Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ ///
+ private int CountTokenForUserMessage(UserChatMessage message)
{
ModelInfo modelDetail = _gptToUse.ModelInfo;
- GptEncoding encoding = modelDetail.Encoding;
- int tokensPerMessage = modelDetail.TokensPerMessage;
- int tokensPerName = modelDetail.TokensPerName;
+ Tokenizer encoding = modelDetail.Encoding;
- int tokenNumber = 0;
- foreach (ChatRequestMessage message in messages)
+ // Tokens per message plus 1 token for the role.
+ int tokenNumber = modelDetail.TokensPerMessage + 1;
+ foreach (ChatMessageContentPart part in message.Content)
{
- tokenNumber += tokensPerMessage;
- tokenNumber += encoding.Encode(message.Role.ToString()).Count;
-
- switch (message)
- {
- case ChatRequestSystemMessage systemMessage:
- tokenNumber += SimpleCountToken(systemMessage.Content, systemMessage.Name);
- break;
-
- case ChatRequestUserMessage userMessage:
- tokenNumber += SimpleCountToken(userMessage.Content, userMessage.Name);
- break;
-
- case ChatRequestAssistantMessage assistantMessage:
- tokenNumber += SimpleCountToken(assistantMessage.Content, assistantMessage.Name);
- if (assistantMessage.ToolCalls is not null)
- {
- // Count tokens for the tool call's properties
- foreach(ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
- {
- if(chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
- {
- tokenNumber += encoding.Encode(functionToolCall.Id).Count;
- tokenNumber += encoding.Encode(functionToolCall.Name).Count;
- tokenNumber += encoding.Encode(functionToolCall.Arguments).Count;
- }
- }
- }
- break;
-
- case ChatRequestToolMessage toolMessage:
- tokenNumber += encoding.Encode(toolMessage.ToolCallId).Count;
- tokenNumber += encoding.Encode(toolMessage.Content).Count;
- break;
- // Add cases for other derived types as needed
- }
+ tokenNumber += encoding.CountTokens(part.Text);
}
- // Every reply is primed with <|start|>assistant<|message|>, which takes 3 tokens.
- tokenNumber += 3;
return tokenNumber;
-
- // ----- Local Function -----
- int SimpleCountToken(string content, string name)
- {
- int sum = 0;
- if (!string.IsNullOrEmpty(content))
- {
- sum = encoding.Encode(content).Count;
- }
-
- if (!string.IsNullOrEmpty(name))
- {
- sum += tokensPerName;
- sum += encoding.Encode(name).Count;
- }
-
- return sum;
- }
}
- private void ReduceChatHistoryAsNeeded(List history, ChatRequestMessage input)
+ private void PrepareForChat(string input)
{
- bool inputTooLong = false;
- int tokenLimit = _gptToUse.ModelInfo.TokenLimit;
+ // Refresh the client in case the active model was changed.
+ RefreshOpenAIClient();
- do
+ if (_chatHistory.Count is 0)
{
- int totalTokens = CountTokenForMessages(Enumerable.Repeat(input, 1));
- if (totalTokens + MaxResponseToken >= tokenLimit)
- {
- // The input itself already exceeds the token limit.
- inputTooLong = true;
- break;
- }
+ _chatHistory.Add(ChatMessage.CreateSystemMessage(_gptToUse.SystemPrompt));
+ _chatHistoryTokens.Add(0);
+ }
- history.Add(input);
- totalTokens = CountTokenForMessages(history);
+ var userMessage = new UserChatMessage(input);
+ int msgTokenCnt = CountTokenForUserMessage(userMessage);
+ _chatHistory.Add(userMessage);
+ _chatHistoryTokens.Add(msgTokenCnt);
- int index = -1;
- while (totalTokens + MaxResponseToken >= tokenLimit)
+ int inputLimit = _gptToUse.ModelInfo.TokenLimit;
+ // Every reply is primed with <|start|>assistant<|message|>, so adding 3 tokens.
+ int newTotal = _totalInputToken + msgTokenCnt + 3;
+
+ // Shrink the chat history if we have less than 50 free tokens left (50-token buffer).
+ while (inputLimit - newTotal < 50)
+ {
+ // We remove a round of conversation for every trimming operation.
+ int userMsgCnt = 0;
+ List indices = [];
+
+ for (int i = 0; i < _chatHistory.Count; i++)
{
- if (index is -1)
+ if (_chatHistory[i] is UserChatMessage)
{
- // Find the first non-system message.
- for (index = 0; history[index] is ChatRequestSystemMessage; index++);
+ if (userMsgCnt is 1)
+ {
+ break;
+ }
+
+ userMsgCnt++;
}
- if (history[index] == input)
+ if (userMsgCnt is 1)
{
- // The input plus system message exceeds the token limit.
- inputTooLong = true;
- break;
+ indices.Add(i);
}
-
- history.RemoveAt(index);
- totalTokens = CountTokenForMessages(history);
}
- }
- while (false);
- if (inputTooLong)
- {
- var message = $"The input is too long to get a proper response without exceeding the token limit ({tokenLimit}).\nPlease reduce the input and try again.";
- throw new InvalidOperationException(message);
- }
- }
-
- private ChatCompletionsOptions PrepareForChat(string input)
- {
- // Refresh the client in case the active model was changed.
- RefreshOpenAIClient();
-
- // TODO: Shall we expose some of the setting properties to our model registration?
- // - max_tokens
- // - temperature
- // - top_p
- // - presence_penalty
- // - frequency_penalty
- // Those settings seem to be important enough, as the Semantic Kernel plugin specifies
- // those settings (see the URL below). We can use default values when not defined.
- // https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json
- string deploymentOrModelName = _gptToUse.Type switch
- {
- EndpointType.AzureOpenAI => _gptToUse.Deployment,
- EndpointType.OpenAI => _gptToUse.ModelName,
- _ => throw new UnreachableException(),
- };
-
- ChatCompletionsOptions chatOptions = new()
- {
- DeploymentName = deploymentOrModelName,
- ChoiceCount = 1,
- Temperature = 0,
- MaxTokens = MaxResponseToken,
- };
-
- List history = _chatHistory;
- if (history.Count is 0)
- {
- history.Add(new ChatRequestSystemMessage(_gptToUse.SystemPrompt));
- }
+ foreach (int i in indices)
+ {
+ newTotal -= _chatHistoryTokens[i];
+ }
- ReduceChatHistoryAsNeeded(history, new ChatRequestUserMessage(input));
- foreach (ChatRequestMessage message in history)
- {
- chatOptions.Messages.Add(message);
+ _chatHistory.RemoveRange(indices[0], indices.Count);
+ _chatHistoryTokens.RemoveRange(indices[0], indices.Count);
+ _totalInputToken = newTotal - msgTokenCnt;
}
-
- return chatOptions;
}
- public async Task> GetStreamingChatResponseAsync(string input, CancellationToken cancellationToken = default)
+ public async Task> GetStreamingChatResponseAsync(string input, CancellationToken cancellationToken)
{
try
{
- ChatCompletionsOptions chatOptions = PrepareForChat(input);
- var response = await _client.GetChatCompletionsStreamingAsync(
- chatOptions,
- cancellationToken);
-
- return response;
+ PrepareForChat(input);
+ IAsyncEnumerator enumerator = _client
+ .CompleteChatStreamingAsync(_chatHistory, _chatOptions, cancellationToken)
+ .GetAsyncEnumerator(cancellationToken);
+
+ return await enumerator
+ .MoveNextAsync()
+ .ConfigureAwait(continueOnCapturedContext: false) ? enumerator : null;
}
catch (OperationCanceledException)
{