Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/main/java/com/example/carina/config/CarinaConfig.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package com.example.carina.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.retriever.VectorStoreRetriever;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -21,8 +19,11 @@ public VectorStore vectorStore(EmbeddingClient embeddingClient, JdbcTemplate jdb
}

@Bean
public VectorStoreRetriever vectorStoreRetriever(VectorStore vectorStore) {
return new VectorStoreRetriever(vectorStore, 4, 0.75);
public SearchRequest searchRequest() {
SearchRequest searchRequest = SearchRequest.defaults();
searchRequest.withTopK(4);
searchRequest.withSimilarityThreshold(0.75);
return searchRequest;
}


Expand Down
37 changes: 22 additions & 15 deletions src/main/java/com/example/carina/qa/QAService.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.client.AiClient;
import org.springframework.ai.client.AiResponse;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.ai.retriever.VectorStoreRetriever;

import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
Expand All @@ -30,14 +33,17 @@ public class QAService {
@Value("classpath:/prompts/system-chatbot.st")
private Resource chatbotSystemPromptResource;

private final AiClient aiClient;
private final ChatClient chatClient;

private final VectorStore vectorStore;

private final VectorStoreRetriever vectorStoreRetriever;
private final SearchRequest searchRequest;

@Autowired
public QAService(AiClient aiClient, VectorStoreRetriever vectorStoreRetriever) {
this.aiClient = aiClient;
this.vectorStoreRetriever = vectorStoreRetriever;
public QAService(ChatClient chatClient, VectorStore vectorStore, SearchRequest searchRequest) {
this.chatClient = chatClient;
this.vectorStore = vectorStore;
this.searchRequest = searchRequest;
}

public String generate(String message, boolean stuffit) {
Expand All @@ -46,15 +52,16 @@ public String generate(String message, boolean stuffit) {
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));

logger.info("Asking AI model to reply to question.");
AiResponse aiResponse = aiClient.generate(prompt);
ChatResponse chatResponse = chatClient.call(prompt);
logger.info("AI responded.");
return aiResponse.getGeneration().getContent();
return chatResponse.getResult().getOutput().getContent();
}

private Message getSystemMessage(String message, boolean stuffit) {
if (stuffit) {
logger.info("Retrieving relevant documents");
List<Document> similarDocuments = vectorStoreRetriever.retrieve(message);
searchRequest.withQuery(message);
List<Document> similarDocuments = vectorStore.similaritySearch(searchRequest);
logger.info(String.format("Found %s relevant documents.", similarDocuments.size()));
String documents = similarDocuments.stream().map(entry -> entry.getContent()).collect(Collectors.joining("\n"));
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.qaSystemPromptResource);
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/com/example/carina/simple/SimpleAiController.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.example.carina.simple;

import org.springframework.ai.client.AiClient;
import org.springframework.ai.chat.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
Expand All @@ -9,16 +9,16 @@
@RestController
public class SimpleAiController {

private final AiClient aiClient;
private final ChatClient chatClient;

@Autowired
public SimpleAiController(AiClient aiClient) {
this.aiClient = aiClient;
public SimpleAiController(ChatClient chatClient) {
this.chatClient = chatClient;
}

@GetMapping("/ai/simple")
public Completion completion(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) {
return new Completion(aiClient.generate(message));
return new Completion(chatClient.call(message));
}

}