Skip to content

Commit 5527d03

Browse files
ThomasVitalemarkpollack
authored andcommitted
Configure TemplateRenderer in ChatClient
- Extend the ChatClient with a new templateRenderer() method to pass a custom TemplateRenderer object used to render user and system templates. - Evolve the QuestionAnswerAdvisor to accept a PromptTemplate for customising the RAG prompt and templating logic while maintaining backward compatibility. - Introduce integration tests for the QuestionAnswerAdvisor. - Document the TemplateRenderer API and how to use it to build PromptTemplate with custom templating logic. - Document how to customise the templating logic used internally by the ChatClient via the TemplateRendererAPI. Add validation tests and improve PromptTemplate resource handling Enhance robustness and reliability of the PromptTemplate class with better resource handling and comprehensive input validation: - Add dedicated validation tests for builder methods with null/invalid inputs - Improve renderResource method to gracefully handle edge cases: - Null resources return empty string - ByteArrayResource handling with proper charset (UTF-8) - Empty resources check with proper existence test - Better error handling with logging instead of exception propagation - Add input validation assertions to all Builder methods - Fix typo in deprecated annotation comment ("fahvor" → "favor") Update documentation to clarify template rendering in different contexts: - Add clear notes about TemplateRenderer usage in ChatClient vs Advisors - Document how advisor template customization differs from ChatClient template rendering - Add comprehensive API upgrade notes for template-related deprecations - Include detailed migration examples for PromptTemplate and QuestionAnswerAdvisor Fixes gh-355, gh-1687, gh-2448, gh-1849, gh-1428 Signed-off-by: Thomas Vitale <[email protected]>
1 parent b0d6719 commit 5527d03

File tree

20 files changed

+1005
-100
lines changed

20 files changed

+1005
-100
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java

+51-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -39,6 +39,7 @@
3939
import org.springframework.ai.vectorstore.VectorStore;
4040
import org.springframework.ai.vectorstore.filter.Filter;
4141
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
42+
import org.springframework.lang.Nullable;
4243
import org.springframework.util.Assert;
4344
import org.springframework.util.StringUtils;
4445

@@ -49,6 +50,7 @@
4950
* @author Christian Tzolov
5051
* @author Timo Salm
5152
* @author Ilayaperumal Gopinathan
53+
* @author Thomas Vitale
5254
* @since 1.0.0
5355
*/
5456
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
@@ -57,7 +59,7 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
5759

5860
public static final String FILTER_EXPRESSION = "qa_filter_expression";
5961

