diff --git a/completion.go b/completion.go index 21d4897c4..eb3831e32 100644 --- a/completion.go +++ b/completion.go @@ -166,6 +166,10 @@ func checkEndpointSupportsModel(endpoint, model string) bool { return !disabledModelsForEndpoints[endpoint][model] } +func RegisterSupportsModel(endpoint, model string) { + disabledModelsForEndpoints[endpoint][model] = true +} + func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) diff --git a/completion_test.go b/completion_test.go index 27e2d150e..73fdd9519 100644 --- a/completion_test.go +++ b/completion_test.go @@ -300,3 +300,27 @@ func TestCompletionWithGPT4oModels(t *testing.T) { }) } } + +func TestRegisterSupportsModel(t *testing.T) { + type args struct { + endpoint string + model string + } + tests := []struct { + name string + args args + }{ + { + name: "Register model ", + args: args{ + endpoint: "/chat/completions", + model: "local-model-3.5", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(_ *testing.T) { + openai.RegisterSupportsModel(tt.args.endpoint, tt.args.model) + }) + } +}