Skip to content

Provide the ability to configure OpenAI client read timeout #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;
import java.util.List;

import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
@@ -33,8 +34,12 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -60,8 +65,7 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper
List<FunctionCallback> toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {

var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
var openAiApi = openAiApi(commonProperties, chatProperties, restClientBuilder, responseErrorHandler);

if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
@@ -78,23 +82,22 @@ public OpenAiEmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties co
OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {

var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
var openAiApi = openAiApi(commonProperties, embeddingProperties, restClientBuilder, responseErrorHandler);

return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getMetadataMode(),
embeddingProperties.getOptions(), retryTemplate);
}

private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
private <T extends OpenAiParentProperties> OpenAiApi openAiApi(OpenAiConnectionProperties commonProperties,
T specificProperties, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {

String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
Assert.hasText(resolvedBaseUrl, "OpenAI base URL must be set");
OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties,
specificProperties);
RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder,
overridenCommonProperties);

String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
Assert.hasText(resolvedApiKey, "OpenAI API key must be set");

return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler);
return new OpenAiApi(overridenCommonProperties.getBaseUrl(), overridenCommonProperties.getApiKey(),
overrideRestClientBuilder, responseErrorHandler);
}

@Bean
@@ -105,41 +108,32 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp
OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {

String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey()
: commonProperties.getApiKey();

String baseUrl = StringUtils.hasText(imageProperties.getBaseUrl()) ? imageProperties.getBaseUrl()
: commonProperties.getBaseUrl();

Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties,
imageProperties);
RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder,
overridenCommonProperties);

var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler);
var openAiImageApi = new OpenAiImageApi(overridenCommonProperties.getBaseUrl(),
overridenCommonProperties.getApiKey(), overrideRestClientBuilder, responseErrorHandler);

return new OpenAiImageClient(openAiImageApi, imageProperties.getOptions(), retryTemplate);
}

@Bean
@ConditionalOnMissingBean
public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConnectionProperties commonProperties,
OpenAiAudioTranscriptionProperties transcriptionProperties, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {

String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey()
: commonProperties.getApiKey();

String baseUrl = StringUtils.hasText(transcriptionProperties.getBaseUrl())
? transcriptionProperties.getBaseUrl() : commonProperties.getBaseUrl();

Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
OpenAiAudioTranscriptionProperties transcriptionProperties, RestClient.Builder restClientBuilder,
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {

var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler);
OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties,
transcriptionProperties);
RestClient.Builder overrideRestClientBuilder = overrideRestClientBuilder(restClientBuilder,
overridenCommonProperties);

OpenAiAudioTranscriptionClient openAiChatClient = new OpenAiAudioTranscriptionClient(openAiAudioApi,
transcriptionProperties.getOptions(), retryTemplate);
var openAiAudioApi = new OpenAiAudioApi(overridenCommonProperties.getBaseUrl(),
overridenCommonProperties.getApiKey(), overrideRestClientBuilder, responseErrorHandler);

return openAiChatClient;
return new OpenAiAudioTranscriptionClient(openAiAudioApi, transcriptionProperties.getOptions(), retryTemplate);
}

@Bean
@@ -150,4 +144,37 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex
return manager;
}

private static <T extends OpenAiParentProperties> OpenAiConnectionProperties checkAndOverrideProperties(
OpenAiConnectionProperties commonProperties, T specificProperties) {

String apiKey = StringUtils.hasText(specificProperties.getApiKey()) ? specificProperties.getApiKey()
: commonProperties.getApiKey();

String baseUrl = StringUtils.hasText(specificProperties.getBaseUrl()) ? specificProperties.getBaseUrl()
: commonProperties.getBaseUrl();

Duration readTimeout = specificProperties.getReadTimeout() != null ? specificProperties.getReadTimeout()
: commonProperties.getReadTimeout();

Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
Assert.notNull(readTimeout, "OpenAI base read timeout must be set");

OpenAiConnectionProperties overridenCommonProperties = new OpenAiConnectionProperties();
overridenCommonProperties.setApiKey(apiKey);
overridenCommonProperties.setBaseUrl(baseUrl);
overridenCommonProperties.setReadTimeout(readTimeout);

return overridenCommonProperties;

}

private static RestClient.Builder overrideRestClientBuilder(RestClient.Builder restClientBuilder,
OpenAiConnectionProperties overridenCommonProperties) {
ClientHttpRequestFactorySettings requestFactorySettings = new ClientHttpRequestFactorySettings(
Duration.ofHours(1l), overridenCommonProperties.getReadTimeout(), SslBundle.of(null));
ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(requestFactorySettings);
return restClientBuilder.clone().requestFactory(requestFactory);
}

}
Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;