60-
private static final String DEFAULT_USER_TEXT_ADVISE = """
62+
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
6163
6264
Context information is below, surrounded by ---------------------
6365
@@ -68,13 +70,13 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
6870
Given the context and provided history information and not prior knowledge,
6971
reply to the user comment. If the answer is not in the context, inform
7072
the user that you can't answer the question.
71-
""";
73+
""");
7274

7375
private static final int DEFAULT_ORDER = 0;
7476

7577
private final VectorStore vectorStore;
7678

77-
private final String userTextAdvise;
79+
private final PromptTemplate promptTemplate;
7880

7981
private final SearchRequest searchRequest;
8082

@@ -88,7 +90,7 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
8890
* @param vectorStore The vector store to use
8991
*/
9092
public QuestionAnswerAdvisor(VectorStore vectorStore) {
91-
this(vectorStore, SearchRequest.builder().build(), DEFAULT_USER_TEXT_ADVISE);
93+
this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
9294
}
9395

9496
/**
@@ -97,9 +99,11 @@ public QuestionAnswerAdvisor(VectorStore vectorStore) {
9799
* @param vectorStore The vector store to use
98100
* @param searchRequest The search request defined using the portable filter
99101
* expression syntax
102+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
100103
*/
104+
@Deprecated
101105
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
102-
this(vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
106+
this(vectorStore, searchRequest, DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
103107
}
104108

105109
/**
@@ -110,9 +114,12 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
110114
* expression syntax
111115
* @param userTextAdvise The user text to append to the existing user prompt. The text
112116
* should contain a placeholder named "question_answer_context".
117+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
113118
*/
119+
@Deprecated
114120
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
115-
this(vectorStore, searchRequest, userTextAdvise, true);
121+
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), true,
122+
DEFAULT_ORDER);
116123
}
117124

118125
/**
@@ -127,10 +134,13 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
127134
* blocking threads. If false the advisor will not protect the execution from blocking
128135
* threads. This is useful when the advisor is used in a non-blocking environment. It
129136
* is true by default.
137+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
130138
*/
139+
@Deprecated
131140
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
132141
boolean protectFromBlocking) {
133-
this(vectorStore, searchRequest, userTextAdvise, protectFromBlocking, DEFAULT_ORDER);
142+
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
143+
DEFAULT_ORDER);
134144
}
135145

136146
/**
@@ -146,17 +156,23 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
146156
* threads. This is useful when the advisor is used in a non-blocking environment. It
147157
* is true by default.
148158
* @param order The order of the advisor.
159+
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
149160
*/
161+
@Deprecated
150162
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
151163
boolean protectFromBlocking, int order) {
164+
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
165+
order);
166+
}
152167

153-
Assert.notNull(vectorStore, "The vectorStore must not be null!");
154-
Assert.notNull(searchRequest, "The searchRequest must not be null!");
155-
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
168+
QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, @Nullable PromptTemplate promptTemplate,
169+
boolean protectFromBlocking, int order) {
170+
Assert.notNull(vectorStore, "vectorStore cannot be null");
171+
Assert.notNull(searchRequest, "searchRequest cannot be null");
156172

157173
this.vectorStore = vectorStore;
158174
this.searchRequest = searchRequest;
159-
this.userTextAdvise = userTextAdvise;
175+
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
160176
this.protectFromBlocking = protectFromBlocking;
161177
this.order = order;
162178
}
@@ -212,32 +228,30 @@ private AdvisedRequest before(AdvisedRequest request) {
212228

213229
var context = new HashMap<>(request.adviseContext());
214230

215-
// 1. Advise the system text.
216-
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
217-
218-
// 2. Search for similar documents in the vector store.
219-
String query = new PromptTemplate(request.userText(), request.userParams()).render();
231+
// 1. Search for similar documents in the vector store.
220232
var searchRequestToUse = SearchRequest.from(this.searchRequest)
221-
.query(query)
233+
.query(request.userText())
222234
.filterExpression(doGetFilterExpression(context))
223235
.build();
224236

225237
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
226238

227-
// 3. Create the context from the documents.
239+
// 2. Create the context from the documents.
228240
context.put(RETRIEVED_DOCUMENTS, documents);
229241

230242
String documentContext = documents.stream()
231243
.map(Document::getText)
232244
.collect(Collectors.joining(System.lineSeparator()));
233245

234-
// 4. Advise the user parameters.
235-
Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
236-
advisedUserParams.put("question_answer_context", documentContext);
246+
// 3. Augment the user prompt with the document context.
247+
String augmentedUserText = this.promptTemplate.mutate()
248+
.template(request.userText() + System.lineSeparator() + this.promptTemplate.getTemplate())
249+
.variables(Map.of("question_answer_context", documentContext))
250+
.build()
251+
.render();
237252

238253
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
239-
.userText(advisedUserText)
240-
.userParams(advisedUserParams)
254+
.userText(augmentedUserText)
241255
.adviseContext(context)
242256
.build();
243257

@@ -266,7 +280,7 @@ public static final class Builder {
266280

267281
private SearchRequest searchRequest = SearchRequest.builder().build();
268282

269-
private String userTextAdvise = DEFAULT_USER_TEXT_ADVISE;
283+
private PromptTemplate promptTemplate;
270284

271285
private boolean protectFromBlocking = true;
272286

@@ -277,15 +291,25 @@ private Builder(VectorStore vectorStore) {
277291
this.vectorStore = vectorStore;
278292
}
279293

294+
public Builder promptTemplate(PromptTemplate promptTemplate) {
295+
Assert.notNull(promptTemplate, "promptTemplate cannot be null");
296+
this.promptTemplate = promptTemplate;
297+
return this;
298+
}
299+
280300
public Builder searchRequest(SearchRequest searchRequest) {
281301
Assert.notNull(searchRequest, "The searchRequest must not be null!");
282302
this.searchRequest = searchRequest;
283303
return this;
284304
}
285305

306+
/**
307+
* @deprecated in favour of {@link #promptTemplate(PromptTemplate)}
308+
*/
309+
@Deprecated
286310
public Builder userTextAdvise(String userTextAdvise) {
287311
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
288-
this.userTextAdvise = userTextAdvise;
312+
this.promptTemplate = PromptTemplate.builder().template(userTextAdvise).build();
289313
return this;
290314
}
291315

@@ -300,7 +324,7 @@ public Builder order(int order) {
300324
}
301325

302326
public QuestionAnswerAdvisor build() {
303-
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise,
327+
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate,
304328
this.protectFromBlocking, this.order);
305329
}
306330

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

+120-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@
4343
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
4444
import org.springframework.ai.openai.api.tool.MockWeatherService;
4545
import org.springframework.ai.openai.testutils.AbstractIT;
46+
import org.springframework.ai.template.st.StTemplateRenderer;
4647
import org.springframework.ai.test.CurlyBracketEscaper;
4748
import org.springframework.ai.tool.function.FunctionToolCallback;
4849
import org.springframework.beans.factory.annotation.Value;
@@ -378,6 +379,124 @@ void multiModalityAudioResponse() {
378379
logger.info("Response: " + response);
379380
}
380381

382+
@Test
383+
void customTemplateRendererWithCall() {
384+
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
385+
386+
// @formatter:off
387+
String result = ChatClient.create(this.chatModel).prompt()
388+
.user(u -> u
389+
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
390+
+ "<format>")
391+
.param("format", outputConverter.getFormat()))
392+
.templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build())
393+
.call()
394+
.content();
395+
// @formatter:on
396+
397+
assertThat(result).isNotEmpty();
398+
ActorsFilms actorsFilms = outputConverter.convert(result);
399+
400+
logger.info("" + actorsFilms);
401+
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
402+
assertThat(actorsFilms.movies()).hasSize(5);
403+
}
404+
405+
@Test
406+
void customTemplateRendererWithCallAndAdvisor() {
407+
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
408+
409+
// @formatter:off
410+
String result = ChatClient.create(this.chatModel).prompt()
411+
.advisors(new SimpleLoggerAdvisor())
412+
.user(u -> u
413+
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
414+
+ "<format>")
415+
.param("format", outputConverter.getFormat()))
416+
.templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build())
417+
.call()
418+
.content();
419+
// @formatter:on
420+
421+
assertThat(result).isNotEmpty();
422+
ActorsFilms actorsFilms = outputConverter.convert(result);
423+
424+
logger.info("" + actorsFilms);
425+
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
426+
assertThat(actorsFilms.movies()).hasSize(5);
427+
}
428+
429+
@Test
430+
void customTemplateRendererWithStream() {
431+
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
432+
433+
// @formatter:off
434+
Flux<ChatResponse> chatResponse = ChatClient.create(this.chatModel)
435+
.prompt()
436+
.options(OpenAiChatOptions.builder().streamUsage(true).build())
437+
.user(u -> u
438+
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
439+
+ "<format>")
440+
.param("format", outputConverter.getFormat()))
441+
.templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build())
442+
.stream()
443+
.chatResponse();
444+
445+
List<ChatResponse> chatResponses = chatResponse.collectList()
446+
.block()
447+
.stream()
448+
.toList();
449+
450+
String generationTextFromStream = chatResponses
451+
.stream()
452+
.filter(cr -> cr.getResult() != null)
453+
.map(cr -> cr.getResult().getOutput().getText())
454+
.collect(Collectors.joining());
455+
// @formatter:on
456+
457+
ActorsFilms actorsFilms = outputConverter.convert(generationTextFromStream);
458+
459+
logger.info("" + actorsFilms);
460+
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
461+
assertThat(actorsFilms.movies()).hasSize(5);
462+
}
463+
464+
@Test
465+
void customTemplateRendererWithStreamAndAdvisor() {
466+
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
467+
468+
// @formatter:off
469+
Flux<ChatResponse> chatResponse = ChatClient.create(this.chatModel)
470+
.prompt()
471+
.options(OpenAiChatOptions.builder().streamUsage(true).build())
472+
.advisors(new SimpleLoggerAdvisor())
473+
.user(u -> u
474+
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
475+
+ "<format>")
476+
.param("format", outputConverter.getFormat()))
477+
.templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build())
478+
.stream()
479+
.chatResponse();
480+
481+
List<ChatResponse> chatResponses = chatResponse.collectList()
482+
.block()
483+
.stream()
484+
.toList();
485+
486+
String generationTextFromStream = chatResponses
487+
.stream()
488+
.filter(cr -> cr.getResult() != null)
489+
.map(cr -> cr.getResult().getOutput().getText())
490+
.collect(Collectors.joining());
491+
// @formatter:on
492+
493+
ActorsFilms actorsFilms = outputConverter.convert(generationTextFromStream);
494+
495+
logger.info("" + actorsFilms);
496+
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
497+
assertThat(actorsFilms.movies()).hasSize(5);
498+
}
499+
381500
record ActorsFilms(String actor, List<String> movies) {
382501

383502
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java

+5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.chat.prompt.Prompt;
3535
import org.springframework.ai.content.Media;
3636
import org.springframework.ai.converter.StructuredOutputConverter;
37+
import org.springframework.ai.template.TemplateRenderer;
3738
import org.springframework.ai.tool.ToolCallback;
3839
import org.springframework.ai.tool.ToolCallbackProvider;
3940
import org.springframework.core.ParameterizedTypeReference;
@@ -247,6 +248,8 @@ interface ChatClientRequestSpec {
247248

248249
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
249250

251+
ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer);
252+
250253
CallResponseSpec call();
251254

252255
StreamResponseSpec stream();
@@ -282,6 +285,8 @@ interface Builder {
282285

283286
Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);
284287

288+
Builder defaultTemplateRenderer(TemplateRenderer templateRenderer);
289+
285290
Builder defaultTools(String... toolNames);
286291

287292
Builder defaultTools(ToolCallback... toolCallbacks);

0 commit comments

Comments
 (0)