Skip to content

Support DocumentPostProcessors in RAG Advisor #3031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ PromptTemplate customPromptTemplate = PromptTemplate.builder()

NOTE: The `QuestionAnswerAdvisor.Builder.userTextAdvise()` method is deprecated in favor of using `.promptTemplate()` for more flexible customization.

=== RetrievalAugmentationAdvisor (Incubating)
=== RetrievalAugmentationAdvisor

Spring AI includes a xref:api/retrieval-augmented-generation.adoc#modules[library of RAG modules] that you can use to build your own RAG flows.
The `RetrievalAugmentationAdvisor` is an `Advisor` providing an out-of-the-box implementation for the most common RAG flows,
Expand Down Expand Up @@ -211,6 +211,8 @@ String answer = chatClient.prompt()
.content();
----

You can also use the `DocumentPostProcessor` API to post-process the retrieved documents before passing them to the model. For example, you can use such an interface to perform re-ranking of the retrieved documents based on their relevance to the query, remove irrelevant or redundant documents, or compress the content of each document to reduce noise and redundancy.

[[modules]]
== Modules

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,32 @@ void ragWithMultiQuery() {
evaluateRelevancy(question, chatResponse);
}

@Test
void ragWithDocumentPostProcessor() {
String question = "Where does the adventure of Anacletus and Birba take place?";

RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
.documentPostProcessors((query, documents) -> List
.of(Document.builder().text("The adventure of Anacletus and Birba takes place in Molise").build()))
.build();

ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
.build()
.prompt(question)
.advisors(ragAdvisor)
.call()
.chatResponse();

assertThat(chatResponse).isNotNull();

String response = chatResponse.getResult().getOutput().getText();
System.out.println(response);
assertThat(response).containsIgnoringCase("Molise");

evaluateRelevancy(question, chatResponse);
}

private void evaluateRelevancy(String question, ChatResponse chatResponse) {
EvaluationRequest evaluationRequest = new EvaluationRequest(question,
chatResponse.getMetadata().get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
Expand Down Expand Up @@ -70,6 +71,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

private final DocumentJoiner documentJoiner;

private final List<DocumentPostProcessor> documentPostProcessors;

private final QueryAugmenter queryAugmenter;

private final TaskExecutor taskExecutor;
Expand All @@ -80,14 +83,16 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

private RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
@Nullable DocumentJoiner documentJoiner, @Nullable List<DocumentPostProcessor> documentPostProcessors,
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
@Nullable Integer order) {
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
this.queryExpander = queryExpander;
this.documentRetriever = documentRetriever;
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
this.documentPostProcessors = documentPostProcessors != null ? documentPostProcessors : List.of();
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
Expand Down Expand Up @@ -130,6 +135,11 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable A
// 4. Combine documents retrieved based on multiple queries and from multiple data
// sources.
List<Document> documents = this.documentJoiner.join(documentsForQuery);

// 5. Post-process the documents.
for (var documentPostProcessor : this.documentPostProcessors) {
documents = documentPostProcessor.process(originalQuery, documents);
}
context.put(DOCUMENT_CONTEXT, documents);

// 5. Augment user query with the document contextual data.
Expand Down Expand Up @@ -197,6 +207,8 @@ public static final class Builder {

private DocumentJoiner documentJoiner;

private List<DocumentPostProcessor> documentPostProcessors;

private QueryAugmenter queryAugmenter;

private TaskExecutor taskExecutor;
Expand All @@ -209,11 +221,14 @@ private Builder() {
}

public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers;
return this;
}

public Builder queryTransformers(QueryTransformer... queryTransformers) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = Arrays.asList(queryTransformers);
return this;
}
Expand All @@ -233,6 +248,19 @@ public Builder documentJoiner(DocumentJoiner documentJoiner) {
return this;
}

public Builder documentPostProcessors(List<DocumentPostProcessor> documentPostProcessors) {
Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
this.documentPostProcessors = documentPostProcessors;
return this;
}

public Builder documentPostProcessors(DocumentPostProcessor... documentPostProcessors) {
Assert.notNull(documentPostProcessors, "documentPostProcessors cannot be null");
Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
this.documentPostProcessors = Arrays.asList(documentPostProcessors);
return this;
}

public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
this.queryAugmenter = queryAugmenter;
return this;
Expand All @@ -255,7 +283,8 @@ public Builder order(Integer order) {

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever,
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
this.documentJoiner, this.documentPostProcessors, this.queryAugmenter, this.taskExecutor,
this.scheduler, this.order);
}

}
Expand Down