import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(OpenAiConnectionProperties.CONFIG_PREFIX)
@@ -24,8 +26,11 @@ public class OpenAiConnectionProperties extends OpenAiParentProperties {

public static final String DEFAULT_BASE_URL = "https://api.openai.com";

public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofMinutes(1);

public OpenAiConnectionProperties() {
super.setBaseUrl(DEFAULT_BASE_URL);
super.setReadTimeout(DEFAULT_READ_TIMEOUT);
}

}
Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;

/**
* Internal parent properties for the OpenAI properties.
*
@@ -27,6 +29,8 @@ class OpenAiParentProperties {

private String baseUrl;

private Duration readTimeout;

public String getApiKey() {
return apiKey;
}
@@ -43,4 +47,12 @@ public void setBaseUrl(String baseUrl) {
this.baseUrl = baseUrl;
}

public Duration getReadTimeout() {
return readTimeout;
}

public void setReadTimeout(Duration readTimeout) {
this.readTimeout = readTimeout;
}

}
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
*/
package org.springframework.ai.autoconfigure.openai;

import java.time.Duration;
import org.junit.jupiter.api.Test;
import org.skyscreamer.jsonassert.JSONAssert;
import org.skyscreamer.jsonassert.JSONCompareMode;
@@ -50,6 +51,7 @@ public void chatProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
@@ -61,9 +63,11 @@ public void chatProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(chatProperties.getApiKey()).isNull();
assertThat(chatProperties.getBaseUrl()).isNull();
assertThat(chatProperties.getReadTimeout()).isNull();

assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
@@ -104,8 +108,10 @@ public void chatOverrideConnectionProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.chat.base-url=TEST_BASE_URL2",
"spring.ai.openai.chat.api-key=456",
"spring.ai.openai.chat.read-timeout=5m",
"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
@@ -117,9 +123,11 @@ public void chatOverrideConnectionProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(chatProperties.getApiKey()).isEqualTo("456");
assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
assertThat(chatProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5));

assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f);
@@ -162,6 +170,7 @@ public void embeddingProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.embedding.options.model=MODEL_XYZ")
// @formatter:on
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
@@ -172,9 +181,11 @@ public void embeddingProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getApiKey()).isNull();
assertThat(embeddingProperties.getBaseUrl()).isNull();
assertThat(embeddingProperties.getReadTimeout()).isNull();

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
});
@@ -187,8 +198,10 @@ public void embeddingOverrideConnectionProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.embedding.base-url=TEST_BASE_URL2",
"spring.ai.openai.embedding.api-key=456",
"spring.ai.openai.embedding.read-timeout=5m",
"spring.ai.openai.embedding.options.model=MODEL_XYZ")
// @formatter:on
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
@@ -199,9 +212,11 @@ public void embeddingOverrideConnectionProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getApiKey()).isEqualTo("456");
assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
assertThat(embeddingProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5));

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
});
@@ -213,6 +228,7 @@ public void imageProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.image.options.model=MODEL_XYZ",
"spring.ai.openai.image.options.n=3")
// @formatter:on
@@ -224,9 +240,11 @@ public void imageProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(imageProperties.getApiKey()).isNull();
assertThat(imageProperties.getBaseUrl()).isNull();
assertThat(imageProperties.getReadTimeout()).isNull();

assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
@@ -239,8 +257,10 @@ public void imageOverrideConnectionProperties() {
// @formatter:off
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.read-timeout=2m",
"spring.ai.openai.image.base-url=TEST_BASE_URL2",
"spring.ai.openai.image.api-key=456",
"spring.ai.openai.image.read-timeout=5m",
"spring.ai.openai.image.options.model=MODEL_XYZ",
"spring.ai.openai.image.options.n=3")
// @formatter:on
@@ -252,9 +272,11 @@ public void imageOverrideConnectionProperties() {

assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(imageProperties.getApiKey()).isEqualTo("456");
assertThat(imageProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
assertThat(imageProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5));

assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
@@ -268,6 +290,7 @@ public void chatOptionsTest() {
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.read-timeout=2m",

"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.frequencyPenalty=-1.5",
@@ -322,6 +345,7 @@ public void chatOptionsTest() {

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002");

@@ -395,6 +419,7 @@ public void embeddingOptionsTest() {
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.read-timeout=2m",

"spring.ai.openai.embedding.options.model=MODEL_XYZ",
"spring.ai.openai.embedding.options.encodingFormat=MyEncodingFormat",
@@ -409,6 +434,7 @@ public void embeddingOptionsTest() {

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
assertThat(embeddingProperties.getOptions().getEncodingFormat()).isEqualTo("MyEncodingFormat");
@@ -422,6 +448,7 @@ public void imageOptionsTest() {
// @formatter:off
"spring.ai.openai.api-key=API_KEY",
"spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.read-timeout=2m",

"spring.ai.openai.image.options.n=3",
"spring.ai.openai.image.options.model=MODEL_XYZ",
@@ -442,6 +469,7 @@ public void imageOptionsTest() {

assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2));

assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");