Skip to content

feat: add mutate functionality for OpenAiApi and OpenAiChatModel #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -687,8 +687,30 @@ public static Builder builder() {
return new Builder();
}

/**
* Returns a builder pre-populated with the current configuration for mutation.
*/
public Builder mutate() {
return new Builder(this);
}

@Override
public OpenAiChatModel clone() {
return this.mutate().build();
}

public static final class Builder {

// Copy constructor for mutate()
public Builder(OpenAiChatModel model) {
this.openAiApi = model.openAiApi;
this.defaultOptions = model.defaultOptions;
this.toolCallingManager = model.toolCallingManager;
this.toolExecutionEligibilityPredicate = model.toolExecutionEligibilityPredicate;
this.retryTemplate = model.retryTemplate;
this.observationRegistry = model.observationRegistry;
}

private OpenAiApi openAiApi;

private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
Original file line number Diff line number Diff line change
@@ -65,6 +65,13 @@
*/
public class OpenAiApi {

/**
* Returns a builder pre-populated with the current configuration for mutation.
*/
public Builder mutate() {
return new Builder(this);
}

public static Builder builder() {
return new Builder();
}
@@ -75,10 +82,19 @@ public static Builder builder() {

private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;

// Store config fields for mutate/copy
private final String baseUrl;

private final ApiKey apiKey;

private final MultiValueMap<String, String> headers;

private final String completionsPath;

private final String embeddingsPath;

private final ResponseErrorHandler responseErrorHandler;

private final RestClient restClient;

private final WebClient webClient;
@@ -99,13 +115,17 @@ public static Builder builder() {
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.headers = headers;
this.completionsPath = completionsPath;
this.embeddingsPath = embeddingsPath;
this.responseErrorHandler = responseErrorHandler;

Assert.hasText(completionsPath, "Completions Path must not be null");
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
Assert.notNull(headers, "Headers must not be null");

this.completionsPath = completionsPath;
this.embeddingsPath = embeddingsPath;
// @formatter:off
Consumer<HttpHeaders> finalHeaders = h -> {
if (!(apiKey instanceof NoopApiKey)) {
@@ -1674,6 +1694,21 @@ public record EmbeddingList<T>(// @formatter:off

public static class Builder {

public Builder() {
}

// Copy constructor for mutate()
public Builder(OpenAiApi api) {
this.baseUrl = api.getBaseUrl();
this.apiKey = api.getApiKey();
this.headers = new LinkedMultiValueMap<>(api.getHeaders());
this.completionsPath = api.getCompletionsPath();
this.embeddingsPath = api.getEmbeddingsPath();
this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder();
this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder();
this.responseErrorHandler = api.getResponseErrorHandler();
}

private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL;

private ApiKey apiKey;
@@ -1752,4 +1787,29 @@ public OpenAiApi build() {

}

// Package-private getters for mutate/copy
String getBaseUrl() {
return this.baseUrl;
}

ApiKey getApiKey() {
return this.apiKey;
}

MultiValueMap<String, String> getHeaders() {
return this.headers;
}

String getCompletionsPath() {
return this.completionsPath;
}

String getEmbeddingsPath() {
return this.embeddingsPath;
}

ResponseErrorHandler getResponseErrorHandler() {
return this.responseErrorHandler;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Integration test for mutate/clone functionality on OpenAiApi and OpenAiChatModel.
* This test demonstrates creating multiple ChatClient instances with different endpoints and options
* from a single autoconfigured OpenAiChatModel/OpenAiApi, as per the feature request.
*/
package org.springframework.ai.openai.api;

import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.util.LinkedMultiValueMap;

import static org.assertj.core.api.Assertions.assertThat;

class OpenAiChatModelMutateTests {

// Simulate autoconfigured base beans (in real usage, these would be @Autowired)
private final OpenAiApi baseApi = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("base-key").build();

private final OpenAiChatModel baseModel = OpenAiChatModel.builder()
.openAiApi(baseApi)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-3.5-turbo").build())
.build();

@Test
void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
// Mutate for GPT-4
OpenAiApi gpt4Api = baseApi.mutate().baseUrl("https://api.openai.com").apiKey("your-api-key-for-gpt4").build();
OpenAiChatModel gpt4Model = baseModel.mutate()
.openAiApi(gpt4Api)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
.build();
ChatClient gpt4Client = ChatClient.builder(gpt4Model).build();

// Mutate for Llama
OpenAiApi llamaApi = baseApi.mutate()
.baseUrl("https://your-custom-endpoint.com")
.apiKey("your-api-key-for-llama")
.build();
OpenAiChatModel llamaModel = baseModel.mutate()
.openAiApi(llamaApi)
.defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build())
.build();
ChatClient llamaClient = ChatClient.builder(llamaModel).build();

// Assert endpoints and models are different
assertThat(gpt4Model).isNotSameAs(llamaModel);
assertThat(gpt4Api).isNotSameAs(llamaApi);
assertThat(gpt4Model.toString()).contains("gpt-4");
assertThat(llamaModel.toString()).contains("llama-70b");
// Optionally, assert endpoints
// (In real usage, you might expose/get the baseUrl for assertion)
}

@Test
void testCloneCreatesDeepCopy() {
OpenAiChatModel clone = baseModel.clone();
assertThat(clone).isNotSameAs(baseModel);
assertThat(clone.toString()).isEqualTo(baseModel.toString());
}

@Test
void mutateDoesNotAffectOriginal() {
OpenAiChatModel mutated = baseModel.mutate()
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
.build();
assertThat(mutated).isNotSameAs(baseModel);
assertThat(mutated.getDefaultOptions().getModel()).isEqualTo("gpt-4");
assertThat(baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
}

@Test
void mutateHeadersCreatesDistinctHeaders() {
OpenAiApi mutatedApi = baseApi.mutate()
.headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value"))))
.build();

assertThat(mutatedApi.getHeaders()).containsKey("X-Test");
assertThat(baseApi.getHeaders()).doesNotContainKey("X-Test");
}

@Test
void mutateHandlesNullAndDefaults() {
OpenAiApi apiWithDefaults = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("key").build();
OpenAiApi mutated = apiWithDefaults.mutate().build();
assertThat(mutated).isNotNull();
assertThat(mutated.getBaseUrl()).isEqualTo("https://api.openai.com");
assertThat(mutated.getApiKey().getValue()).isEqualTo("key");
}

@Test
void multipleSequentialMutationsProduceDistinctInstances() {
OpenAiChatModel m1 = baseModel.mutate().defaultOptions(OpenAiChatOptions.builder().model("m1").build()).build();
OpenAiChatModel m2 = m1.mutate().defaultOptions(OpenAiChatOptions.builder().model("m2").build()).build();
OpenAiChatModel m3 = m2.mutate().defaultOptions(OpenAiChatOptions.builder().model("m3").build()).build();
assertThat(m1).isNotSameAs(m2);
assertThat(m2).isNotSameAs(m3);
assertThat(m1.getDefaultOptions().getModel()).isEqualTo("m1");
assertThat(m2.getDefaultOptions().getModel()).isEqualTo("m2");
assertThat(m3.getDefaultOptions().getModel()).isEqualTo("m3");
}

@Test
void mutateAndCloneAreEquivalent() {
OpenAiChatModel mutated = baseModel.mutate().build();
OpenAiChatModel cloned = baseModel.clone();
assertThat(mutated.toString()).isEqualTo(cloned.toString());
assertThat(mutated).isNotSameAs(cloned);
}

}
Original file line number Diff line number Diff line change
@@ -66,7 +66,7 @@

@SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class)
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
@Disabled("Due to rate limiting it is hard to run it in one go")
// @Disabled("Due to rate limiting it is hard to run it in one go")
class GroqWithOpenAiChatModelIT {

private static final Logger logger = LoggerFactory.getLogger(GroqWithOpenAiChatModelIT.class);
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.openai.chat.proxy;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.test.context.ActiveProfiles;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = MultiOpenAiClientIT.Config.class)
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@ActiveProfiles("logging-test")
class MultiOpenAiClientIT {

private static final Logger logger = LoggerFactory.getLogger(MultiOpenAiClientIT.class);

@Autowired
private OpenAiChatModel baseChatModel;

@Autowired
private OpenAiApi baseOpenAiApi;

@Test
void multiClientFlow() {
// Derive a new OpenAiApi for Groq (Llama3)
OpenAiApi groqApi = baseOpenAiApi.mutate()
.baseUrl("https://api.groq.com/openai")
.apiKey(System.getenv("GROQ_API_KEY"))
.build();

// Derive a new OpenAiApi for OpenAI GPT-4
OpenAiApi gpt4Api = baseOpenAiApi.mutate()
.baseUrl("https://api.openai.com")
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();

// Derive a new OpenAiChatModel for Groq
OpenAiChatModel groqModel = baseChatModel.mutate()
.openAiApi(groqApi)
.defaultOptions(OpenAiChatOptions.builder().model("llama3-70b-8192").temperature(0.5).build())
.build();

// Derive a new OpenAiChatModel for GPT-4
OpenAiChatModel gpt4Model = baseChatModel.mutate()
.openAiApi(gpt4Api)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
.build();

// Simple prompt for both models
String prompt = "What is the capital of France?";

String groqResponse = ChatClient.builder(groqModel).build().prompt(prompt).call().content();
String gpt4Response = ChatClient.builder(gpt4Model).build().prompt(prompt).call().content();

logger.info("Groq (Llama3) response: {}", groqResponse);
logger.info("OpenAI GPT-4 response: {}", gpt4Response);

assertThat(groqResponse).containsIgnoringCase("Paris");
assertThat(gpt4Response).containsIgnoringCase("Paris");

logger.info("OpenAI GPT-4 response: {}", gpt4Response);

assertThat(groqResponse).containsIgnoringCase("Paris");
assertThat(gpt4Response).containsIgnoringCase("Paris");
}

@SpringBootConfiguration
static class Config {

@Bean
public OpenAiApi chatCompletionApi() {
return OpenAiApi.builder().baseUrl("foo").apiKey("bar").build();
}

@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}

}

}