From 733957094bf63c51188257f76cef1760dd269d81 Mon Sep 17 00:00:00 2001 From: Dongbo Wang Date: Fri, 17 Jan 2025 15:42:47 -0800 Subject: [PATCH 1/2] Refactoring work to move to Azure.AI.OpenAI v2.1.0 --- .../AIShell.OpenAI.Agent.csproj | 7 +- shell/agents/AIShell.OpenAI.Agent/Agent.cs | 36 ++- shell/agents/AIShell.OpenAI.Agent/Helpers.cs | 75 +---- .../agents/AIShell.OpenAI.Agent/ModelInfo.cs | 18 +- shell/agents/AIShell.OpenAI.Agent/Service.cs | 302 ++++++++---------- 5 files changed, 186 insertions(+), 252 deletions(-) diff --git a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj index c0335823..e152bbbe 100644 --- a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj +++ b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj @@ -21,9 +21,10 @@ - - - + + + + diff --git a/shell/agents/AIShell.OpenAI.Agent/Agent.cs b/shell/agents/AIShell.OpenAI.Agent/Agent.cs index 16a86d4e..86ee8edf 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Agent.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Agent.cs @@ -1,7 +1,8 @@ +using System.ClientModel; using System.Text; using System.Text.Json; -using Azure.AI.OpenAI; using AIShell.Abstraction; +using OpenAI.Chat; namespace AIShell.OpenAI.Agent; @@ -106,37 +107,44 @@ public async Task ChatAsync(string input, IShell shell) return checkPass; } - string responseContent = null; - StreamingResponse response = await host.RunWithSpinnerAsync( - () => _chatService.GetStreamingChatResponseAsync(input, token) - ).ConfigureAwait(false); + IAsyncEnumerator response = await host + .RunWithSpinnerAsync( + () => _chatService.GetStreamingChatResponseAsync(input, token) + ).ConfigureAwait(false); if (response is not null) { + StreamingChatCompletionUpdate update = null; using var streamingRender = host.NewStreamRender(token); try { - await foreach (StreamingChatCompletionsUpdate chatUpdate in response) + do { - if (string.IsNullOrEmpty(chatUpdate.ContentUpdate)) + update = response.Current; + if (update.ContentUpdate.Count > 0) { - continue; + streamingRender.Refresh(update.ContentUpdate[0].Text); } - - streamingRender.Refresh(chatUpdate.ContentUpdate); } + while (await response.MoveNextAsync().ConfigureAwait(continueOnCapturedContext: false)); } catch (OperationCanceledException) { - // Ignore the cancellation exception. + update = null; } - responseContent = streamingRender.AccumulatedContent; + if (update is null) + { + _chatService.CalibrateChatHistory(usage: null, response: null); + } + else + { + string responseContent = streamingRender.AccumulatedContent; + _chatService.CalibrateChatHistory(update.Usage, new AssistantChatMessage(responseContent)); + } } - _chatService.AddResponseToHistory(responseContent); - return checkPass; } diff --git a/shell/agents/AIShell.OpenAI.Agent/Helpers.cs b/shell/agents/AIShell.OpenAI.Agent/Helpers.cs index ad5ac349..cbfccd87 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Helpers.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Helpers.cs @@ -3,10 +3,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; - -using Azure; -using Azure.Core; -using Azure.Core.Pipeline; +using System.ClientModel.Primitives; namespace AIShell.OpenAI.Agent; @@ -134,69 +131,25 @@ public override JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions option } } -#nullable enable - -/// -/// Used for setting user key for the Azure.OpenAI.Client. -/// -internal sealed class UserKeyPolicy : HttpPipelineSynchronousPolicy -{ - private readonly string _name; - private readonly AzureKeyCredential _credential; - - /// - /// Initializes a new instance of the class. - /// - /// The used to authenticate requests. - /// The name of the key header used for the credential. - public UserKeyPolicy(AzureKeyCredential credential, string name) - { - ArgumentNullException.ThrowIfNull(credential); - ArgumentException.ThrowIfNullOrEmpty(name); - - _credential = credential; - _name = name; - } - - /// - public override void OnSendingRequest(HttpMessage message) - { - base.OnSendingRequest(message); - message.Request.Headers.SetValue(_name, _credential.Key); - } -} - /// -/// Used for configuring the retry policy for Azure.OpenAI.Client. +/// Initializes a new instance of the class. /// -internal sealed class ChatRetryPolicy : RetryPolicy +/// The maximum number of retries to attempt. +/// The delay to use for computing the interval between retry attempts. +internal sealed class ChatRetryPolicy(int maxRetries = 2) : ClientRetryPolicy(maxRetries) { private const string RetryAfterHeaderName = "Retry-After"; private const string RetryAfterMsHeaderName = "retry-after-ms"; private const string XRetryAfterMsHeaderName = "x-ms-retry-after-ms"; - /// - /// Initializes a new instance of the class. - /// - /// The maximum number of retries to attempt. - /// The delay to use for computing the interval between retry attempts. - public ChatRetryPolicy(int maxRetries = 2, DelayStrategy? delayStrategy = default) : base( - maxRetries, - delayStrategy ?? DelayStrategy.CreateExponentialDelayStrategy( - initialDelay: TimeSpan.FromSeconds(0.8), - maxDelay: TimeSpan.FromSeconds(5))) - { - // By default, we retry 2 times at most, and use a delay strategy that waits 5 seconds at most between retries. - } - - protected override bool ShouldRetry(HttpMessage message, Exception? exception) => ShouldRetryImpl(message, exception); - protected override ValueTask ShouldRetryAsync(HttpMessage message, Exception? exception) => new(ShouldRetryImpl(message, exception)); + protected override bool ShouldRetry(PipelineMessage message, Exception exception) => ShouldRetryImpl(message, exception); + protected override ValueTask ShouldRetryAsync(PipelineMessage message, Exception exception) => new(ShouldRetryImpl(message, exception)); - private bool ShouldRetryImpl(HttpMessage message, Exception? exception) + private bool ShouldRetryImpl(PipelineMessage message, Exception exception) { bool result = base.ShouldRetry(message, exception); - if (result && message.HasResponse) + if (result && message.Response is not null) { TimeSpan? retryAfter = GetRetryAfterHeaderValue(message.Response.Headers); if (retryAfter > TimeSpan.FromSeconds(5)) @@ -209,22 +162,22 @@ private bool ShouldRetryImpl(HttpMessage message, Exception? exception) return result; } - private static TimeSpan? GetRetryAfterHeaderValue(ResponseHeaders headers) + private static TimeSpan? GetRetryAfterHeaderValue(PipelineResponseHeaders headers) { if (headers.TryGetValue(RetryAfterMsHeaderName, out var retryAfterValue) || headers.TryGetValue(XRetryAfterMsHeaderName, out retryAfterValue)) { - if (int.TryParse(retryAfterValue, out var delaySeconds)) + if (int.TryParse(retryAfterValue, out var delayInMS)) { - return TimeSpan.FromMilliseconds(delaySeconds); + return TimeSpan.FromMilliseconds(delayInMS); } } if (headers.TryGetValue(RetryAfterHeaderName, out retryAfterValue)) { - if (int.TryParse(retryAfterValue, out var delaySeconds)) + if (int.TryParse(retryAfterValue, out var delayInSec)) { - return TimeSpan.FromSeconds(delaySeconds); + return TimeSpan.FromSeconds(delayInSec); } if (DateTimeOffset.TryParse(retryAfterValue, out DateTimeOffset delayTime)) diff --git a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs index 1c509648..5e04417f 100644 --- a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs +++ b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs @@ -1,4 +1,4 @@ -using SharpToken; +using Microsoft.ML.Tokenizers; namespace AIShell.OpenAI.Agent; @@ -12,7 +12,7 @@ internal class ModelInfo private const string Gpt34Encoding = "cl100k_base"; private static readonly Dictionary s_modelMap; - private static readonly Dictionary> s_encodingMap; + private static readonly Dictionary> s_encodingMap; static ModelInfo() { @@ -35,8 +35,8 @@ static ModelInfo() // we don't block the startup and the values will be ready when we really need them. s_encodingMap = new(StringComparer.OrdinalIgnoreCase) { - [Gpt34Encoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt34Encoding)), - [Gpt4oEncoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt4oEncoding)) + [Gpt34Encoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt34Encoding)), + [Gpt4oEncoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt4oEncoding)) }; } @@ -45,24 +45,24 @@ private ModelInfo(int tokenLimit, string encoding = null) TokenLimit = tokenLimit; _encodingName = encoding ?? Gpt34Encoding; - // For gpt4 and gpt3.5-turbo, the following 2 properties are the same. + // For gpt4o, gpt4 and gpt3.5-turbo, the following 2 properties are the same. // See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb TokensPerMessage = 3; TokensPerName = 1; } private readonly string _encodingName; - private GptEncoding _gptEncoding; + private Tokenizer _gptEncoding; internal int TokenLimit { get; } internal int TokensPerMessage { get; } internal int TokensPerName { get; } - internal GptEncoding Encoding + internal Tokenizer Encoding { get { - _gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task value) + _gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task value) ? value.Result - : GptEncoding.GetEncoding(_encodingName); + : TiktokenTokenizer.CreateForEncoding(_encodingName); return _gptEncoding; } } diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs index 22fcc48c..fa96f2ed 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Service.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs @@ -1,8 +1,9 @@ -using System.Diagnostics; -using Azure; -using Azure.Core; +using System.ClientModel; +using System.ClientModel.Primitives; using Azure.AI.OpenAI; -using SharpToken; +using Microsoft.ML.Tokenizers; +using OpenAI; +using OpenAI.Chat; namespace AIShell.OpenAI.Agent; @@ -10,22 +11,34 @@ internal class ChatService { // TODO: Maybe expose this to our model registration? // We can still use 1000 as the default value. - private const int MaxResponseToken = 1000; + private const int MaxResponseToken = 2000; private readonly string _historyRoot; - private readonly List _chatHistory; + private readonly List _chatHistory; + private readonly List _chatHistoryTokens; + private readonly ChatCompletionOptions _chatOptions; private GPT _gptToUse; private Settings _settings; - private OpenAIClient _client; + private ChatClient _client; + private int _totalInputToken; internal ChatService(string historyRoot, Settings settings) { + _chatHistory = []; + _chatHistoryTokens = []; _historyRoot = historyRoot; + + _totalInputToken = 0; _settings = settings; - _chatHistory = []; + + _chatOptions = new ChatCompletionOptions() + { + Temperature = 0, + MaxOutputTokenCount = MaxResponseToken, + }; } - internal List ChatHistory => _chatHistory; + internal List ChatHistory => _chatHistory; internal void AddResponseToHistory(string response) { @@ -34,7 +47,7 @@ internal void AddResponseToHistory(string response) return; } - _chatHistory.Add(new ChatRequestAssistantMessage(response)); + _chatHistory.Add(ChatMessage.CreateAssistantMessage(response)); } internal void RefreshSettings(Settings settings) @@ -42,6 +55,48 @@ internal void RefreshSettings(Settings settings) _settings = settings; } + /// + /// It's almost impossible to relative-accurately calculate the token counts of all + /// messages, especially when tool calls are involved (tool call definitions and the + /// tool call payloads in AI response). + /// So, I decide to leverage the useage report from AI to track the token count of + /// the chat history. It's also an estimate, but I think more accurate than doing the + /// counting by ourselves. + /// + internal void CalibrateChatHistory(ChatTokenUsage usage, AssistantChatMessage response) + { + if (usage is null) + { + // Response was cancelled and we will remove the last query from history. + int index = _chatHistory.Count - 1; + _chatHistory.RemoveAt(index); + _chatHistoryTokens.RemoveAt(index); + + return; + } + + // Every reply is primed with <|start|>assistant<|message|>, so we subtract 3 from the 'InputTokenCount'. + int promptTokenCount = usage.InputTokenCount - 3; + // 'ReasoningTokenCount' should be 0 for non-o1 models. + int responseTokenCount = usage.OutputTokenCount - usage.OutputTokenDetails.ReasoningTokenCount; + + if (_totalInputToken is 0) + { + // It was the first user message, so instead of adjusting the user message token count, + // we set the token count for system message and tool calls. + _chatHistoryTokens[0] = promptTokenCount - _chatHistoryTokens[^1]; + } + else + { + // Adjust the token count of the user message, as our calculation is an estimate. + _chatHistoryTokens[^1] = promptTokenCount - _totalInputToken; + } + + _chatHistory.Add(response); + _chatHistoryTokens.Add(responseTokenCount); + _totalInputToken = promptTokenCount + responseTokenCount; + } + private void RefreshOpenAIClient() { if (ReferenceEquals(_gptToUse, _settings.Active)) @@ -53,6 +108,7 @@ private void RefreshOpenAIClient() GPT old = _gptToUse; _gptToUse = _settings.Active; _chatHistory.Clear(); + _chatHistoryTokens.Clear(); if (old is not null && old.Type == _gptToUse.Type @@ -65,211 +121,127 @@ private void RefreshOpenAIClient() return; } - var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() }; + string userKey = Utils.ConvertFromSecureString(_gptToUse.Key); if (_gptToUse.Type is EndpointType.AzureOpenAI) { // Create a client that targets Azure OpenAI service or Azure API Management service. + var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() }; bool isApimEndpoint = _gptToUse.Endpoint.EndsWith(Utils.ApimGatewayDomain); + if (isApimEndpoint) { - string userkey = Utils.ConvertFromSecureString(_gptToUse.Key); clientOptions.AddPolicy( - new UserKeyPolicy( - new AzureKeyCredential(userkey), + ApiKeyAuthenticationPolicy.CreateHeaderApiKeyPolicy( + new ApiKeyCredential(userKey), Utils.ApimAuthorizationHeader), - HttpPipelinePosition.PerRetry - ); + PipelinePosition.PerTry); } - string azOpenAIApiKey = isApimEndpoint - ? "placeholder-api-key" - : Utils.ConvertFromSecureString(_gptToUse.Key); + string azOpenAIApiKey = isApimEndpoint ? "placeholder-api-key" : userKey; - _client = new OpenAIClient( + var aiClient = new AzureOpenAIClient( new Uri(_gptToUse.Endpoint), - new AzureKeyCredential(azOpenAIApiKey), + new ApiKeyCredential(azOpenAIApiKey), clientOptions); + + _client = aiClient.GetChatClient(_gptToUse.Deployment); } else { // Create a client that targets the non-Azure OpenAI service. - _client = new OpenAIClient(Utils.ConvertFromSecureString(_gptToUse.Key), clientOptions); + var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() }; + var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions); + _client = aiClient.GetChatClient(_gptToUse.ModelName); } } - private int CountTokenForMessages(IEnumerable messages) + /// + /// Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + /// + private int CountTokenForUserMessage(UserChatMessage message) { ModelInfo modelDetail = _gptToUse.ModelInfo; - GptEncoding encoding = modelDetail.Encoding; - int tokensPerMessage = modelDetail.TokensPerMessage; - int tokensPerName = modelDetail.TokensPerName; + Tokenizer encoding = modelDetail.Encoding; - int tokenNumber = 0; - foreach (ChatRequestMessage message in messages) + // Tokens per message plus 1 token for the role. + int tokenNumber = modelDetail.TokensPerMessage + 1; + foreach (ChatMessageContentPart part in message.Content) { - tokenNumber += tokensPerMessage; - tokenNumber += encoding.Encode(message.Role.ToString()).Count; - - switch (message) - { - case ChatRequestSystemMessage systemMessage: - tokenNumber += SimpleCountToken(systemMessage.Content, systemMessage.Name); - break; - - case ChatRequestUserMessage userMessage: - tokenNumber += SimpleCountToken(userMessage.Content, userMessage.Name); - break; - - case ChatRequestAssistantMessage assistantMessage: - tokenNumber += SimpleCountToken(assistantMessage.Content, assistantMessage.Name); - if (assistantMessage.ToolCalls is not null) - { - // Count tokens for the tool call's properties - foreach(ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls) - { - if(chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall) - { - tokenNumber += encoding.Encode(functionToolCall.Id).Count; - tokenNumber += encoding.Encode(functionToolCall.Name).Count; - tokenNumber += encoding.Encode(functionToolCall.Arguments).Count; - } - } - } - break; - - case ChatRequestToolMessage toolMessage: - tokenNumber += encoding.Encode(toolMessage.ToolCallId).Count; - tokenNumber += encoding.Encode(toolMessage.Content).Count; - break; - // Add cases for other derived types as needed - } + tokenNumber += encoding.CountTokens(part.Text); } - // Every reply is primed with <|start|>assistant<|message|>, which takes 3 tokens. - tokenNumber += 3; return tokenNumber; - - // ----- Local Function ----- - int SimpleCountToken(string content, string name) - { - int sum = 0; - if (!string.IsNullOrEmpty(content)) - { - sum = encoding.Encode(content).Count; - } - - if (!string.IsNullOrEmpty(name)) - { - sum += tokensPerName; - sum += encoding.Encode(name).Count; - } - - return sum; - } } - private void ReduceChatHistoryAsNeeded(List history, ChatRequestMessage input) + private void PrepareForChat(string input) { - bool inputTooLong = false; - int tokenLimit = _gptToUse.ModelInfo.TokenLimit; + // Refresh the client in case the active model was changed. + RefreshOpenAIClient(); - do + if (_chatHistory.Count is 0) { - int totalTokens = CountTokenForMessages(Enumerable.Repeat(input, 1)); - if (totalTokens + MaxResponseToken >= tokenLimit) - { - // The input itself already exceeds the token limit. - inputTooLong = true; - break; - } + _chatHistory.Add(ChatMessage.CreateSystemMessage(_gptToUse.SystemPrompt)); + _chatHistoryTokens.Add(0); + } - history.Add(input); - totalTokens = CountTokenForMessages(history); + var userMessage = new UserChatMessage(input); + int msgTokenCnt = CountTokenForUserMessage(userMessage); + _chatHistory.Add(userMessage); + _chatHistoryTokens.Add(msgTokenCnt); - int index = -1; - while (totalTokens + MaxResponseToken >= tokenLimit) + int inputLimit = _gptToUse.ModelInfo.TokenLimit; + // Every reply is primed with <|start|>assistant<|message|>, so adding 3 tokens. + int newTotal = _totalInputToken + msgTokenCnt + 3; + + // Shrink the chat history if we have less than 50 free tokens left (50-token buffer). + while (inputLimit - newTotal < 50) + { + // We remove a round of conversation for every trimming operation. + int userMsgCnt = 0; + List indices = []; + + for (int i = 0; i < _chatHistory.Count; i++) { - if (index is -1) + if (_chatHistory[i] is UserChatMessage) { - // Find the first non-system message. - for (index = 0; history[index] is ChatRequestSystemMessage; index++); + if (userMsgCnt is 1) + { + break; + } + + userMsgCnt++; } - if (history[index] == input) + if (userMsgCnt is 1) { - // The input plus system message exceeds the token limit. - inputTooLong = true; - break; + indices.Add(i); } - - history.RemoveAt(index); - totalTokens = CountTokenForMessages(history); } - } - while (false); - if (inputTooLong) - { - var message = $"The input is too long to get a proper response without exceeding the token limit ({tokenLimit}).\nPlease reduce the input and try again."; - throw new InvalidOperationException(message); - } - } - - private ChatCompletionsOptions PrepareForChat(string input) - { - // Refresh the client in case the active model was changed. - RefreshOpenAIClient(); - - // TODO: Shall we expose some of the setting properties to our model registration? - // - max_tokens - // - temperature - // - top_p - // - presence_penalty - // - frequency_penalty - // Those settings seem to be important enough, as the Semantic Kernel plugin specifies - // those settings (see the URL below). We can use default values when not defined. - // https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json - string deploymentOrModelName = _gptToUse.Type switch - { - EndpointType.AzureOpenAI => _gptToUse.Deployment, - EndpointType.OpenAI => _gptToUse.ModelName, - _ => throw new UnreachableException(), - }; - - ChatCompletionsOptions chatOptions = new() - { - DeploymentName = deploymentOrModelName, - ChoiceCount = 1, - Temperature = 0, - MaxTokens = MaxResponseToken, - }; - - List history = _chatHistory; - if (history.Count is 0) - { - history.Add(new ChatRequestSystemMessage(_gptToUse.SystemPrompt)); - } + foreach (int i in indices) + { + newTotal -= _chatHistoryTokens[i]; + } - ReduceChatHistoryAsNeeded(history, new ChatRequestUserMessage(input)); - foreach (ChatRequestMessage message in history) - { - chatOptions.Messages.Add(message); + _chatHistory.RemoveRange(indices[0], indices.Count); + _chatHistoryTokens.RemoveRange(indices[0], indices.Count); + _totalInputToken = newTotal - msgTokenCnt; } - - return chatOptions; } - public async Task> GetStreamingChatResponseAsync(string input, CancellationToken cancellationToken = default) + public async Task> GetStreamingChatResponseAsync(string input, CancellationToken cancellationToken) { try { - ChatCompletionsOptions chatOptions = PrepareForChat(input); - var response = await _client.GetChatCompletionsStreamingAsync( - chatOptions, - cancellationToken); - - return response; + PrepareForChat(input); + IAsyncEnumerator enumerator = _client + .CompleteChatStreamingAsync(_chatHistory, _chatOptions, cancellationToken) + .GetAsyncEnumerator(cancellationToken); + + return await enumerator + .MoveNextAsync() + .ConfigureAwait(continueOnCapturedContext: false) ? enumerator : null; } catch (OperationCanceledException) { From d49d5ca0c8fa9f274a49eeb9674c10872836de2d Mon Sep 17 00:00:00 2001 From: Dongbo Wang Date: Fri, 17 Jan 2025 16:46:35 -0800 Subject: [PATCH 2/2] Fix an error and add model info for the `o1` model --- shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs | 6 +++--- shell/agents/AIShell.OpenAI.Agent/Service.cs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs index 5e04417f..654f090d 100644 --- a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs +++ b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs @@ -5,9 +5,8 @@ namespace AIShell.OpenAI.Agent; internal class ModelInfo { // Models gpt4, gpt3.5, and the variants of them all use the 'cl100k_base' token encoding. - // But the gpt-4o model uses the 'o200k_base' token encoding. For reference: - // https://github.com/openai/tiktoken/blob/5d970c1100d3210b42497203d6b5c1e30cfda6cb/tiktoken/model.py#L7 - // https://github.com/dmitry-brazhenko/SharpToken/blob/main/SharpToken/Lib/Model.cs#L8 + // But gpt-4o and o1 models use the 'o200k_base' token encoding. For reference: + // https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/model.py private const string Gpt4oEncoding = "o200k_base"; private const string Gpt34Encoding = "cl100k_base"; @@ -21,6 +20,7 @@ static ModelInfo() // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb s_modelMap = new(StringComparer.OrdinalIgnoreCase) { + ["o1"] = new(tokenLimit: 200_000, encoding: Gpt4oEncoding), ["gpt-4o"] = new(tokenLimit: 128_000, encoding: Gpt4oEncoding), ["gpt-4"] = new(tokenLimit: 8_192), ["gpt-4-32k"] = new(tokenLimit: 32_768), diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs index fa96f2ed..583a90d6 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Service.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs @@ -78,7 +78,8 @@ internal void CalibrateChatHistory(ChatTokenUsage usage, AssistantChatMessage re // Every reply is primed with <|start|>assistant<|message|>, so we subtract 3 from the 'InputTokenCount'. int promptTokenCount = usage.InputTokenCount - 3; // 'ReasoningTokenCount' should be 0 for non-o1 models. - int responseTokenCount = usage.OutputTokenCount - usage.OutputTokenDetails.ReasoningTokenCount; + int reasoningTokenCount = usage.OutputTokenDetails is null ? 0 : usage.OutputTokenDetails.ReasoningTokenCount; + int responseTokenCount = usage.OutputTokenCount - reasoningTokenCount; if (_totalInputToken is 0) {