diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc index fdafefc3de..f431d34a40 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc @@ -119,7 +119,7 @@ PromptTemplate customPromptTemplate = PromptTemplate.builder() NOTE: The `QuestionAnswerAdvisor.Builder.userTextAdvise()` method is deprecated in favor of using `.promptTemplate()` for more flexible customization. -=== RetrievalAugmentationAdvisor (Incubating) +=== RetrievalAugmentationAdvisor Spring AI includes a xref:api/retrieval-augmented-generation.adoc#modules[library of RAG modules] that you can use to build your own RAG flows. The `RetrievalAugmentationAdvisor` is an `Advisor` providing an out-of-the-box implementation for the most common RAG flows, @@ -211,6 +211,8 @@ String answer = chatClient.prompt() .content(); ---- +You can also use the `DocumentPostProcessor` API to post-process the retrieved documents before passing them to the model. For example, you can use such an interface to perform re-ranking of the retrieved documents based on their relevance to the query, remove irrelevant or redundant documents, or compress the content of each document to reduce noise and redundancy. + [[modules]] == Modules diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java index da17ed0edd..2a2229d681 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java @@ -261,6 +261,32 @@ void ragWithMultiQuery() { evaluateRelevancy(question, chatResponse); } + @Test + void ragWithDocumentPostProcessor() { + String question = "Where does the adventure of Anacletus and Birba take place?"; + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build()) + .documentPostProcessors((query, documents) -> List + .of(Document.builder().text("The adventure of Anacletus and Birba takes place in Molise").build())) + .build(); + + ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel) + .build() + .prompt(question) + .advisors(ragAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + + String response = chatResponse.getResult().getOutput().getText(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Molise"); + + evaluateRelevancy(question, chatResponse); + } + private void evaluateRelevancy(String question, ChatResponse chatResponse) { EvaluationRequest evaluationRequest = new EvaluationRequest(question, chatResponse.getMetadata().get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT), diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java index 69521638d8..c2dbbb6f57 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java @@ -34,6 +34,7 @@ import org.springframework.ai.rag.Query; import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; import org.springframework.ai.rag.generation.augmentation.QueryAugmenter; +import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor; import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander; import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner; @@ -70,6 +71,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { private final DocumentJoiner documentJoiner; + private final List documentPostProcessors; + private final QueryAugmenter queryAugmenter; private final TaskExecutor taskExecutor; @@ -80,14 +83,16 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { private RetrievalAugmentationAdvisor(@Nullable List queryTransformers, @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, - @Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter, - @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) { + @Nullable DocumentJoiner documentJoiner, @Nullable List documentPostProcessors, + @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, + @Nullable Integer order) { Assert.notNull(documentRetriever, "documentRetriever cannot be null"); Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); this.queryTransformers = queryTransformers != null ? queryTransformers : List.of(); this.queryExpander = queryExpander; this.documentRetriever = documentRetriever; this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner(); + this.documentPostProcessors = documentPostProcessors != null ? documentPostProcessors : List.of(); this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build(); this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor(); this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER; @@ -130,6 +135,11 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable A // 4. Combine documents retrieved based on multiple queries and from multiple data // sources. List documents = this.documentJoiner.join(documentsForQuery); + + // 5. Post-process the documents. + for (var documentPostProcessor : this.documentPostProcessors) { + documents = documentPostProcessor.process(originalQuery, documents); + } context.put(DOCUMENT_CONTEXT, documents); // 5. Augment user query with the document contextual data. @@ -197,6 +207,8 @@ public static final class Builder { private DocumentJoiner documentJoiner; + private List documentPostProcessors; + private QueryAugmenter queryAugmenter; private TaskExecutor taskExecutor; @@ -209,11 +221,14 @@ private Builder() { } public Builder queryTransformers(List queryTransformers) { + Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); this.queryTransformers = queryTransformers; return this; } public Builder queryTransformers(QueryTransformer... queryTransformers) { + Assert.notNull(queryTransformers, "queryTransformers cannot be null"); + Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); this.queryTransformers = Arrays.asList(queryTransformers); return this; } @@ -233,6 +248,19 @@ public Builder documentJoiner(DocumentJoiner documentJoiner) { return this; } + public Builder documentPostProcessors(List documentPostProcessors) { + Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements"); + this.documentPostProcessors = documentPostProcessors; + return this; + } + + public Builder documentPostProcessors(DocumentPostProcessor... documentPostProcessors) { + Assert.notNull(documentPostProcessors, "documentPostProcessors cannot be null"); + Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements"); + this.documentPostProcessors = Arrays.asList(documentPostProcessors); + return this; + } + public Builder queryAugmenter(QueryAugmenter queryAugmenter) { this.queryAugmenter = queryAugmenter; return this; @@ -255,7 +283,8 @@ public Builder order(Integer order) { public RetrievalAugmentationAdvisor build() { return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, - this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order); + this.documentJoiner, this.documentPostProcessors, this.queryAugmenter, this.taskExecutor, + this.scheduler, this.order); } }