Skip to content

Commit d4e1655

Browse files
committed
refactor: centralize retry exception handling across AI models
- Add RetryUtils.execute() method to handle RetryException uniformly - Replace duplicate exception handling code in all model implementations - Affects Anthropic, DeepSeek, ElevenLabs, Google GenAI, MiniMax, Mistral AI, Ollama, OpenAI, Vertex AI, and ZhiPu AI models - Remove unused RetryException imports Signed-off-by: Christian Tzolov <[email protected]>
1 parent d5e92be commit d4e1655

File tree

22 files changed

+169
-406
lines changed

22 files changed

+169
-406
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
import org.springframework.ai.support.UsageCalculator;
7979
import org.springframework.ai.tool.definition.ToolDefinition;
8080
import org.springframework.ai.util.json.JsonParser;
81-
import org.springframework.core.retry.RetryException;
8281
import org.springframework.core.retry.RetryTemplate;
8382
import org.springframework.http.HttpHeaders;
8483
import org.springframework.http.ResponseEntity;
@@ -194,19 +193,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
194193
this.observationRegistry)
195194
.observe(() -> {
196195

197-
ResponseEntity<ChatCompletionResponse> completionEntity = null;
198-
try {
199-
completionEntity = this.retryTemplate.execute(() -> this.anthropicApi.chatCompletionEntity(request,
200-
this.getAdditionalHttpHeaders(prompt)));
201-
}
202-
catch (RetryException e) {
203-
if (e.getCause() instanceof RuntimeException r) {
204-
throw r;
205-
}
206-
else {
207-
throw new RuntimeException(e.getCause());
208-
}
209-
}
196+
ResponseEntity<ChatCompletionResponse> completionEntity = RetryUtils.execute(this.retryTemplate,
197+
() -> this.anthropicApi.chatCompletionEntity(request, this.getAdditionalHttpHeaders(prompt)));
210198

211199
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
212200
AnthropicApi.Usage usage = completionResponse.usage();

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
import org.springframework.ai.retry.RetryUtils;
6767
import org.springframework.ai.support.UsageCalculator;
6868
import org.springframework.ai.tool.definition.ToolDefinition;
69-
import org.springframework.core.retry.RetryException;
7069
import org.springframework.core.retry.RetryTemplate;
7170
import org.springframework.http.ResponseEntity;
7271
import org.springframework.util.Assert;
@@ -166,18 +165,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
166165
this.observationRegistry)
167166
.observe(() -> {
168167

169-
ResponseEntity<ChatCompletion> completionEntity = null;
170-
try {
171-
completionEntity = this.retryTemplate.execute(() -> this.deepSeekApi.chatCompletionEntity(request));
172-
}
173-
catch (RetryException e) {
174-
if (e.getCause() instanceof RuntimeException r) {
175-
throw r;
176-
}
177-
else {
178-
throw new RuntimeException(e.getCause());
179-
}
180-
}
168+
ResponseEntity<ChatCompletion> completionEntity = RetryUtils.execute(this.retryTemplate,
169+
() -> this.deepSeekApi.chatCompletionEntity(request));
181170

182171
var chatCompletion = completionEntity.getBody();
183172

models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.springframework.ai.audio.tts.TextToSpeechResponse;
2929
import org.springframework.ai.elevenlabs.api.ElevenLabsApi;
3030
import org.springframework.ai.retry.RetryUtils;
31-
import org.springframework.core.retry.RetryException;
3231
import org.springframework.core.retry.RetryTemplate;
3332
import org.springframework.util.Assert;
3433
import org.springframework.util.LinkedMultiValueMap;
@@ -72,26 +71,15 @@ public static Builder builder() {
7271
public TextToSpeechResponse call(TextToSpeechPrompt prompt) {
7372
RequestContext requestContext = prepareRequest(prompt);
7473

75-
byte[] audioData = null;
76-
try {
77-
audioData = this.retryTemplate.execute(() -> {
78-
var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId,
79-
requestContext.queryParameters);
80-
if (response.getBody() == null) {
81-
logger.warn("No speech response returned for request: {}", requestContext.request);
82-
return new byte[0];
83-
}
84-
return response.getBody();
85-
});
86-
}
87-
catch (RetryException e) {
88-
if (e.getCause() instanceof RuntimeException r) {
89-
throw r;
90-
}
91-
else {
92-
throw new RuntimeException(e.getCause());
74+
byte[] audioData = RetryUtils.execute(this.retryTemplate, () -> {
75+
var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId,
76+
requestContext.queryParameters);
77+
if (response.getBody() == null) {
78+
logger.warn("No speech response returned for request: {}", requestContext.request);
79+
return new byte[0];
9380
}
94-
}
81+
return response.getBody();
82+
});
9583

9684
return new TextToSpeechResponse(List.of(new Speech(audioData)));
9785
}
@@ -100,19 +88,10 @@ public TextToSpeechResponse call(TextToSpeechPrompt prompt) {
10088
public Flux<TextToSpeechResponse> stream(TextToSpeechPrompt prompt) {
10189
RequestContext requestContext = prepareRequest(prompt);
10290

103-
try {
104-
return this.retryTemplate.execute(() -> this.elevenLabsApi
105-
.textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters)
106-
.map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody())))));
107-
}
108-
catch (RetryException e) {
109-
if (e.getCause() instanceof RuntimeException r) {
110-
throw r;
111-
}
112-
else {
113-
throw new RuntimeException(e.getCause());
114-
}
115-
}
91+
return RetryUtils.execute(this.retryTemplate,
92+
() -> this.elevenLabsApi
93+
.textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters)
94+
.map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody())))));
11695
}
11796

