diff --git a/shell/agents/AIShell.OpenAI.Agent/GPT.cs b/shell/agents/AIShell.OpenAI.Agent/GPT.cs index 8d726cec..239f286b 100644 --- a/shell/agents/AIShell.OpenAI.Agent/GPT.cs +++ b/shell/agents/AIShell.OpenAI.Agent/GPT.cs @@ -9,6 +9,7 @@ internal enum EndpointType { AzureOpenAI, OpenAI, + CompatibleThirdParty, } public class GPT @@ -56,9 +57,16 @@ public GPT( bool noDeployment = string.IsNullOrEmpty(Deployment); Type = noEndpoint && noDeployment ? EndpointType.OpenAI - : !noEndpoint && !noDeployment - ? EndpointType.AzureOpenAI - : throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys."); + : !noEndpoint && noDeployment + ? EndpointType.CompatibleThirdParty + : !noEndpoint && !noDeployment + ? EndpointType.AzureOpenAI + : throw new InvalidOperationException($"Invalid setting: 'Deployment' key present but 'Endpoint' key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys."); + + if (ModelInfo is null && Type is EndpointType.CompatibleThirdParty) + { + ModelInfo = ModelInfo.ThirdPartyModel; + } } /// @@ -142,11 +150,18 @@ private void ShowEndpointInfo(IHost host) new(label: " Model", m => m.ModelName), }, - EndpointType.OpenAI => new CustomElement[] - { + EndpointType.OpenAI => + [ new(label: " Type", m => m.Type.ToString()), new(label: " Model", m => m.ModelName), - }, + ], + + EndpointType.CompatibleThirdParty => + [ + new(label: " Type", m => m.Type.ToString()), + new(label: " Endpoint", m => m.Endpoint), + new(label: " Model", m => m.ModelName), + ], _ => throw new UnreachableException(), }; diff --git a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs index 654f090d..79674e85 100644 --- a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs +++ b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs @@ -13,6 +13,11 @@ internal class ModelInfo private static readonly Dictionary s_modelMap; private static readonly Dictionary> s_encodingMap; + // A rough estimate to cover all third-party models. + // - most popular models today support 32K+ context length; + // - use the gpt-4o encoding as an estimate for token count. + internal static readonly ModelInfo ThirdPartyModel = new(32_000, encoding: Gpt4oEncoding); + static ModelInfo() { // For reference, see https://platform.openai.com/docs/models and the "Counting tokens" section in diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs index 583a90d6..9251a6f6 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Service.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs @@ -122,9 +122,10 @@ private void RefreshOpenAIClient() return; } + EndpointType type = _gptToUse.Type; string userKey = Utils.ConvertFromSecureString(_gptToUse.Key); - if (_gptToUse.Type is EndpointType.AzureOpenAI) + if (type is EndpointType.AzureOpenAI) { // Create a client that targets Azure OpenAI service or Azure API Management service. var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() }; @@ -152,6 +153,11 @@ private void RefreshOpenAIClient() { // Create a client that targets the non-Azure OpenAI service. var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() }; + if (type is EndpointType.CompatibleThirdParty) + { + clientOptions.Endpoint = new(_gptToUse.Endpoint); + } + var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions); _client = aiClient.GetChatClient(_gptToUse.ModelName); }