|
14 | 14 | import org.springframework.web.reactive.function.client.WebClient;
|
15 | 15 |
|
16 | 16 | /**
|
17 |
| - * A Reranker implementation that integrates with Cohere's Rerank API. |
18 |
| - * This component reorders retrieved documents based on semantic relevance to the input query. |
| 17 | + * A Reranker implementation that integrates with Cohere's Rerank API. This component |
| 18 | + * reorders retrieved documents based on semantic relevance to the input query. |
19 | 19 | *
|
20 | 20 | * @author KoreaNirsa
|
21 |
| - * @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API Documentation</a> |
| 21 | + * @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API |
| 22 | + * Documentation</a> |
22 | 23 | */
|
23 | 24 | public class CohereReranker {
|
| 25 | + |
24 | 26 | private static final String COHERE_RERANK_ENDPOINT = "https://api.cohere.ai/v1/rerank";
|
25 | 27 |
|
26 | 28 | private static final Logger logger = LoggerFactory.getLogger(CohereReranker.class);
|
27 |
| - |
| 29 | + |
28 | 30 | private static final int MAX_DOCUMENTS = 1000;
|
29 | 31 |
|
30 | 32 | private final WebClient webClient;
|
31 | 33 |
|
32 | 34 | /**
|
33 | 35 | * Constructs a CohereReranker that communicates with the Cohere Rerank API.
|
34 | 36 | * Initializes the internal WebClient with the provided API key for authorization.
|
35 |
| - * |
36 |
| - * @param cohereApi the API configuration object containing the required API key (must not be null) |
| 37 | + * @param cohereApi the API configuration object containing the required API key (must |
| 38 | + * not be null) |
37 | 39 | * @throws IllegalArgumentException if cohereApi is null
|
38 | 40 | */
|
39 |
| - CohereReranker(CohereApi cohereApi) { |
40 |
| - if (cohereApi == null) { |
41 |
| - throw new IllegalArgumentException("CohereApi must not be null"); |
42 |
| - } |
43 |
| - |
44 |
| - this.webClient = WebClient.builder() |
45 |
| - .baseUrl(COHERE_RERANK_ENDPOINT) |
46 |
| - .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) |
47 |
| - .build(); |
48 |
| - } |
49 |
| - |
50 |
| - /** |
51 |
| - * Reranks a list of documents based on the provided query using the Cohere API. |
52 |
| - * |
53 |
| - * @param query The user input query. |
54 |
| - * @param documents The list of documents to rerank. |
55 |
| - * @param topN The number of top results to return (at most). |
56 |
| - * @return A reranked list of documents. If the API fails, returns the original list. |
57 |
| - */ |
58 |
| - public List<Document> rerank(String query, List<Document> documents, int topN) { |
59 |
| - if (topN < 1) { |
60 |
| - throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); |
61 |
| - } |
62 |
| - |
63 |
| - if (documents == null || documents.isEmpty()) { |
64 |
| - logger.warn("Empty document list provided. Skipping rerank."); |
65 |
| - return Collections.emptyList(); |
66 |
| - } |
67 |
| - |
68 |
| - if (documents.size() > MAX_DOCUMENTS) { |
69 |
| - logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", MAX_DOCUMENTS); |
70 |
| - return documents; |
71 |
| - } |
72 |
| - |
73 |
| - int adjustedTopN = Math.min(topN, documents.size()); |
74 |
| - |
75 |
| - Map<String, Object> payload = Map.of( |
76 |
| - "query", query, |
77 |
| - "documents", documents.stream().map(Document::getText).toList(), |
78 |
| - "top_n", adjustedTopN |
79 |
| - ); |
80 |
| - |
81 |
| - // Call the API and process the result |
82 |
| - return sendRerankRequest(payload) |
83 |
| - .map(results -> results.stream() |
84 |
| - .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) |
85 |
| - .map(r -> { |
86 |
| - Document original = documents.get(r.getIndex()); |
87 |
| - Map<String, Object> metadata = new HashMap<>(original.getMetadata()); |
88 |
| - metadata.put("score", String.format("%.4f", r.getRelevanceScore())); |
89 |
| - return new Document(original.getText(), metadata); |
90 |
| - }) |
91 |
| - .toList()) |
92 |
| - .orElseGet(() -> { |
93 |
| - logger.warn("Cohere response is null or invalid"); |
94 |
| - return documents; |
95 |
| - }); |
96 |
| - } |
97 |
| - |
98 |
| - /** |
99 |
| - * Sends a rerank request to the Cohere API and returns the result list. |
100 |
| - * |
101 |
| - * @param payload The request body including query, documents, and top_n. |
102 |
| - * @return An Optional list of reranked results, or empty if failed. |
103 |
| - */ |
104 |
| - private Optional<List<RerankResponse.Result>> sendRerankRequest(Map<String, Object> payload) { |
105 |
| - try { |
106 |
| - RerankResponse response = webClient.post() |
107 |
| - .bodyValue(payload) |
108 |
| - .retrieve() |
109 |
| - .bodyToMono(RerankResponse.class) |
110 |
| - .block(); |
111 |
| - |
112 |
| - return Optional.ofNullable(response) |
113 |
| - .map(RerankResponse::getResults); |
114 |
| - } catch (Exception e) { |
115 |
| - logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); |
116 |
| - return Optional.empty(); |
117 |
| - } |
118 |
| - } |
| 41 | + CohereReranker(CohereApi cohereApi) { |
| 42 | + if (cohereApi == null) { |
| 43 | + throw new IllegalArgumentException("CohereApi must not be null"); |
| 44 | + } |
| 45 | + |
| 46 | + this.webClient = WebClient.builder() |
| 47 | + .baseUrl(COHERE_RERANK_ENDPOINT) |
| 48 | + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereApi.getApiKey()) |
| 49 | + .build(); |
| 50 | + } |
| 51 | + |
| 52 | + /** |
| 53 | + * Reranks a list of documents based on the provided query using the Cohere API. |
| 54 | + * @param query The user input query. |
| 55 | + * @param documents The list of documents to rerank. |
| 56 | + * @param topN The number of top results to return (at most). |
| 57 | + * @return A reranked list of documents. If the API fails, returns the original list. |
| 58 | + */ |
| 59 | + public List<Document> rerank(String query, List<Document> documents, int topN) { |
| 60 | + if (topN < 1) { |
| 61 | + throw new IllegalArgumentException("topN must be ≥ 1. Provided: " + topN); |
| 62 | + } |
| 63 | + |
| 64 | + if (documents == null || documents.isEmpty()) { |
| 65 | + logger.warn("Empty document list provided. Skipping rerank."); |
| 66 | + return Collections.emptyList(); |
| 67 | + } |
| 68 | + |
| 69 | + if (documents.size() > MAX_DOCUMENTS) { |
| 70 | + logger.warn("Cohere recommends ≤ {} documents per rerank request. Larger sizes may cause errors.", |
| 71 | + MAX_DOCUMENTS); |
| 72 | + return documents; |
| 73 | + } |
| 74 | + |
| 75 | + int adjustedTopN = Math.min(topN, documents.size()); |
| 76 | + |
| 77 | + Map<String, Object> payload = Map.of("query", query, "documents", |
| 78 | + documents.stream().map(Document::getText).toList(), "top_n", adjustedTopN); |
| 79 | + |
| 80 | + // Call the API and process the result |
| 81 | + return sendRerankRequest(payload).map(results -> results.stream() |
| 82 | + .sorted(Comparator.comparingDouble(RerankResponse.Result::getRelevanceScore).reversed()) |
| 83 | + .map(r -> { |
| 84 | + Document original = documents.get(r.getIndex()); |
| 85 | + Map<String, Object> metadata = new HashMap<>(original.getMetadata()); |
| 86 | + metadata.put("score", String.format("%.4f", r.getRelevanceScore())); |
| 87 | + return new Document(original.getText(), metadata); |
| 88 | + }) |
| 89 | + .toList()).orElseGet(() -> { |
| 90 | + logger.warn("Cohere response is null or invalid"); |
| 91 | + return documents; |
| 92 | + }); |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * Sends a rerank request to the Cohere API and returns the result list. |
| 97 | + * @param payload The request body including query, documents, and top_n. |
| 98 | + * @return An Optional list of reranked results, or empty if failed. |
| 99 | + */ |
| 100 | + private Optional<List<RerankResponse.Result>> sendRerankRequest(Map<String, Object> payload) { |
| 101 | + try { |
| 102 | + RerankResponse response = webClient.post() |
| 103 | + .bodyValue(payload) |
| 104 | + .retrieve() |
| 105 | + .bodyToMono(RerankResponse.class) |
| 106 | + .block(); |
| 107 | + |
| 108 | + return Optional.ofNullable(response).map(RerankResponse::getResults); |
| 109 | + } |
| 110 | + catch (Exception e) { |
| 111 | + logger.error("Cohere rerank failed, fallback to original order: {}", e.getMessage(), e); |
| 112 | + return Optional.empty(); |
| 113 | + } |
| 114 | + } |
| 115 | + |
119 | 116 | }
|
0 commit comments