11897
private RequestContext prepareRequest(TextToSpeechPrompt prompt) {

models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import org.springframework.ai.model.ModelOptionsUtils;
4747
import org.springframework.ai.observation.conventions.AiProvider;
4848
import org.springframework.ai.retry.RetryUtils;
49-
import org.springframework.core.retry.RetryException;
5049
import org.springframework.core.retry.RetryTemplate;
5150
import org.springframework.util.Assert;
5251
import org.springframework.util.StringUtils;
@@ -169,19 +168,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
169168
}
170169

171170
// Call the embedding API with retry
172-
EmbedContentResponse embeddingResponse = null;
173-
try {
174-
embeddingResponse = this.retryTemplate
175-
.execute(() -> this.genAiClient.models.embedContent(modelName, validTexts, config));
176-
}
177-
catch (RetryException e) {
178-
if (e.getCause() instanceof RuntimeException r) {
179-
throw r;
180-
}
181-
else {
182-
throw new RuntimeException(e.getCause());
183-
}
184-
}
171+
EmbedContentResponse embeddingResponse = RetryUtils.execute(this.retryTemplate,
172+
() -> this.genAiClient.models.embedContent(modelName, validTexts, config));
185173

186174
// Process the response
187175
// Note: We need to handle the case where some texts were filtered out

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
import org.springframework.ai.support.UsageCalculator;
8888
import org.springframework.ai.tool.definition.ToolDefinition;
8989
import org.springframework.beans.factory.DisposableBean;
90-
import org.springframework.core.retry.RetryException;
9190
import org.springframework.core.retry.RetryTemplate;
9291
import org.springframework.lang.NonNull;
9392
import org.springframework.util.Assert;
@@ -406,39 +405,31 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
406405
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
407406
this.observationRegistry)
408407
.observe(() -> {
409-
try {
410-
return this.retryTemplate.execute(() -> {
411-
412-
var geminiRequest = createGeminiRequest(prompt);
413-
414-
GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest);
415-
416-
List<Generation> generations = generateContentResponse.candidates()
417-
.orElse(List.of())
418-
.stream()
419-
.map(this::responseCandidateToGeneration)
420-
.flatMap(List::stream)
421-
.toList();
422-
423-
var usage = generateContentResponse.usageMetadata();
424-
GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions();
425-
Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options)
426-
: getDefaultUsage(null, options);
427-
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);
428-
ChatResponse chatResponse = new ChatResponse(generations,
429-
toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get()));
430-
431-
observationContext.setResponse(chatResponse);
432-
return chatResponse;
433-
});
434-
}
435-
catch (RetryException e) {
436-
if (e.getCause() instanceof RuntimeException r) {
437-
throw r;
438-
}
439408

440-
throw new RuntimeException(e);
441-
}
409+
return RetryUtils.execute(this.retryTemplate, () -> {
410+
411+
var geminiRequest = createGeminiRequest(prompt);
412+
413+
GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest);
414+
415+
List<Generation> generations = generateContentResponse.candidates()
416+
.orElse(List.of())
417+
.stream()
418+
.map(this::responseCandidateToGeneration)
419+
.flatMap(List::stream)
420+
.toList();
421+
422+
var usage = generateContentResponse.usageMetadata();
423+
GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions();
424+
Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options)
425+
: getDefaultUsage(null, options);
426+
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);
427+
ChatResponse chatResponse = new ChatResponse(generations,
428+
toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get()));
429+
430+
observationContext.setResponse(chatResponse);
431+
return chatResponse;
432+
});
442433
});
443434

