diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 5a11bcad999..9446706149a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -687,8 +687,30 @@ public static Builder builder() { return new Builder(); } + /** + * Returns a builder pre-populated with the current configuration for mutation. + */ + public Builder mutate() { + return new Builder(this); + } + + @Override + public OpenAiChatModel clone() { + return this.mutate().build(); + } + public static final class Builder { + // Copy constructor for mutate() + public Builder(OpenAiChatModel model) { + this.openAiApi = model.openAiApi; + this.defaultOptions = model.defaultOptions; + this.toolCallingManager = model.toolCallingManager; + this.toolExecutionEligibilityPredicate = model.toolExecutionEligibilityPredicate; + this.retryTemplate = model.retryTemplate; + this.observationRegistry = model.observationRegistry; + } + private OpenAiApi openAiApi; private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index a0da1221dbd..1442dbcf0b0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -65,6 +65,13 @@ */ public class OpenAiApi { + /** + * Returns a builder pre-populated with the current configuration for mutation. + */ + public Builder mutate() { + return new Builder(this); + } + public static Builder builder() { return new Builder(); } @@ -75,10 +82,19 @@ public static Builder builder() { private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; + // Store config fields for mutate/copy + private final String baseUrl; + + private final ApiKey apiKey; + + private final MultiValueMap headers; + private final String completionsPath; private final String embeddingsPath; + private final ResponseErrorHandler responseErrorHandler; + private final RestClient restClient; private final WebClient webClient; @@ -99,13 +115,17 @@ public static Builder builder() { public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + this.baseUrl = baseUrl; + this.apiKey = apiKey; + this.headers = headers; + this.completionsPath = completionsPath; + this.embeddingsPath = embeddingsPath; + this.responseErrorHandler = responseErrorHandler; Assert.hasText(completionsPath, "Completions Path must not be null"); Assert.hasText(embeddingsPath, "Embeddings Path must not be null"); Assert.notNull(headers, "Headers must not be null"); - this.completionsPath = completionsPath; - this.embeddingsPath = embeddingsPath; // @formatter:off Consumer finalHeaders = h -> { if (!(apiKey instanceof NoopApiKey)) { @@ -1674,6 +1694,21 @@ public record EmbeddingList(// @formatter:off public static class Builder { + public Builder() { + } + + // Copy constructor for mutate() + public Builder(OpenAiApi api) { + this.baseUrl = api.getBaseUrl(); + this.apiKey = api.getApiKey(); + this.headers = new LinkedMultiValueMap<>(api.getHeaders()); + this.completionsPath = api.getCompletionsPath(); + this.embeddingsPath = api.getEmbeddingsPath(); + this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); + this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder(); + this.responseErrorHandler = api.getResponseErrorHandler(); + } + private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL; private ApiKey apiKey; @@ -1752,4 +1787,29 @@ public OpenAiApi build() { } + // Package-private getters for mutate/copy + String getBaseUrl() { + return this.baseUrl; + } + + ApiKey getApiKey() { + return this.apiKey; + } + + MultiValueMap getHeaders() { + return this.headers; + } + + String getCompletionsPath() { + return this.completionsPath; + } + + String getEmbeddingsPath() { + return this.embeddingsPath; + } + + ResponseErrorHandler getResponseErrorHandler() { + return this.responseErrorHandler; + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java new file mode 100644 index 00000000000..659b797683d --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiChatModelMutateTests.java @@ -0,0 +1,113 @@ +/* + * Integration test for mutate/clone functionality on OpenAiApi and OpenAiChatModel. + * This test demonstrates creating multiple ChatClient instances with different endpoints and options + * from a single autoconfigured OpenAiChatModel/OpenAiApi, as per the feature request. + */ +package org.springframework.ai.openai.api; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.util.LinkedMultiValueMap; + +import static org.assertj.core.api.Assertions.assertThat; + +class OpenAiChatModelMutateTests { + + // Simulate autoconfigured base beans (in real usage, these would be @Autowired) + private final OpenAiApi baseApi = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("base-key").build(); + + private final OpenAiChatModel baseModel = OpenAiChatModel.builder() + .openAiApi(baseApi) + .defaultOptions(OpenAiChatOptions.builder().model("gpt-3.5-turbo").build()) + .build(); + + @Test + void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() { + // Mutate for GPT-4 + OpenAiApi gpt4Api = baseApi.mutate().baseUrl("https://api.openai.com").apiKey("your-api-key-for-gpt4").build(); + OpenAiChatModel gpt4Model = baseModel.mutate() + .openAiApi(gpt4Api) + .defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build()) + .build(); + ChatClient gpt4Client = ChatClient.builder(gpt4Model).build(); + + // Mutate for Llama + OpenAiApi llamaApi = baseApi.mutate() + .baseUrl("https://your-custom-endpoint.com") + .apiKey("your-api-key-for-llama") + .build(); + OpenAiChatModel llamaModel = baseModel.mutate() + .openAiApi(llamaApi) + .defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build()) + .build(); + ChatClient llamaClient = ChatClient.builder(llamaModel).build(); + + // Assert endpoints and models are different + assertThat(gpt4Model).isNotSameAs(llamaModel); + assertThat(gpt4Api).isNotSameAs(llamaApi); + assertThat(gpt4Model.toString()).contains("gpt-4"); + assertThat(llamaModel.toString()).contains("llama-70b"); + // Optionally, assert endpoints + // (In real usage, you might expose/get the baseUrl for assertion) + } + + @Test + void testCloneCreatesDeepCopy() { + OpenAiChatModel clone = baseModel.clone(); + assertThat(clone).isNotSameAs(baseModel); + assertThat(clone.toString()).isEqualTo(baseModel.toString()); + } + + @Test + void mutateDoesNotAffectOriginal() { + OpenAiChatModel mutated = baseModel.mutate() + .defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build()) + .build(); + assertThat(mutated).isNotSameAs(baseModel); + assertThat(mutated.getDefaultOptions().getModel()).isEqualTo("gpt-4"); + assertThat(baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo"); + } + + @Test + void mutateHeadersCreatesDistinctHeaders() { + OpenAiApi mutatedApi = baseApi.mutate() + .headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value")))) + .build(); + + assertThat(mutatedApi.getHeaders()).containsKey("X-Test"); + assertThat(baseApi.getHeaders()).doesNotContainKey("X-Test"); + } + + @Test + void mutateHandlesNullAndDefaults() { + OpenAiApi apiWithDefaults = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("key").build(); + OpenAiApi mutated = apiWithDefaults.mutate().build(); + assertThat(mutated).isNotNull(); + assertThat(mutated.getBaseUrl()).isEqualTo("https://api.openai.com"); + assertThat(mutated.getApiKey().getValue()).isEqualTo("key"); + } + + @Test + void multipleSequentialMutationsProduceDistinctInstances() { + OpenAiChatModel m1 = baseModel.mutate().defaultOptions(OpenAiChatOptions.builder().model("m1").build()).build(); + OpenAiChatModel m2 = m1.mutate().defaultOptions(OpenAiChatOptions.builder().model("m2").build()).build(); + OpenAiChatModel m3 = m2.mutate().defaultOptions(OpenAiChatOptions.builder().model("m3").build()).build(); + assertThat(m1).isNotSameAs(m2); + assertThat(m2).isNotSameAs(m3); + assertThat(m1.getDefaultOptions().getModel()).isEqualTo("m1"); + assertThat(m2.getDefaultOptions().getModel()).isEqualTo("m2"); + assertThat(m3.getDefaultOptions().getModel()).isEqualTo("m3"); + } + + @Test + void mutateAndCloneAreEquivalent() { + OpenAiChatModel mutated = baseModel.mutate().build(); + OpenAiChatModel cloned = baseModel.clone(); + assertThat(mutated.toString()).isEqualTo(cloned.toString()); + assertThat(mutated).isNotSameAs(cloned); + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index cbff7f6cb10..6012ad245b9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -66,7 +66,7 @@ @SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+") -@Disabled("Due to rate limiting it is hard to run it in one go") +// @Disabled("Due to rate limiting it is hard to run it in one go") class GroqWithOpenAiChatModelIT { private static final Logger logger = LoggerFactory.getLogger(GroqWithOpenAiChatModelIT.class); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MultiOpenAiClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MultiOpenAiClientIT.java new file mode 100644 index 00000000000..15a49d40afe --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MultiOpenAiClientIT.java @@ -0,0 +1,109 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.chat.proxy; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.test.context.ActiveProfiles; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = MultiOpenAiClientIT.Config.class) +@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +@ActiveProfiles("logging-test") +class MultiOpenAiClientIT { + + private static final Logger logger = LoggerFactory.getLogger(MultiOpenAiClientIT.class); + + @Autowired + private OpenAiChatModel baseChatModel; + + @Autowired + private OpenAiApi baseOpenAiApi; + + @Test + void multiClientFlow() { + // Derive a new OpenAiApi for Groq (Llama3) + OpenAiApi groqApi = baseOpenAiApi.mutate() + .baseUrl("https://api.groq.com/openai") + .apiKey(System.getenv("GROQ_API_KEY")) + .build(); + + // Derive a new OpenAiApi for OpenAI GPT-4 + OpenAiApi gpt4Api = baseOpenAiApi.mutate() + .baseUrl("https://api.openai.com") + .apiKey(System.getenv("OPENAI_API_KEY")) + .build(); + + // Derive a new OpenAiChatModel for Groq + OpenAiChatModel groqModel = baseChatModel.mutate() + .openAiApi(groqApi) + .defaultOptions(OpenAiChatOptions.builder().model("llama3-70b-8192").temperature(0.5).build()) + .build(); + + // Derive a new OpenAiChatModel for GPT-4 + OpenAiChatModel gpt4Model = baseChatModel.mutate() + .openAiApi(gpt4Api) + .defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build()) + .build(); + + // Simple prompt for both models + String prompt = "What is the capital of France?"; + + String groqResponse = ChatClient.builder(groqModel).build().prompt(prompt).call().content(); + String gpt4Response = ChatClient.builder(gpt4Model).build().prompt(prompt).call().content(); + + logger.info("Groq (Llama3) response: {}", groqResponse); + logger.info("OpenAI GPT-4 response: {}", gpt4Response); + + assertThat(groqResponse).containsIgnoringCase("Paris"); + assertThat(gpt4Response).containsIgnoringCase("Paris"); + + logger.info("OpenAI GPT-4 response: {}", gpt4Response); + + assertThat(groqResponse).containsIgnoringCase("Paris"); + assertThat(gpt4Response).containsIgnoringCase("Paris"); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiApi chatCompletionApi() { + return OpenAiApi.builder().baseUrl("foo").apiKey("bar").build(); + } + + @Bean + public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); + } + + } + +}