444435
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6969
import org.springframework.ai.retry.RetryUtils;
7070
import org.springframework.ai.tool.definition.ToolDefinition;
71-
import org.springframework.core.retry.RetryException;
7271
import org.springframework.core.retry.RetryTemplate;
7372
import org.springframework.http.ResponseEntity;
7473
import org.springframework.util.Assert;
@@ -254,18 +253,8 @@ public ChatResponse call(Prompt prompt) {
254253
this.observationRegistry)
255254
.observe(() -> {
256255

257-
ResponseEntity<ChatCompletion> completionEntity = null;
258-
try {
259-
completionEntity = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionEntity(request));
260-
}
261-
catch (RetryException e) {
262-
if (e.getCause() instanceof RuntimeException r) {
263-
throw r;
264-
}
265-
else {
266-
throw new RuntimeException(e.getCause());
267-
}
268-
}
256+
ResponseEntity<ChatCompletion> completionEntity = RetryUtils.execute(this.retryTemplate,
257+
() -> this.miniMaxApi.chatCompletionEntity(request));
269258

270259
var chatCompletion = completionEntity.getBody();
271260

@@ -339,18 +328,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
339328
return Flux.deferContextual(contextView -> {
340329
ChatCompletionRequest request = createRequest(requestPrompt, true);
341330

342-
Flux<ChatCompletionChunk> completionChunks = null;
343-
try {
344-
completionChunks = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionStream(request));
345-
}
346-
catch (RetryException e) {
347-
if (e.getCause() instanceof RuntimeException r) {
348-
throw r;
349-
}
350-
else {
351-
throw new RuntimeException(e.getCause());
352-
}
353-
}
331+
Flux<ChatCompletionChunk> completionChunks = RetryUtils.execute(this.retryTemplate,
332+
() -> this.miniMaxApi.chatCompletionStream(request));
354333

355334
// For chunked responses, only the first chunk contains the choice role.
356335
// The rest of the chunks with same ID share the same role.

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.springframework.ai.minimax.api.MiniMaxApiConstants;
4141
import org.springframework.ai.model.ModelOptionsUtils;
4242
import org.springframework.ai.retry.RetryUtils;
43-
import org.springframework.core.retry.RetryException;
4443
import org.springframework.core.retry.RetryTemplate;
4544
import org.springframework.util.Assert;
4645
import org.springframework.util.StringUtils;
@@ -166,19 +165,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
166165
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
167166
this.observationRegistry)
168167
.observe(() -> {
169-
MiniMaxApi.EmbeddingList apiEmbeddingResponse = null;
170-
try {
171-
apiEmbeddingResponse = this.retryTemplate
172-
.execute(() -> this.miniMaxApi.embeddings(apiRequest).getBody());
173-
}
174-
catch (RetryException e) {
175-
if (e.getCause() instanceof RuntimeException r) {
176-
throw r;
177-
}
178-
else {
179-
throw new RuntimeException(e.getCause());
180-
}
181-
}
168+
MiniMaxApi.EmbeddingList apiEmbeddingResponse = RetryUtils.execute(this.retryTemplate,
169+
() -> this.miniMaxApi.embeddings(apiRequest).getBody());
182170

183171
if (apiEmbeddingResponse == null) {
184172
logger.warn("No embeddings returned for request: {}", request);

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
import org.springframework.ai.retry.RetryUtils;
7070
import org.springframework.ai.support.UsageCalculator;
7171
import org.springframework.ai.tool.definition.ToolDefinition;
72-
import org.springframework.core.retry.RetryException;
7372
import org.springframework.core.retry.RetryTemplate;
7473
import org.springframework.http.ResponseEntity;
7574
import org.springframework.util.Assert;
@@ -192,19 +191,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
192191
this.observationRegistry)
193192
.observe(() -> {
194193

195-
ResponseEntity<ChatCompletion> completionEntity = null;
196-
try {
197-
completionEntity = this.retryTemplate
198-
.execute(() -> this.mistralAiApi.chatCompletionEntity(request));
199-
}
200-
catch (RetryException e) {
201-
if (e.getCause() instanceof RuntimeException r) {
202-
throw r;
203-
}
204-
else {
205-
throw new RuntimeException(e.getCause());
206-
}
207-
}
194+
ResponseEntity<ChatCompletion> completionEntity = RetryUtils.execute(this.retryTemplate,
195+
() -> this.mistralAiApi.chatCompletionEntity(request));
208196

209197
ChatCompletion chatCompletion = completionEntity.getBody();
210198

@@ -276,18 +264,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
276264

277265
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
278266

279-
Flux<ChatCompletionChunk> completionChunks = null;
280-
try {
281-
completionChunks = this.retryTemplate.execute(() -> this.mistralAiApi.chatCompletionStream(request));
282-
}
283-
catch (RetryException e) {
284-
if (e.getCause() instanceof RuntimeException r) {
285-
throw r;
286-
}
287-
else {
288-
throw new RuntimeException(e.getCause());
289-
}
290-
}
267+
Flux<ChatCompletionChunk> completionChunks = RetryUtils.execute(this.retryTemplate,
268+
() -> this.mistralAiApi.chatCompletionStream(request));
291269

292270
// For chunked responses, only the first chunk contains the choice role.
293271
// The rest of the chunks with same ID share the same role.

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.springframework.ai.mistralai.api.MistralAiApi;
4040
import org.springframework.ai.model.ModelOptionsUtils;
4141
import org.springframework.ai.retry.RetryUtils;
42-
import org.springframework.core.retry.RetryException;
4342
import org.springframework.core.retry.RetryTemplate;
4443
import org.springframework.util.Assert;
4544

@@ -118,19 +117,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
118117
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
119118
this.observationRegistry)
120119
.observe(() -> {
121-
MistralAiApi.EmbeddingList<MistralAiApi.Embedding> apiEmbeddingResponse = null;
122-
try {
123-
apiEmbeddingResponse = this.retryTemplate
124-
.execute(() -> this.mistralAiApi.embeddings(apiRequest).getBody());
125-
}
126-
catch (RetryException e) {
127-
if (e.getCause() instanceof RuntimeException r) {
128-
throw r;
129-
}
130-
else {
131-
throw new RuntimeException(e.getCause());
132-
}
133-
}
120+
MistralAiApi.EmbeddingList<MistralAiApi.Embedding> apiEmbeddingResponse = RetryUtils
121+
.execute(this.retryTemplate, () -> this.mistralAiApi.embeddings(apiRequest).getBody());
134122

135123
if (apiEmbeddingResponse == null) {
136124
logger.warn("No embeddings returned for request: {}", request);

0 commit comments

Comments
 (0)