Skip to content

Commit

Permalink
fix prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 20, 2024
1 parent 5b560e1 commit 39e2394
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
Expand Down Expand Up @@ -70,7 +71,10 @@ public AIAssistant createWithPrompt(
case QIANFAN -> QianFanAssistant.builder();
};
AIAssistant aiAssistant = builder.id(id)
.memoryStore((id == null) ? new InMemoryChatMemoryStore() : chatMemoryStore)
.memoryStore(
(id == null)
? new InMemoryChatMemoryStore()
: ((PersistentChatMemoryStore) chatMemoryStore).clone())
.withConfigProvider(assistantConfig)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public class PersistentChatMemoryStore implements ChatMemoryStore {

private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

private final List<ChatMessage> systemMessages = new ArrayList<>();

public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
Expand Down Expand Up @@ -78,19 +79,25 @@ private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatT
@Override
public List<ChatMessage> getMessages(Object threadId) {
List<ChatMessagePO> chatMessages = chatMessageDao.findAllByThreadId((Long) threadId);
if (chatMessages.isEmpty()) {
return new ArrayList<>();
} else {
return chatMessages.stream()
List<ChatMessage> allChatMessages = new ArrayList<>(systemMessages);
if (!chatMessages.isEmpty()) {
allChatMessages.addAll(chatMessages.stream()
.map(this::convertToChatMessage)
.filter(Objects::nonNull)
.collect(Collectors.toList());
.toList());
}
return allChatMessages;
}

@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId);
ChatMessage newMessage = messages.get(messages.size() - 1);
if (newMessage.type().equals(ChatMessageType.SYSTEM)) {
SystemMessage systemMessage = (SystemMessage) newMessage;
systemMessages.add(systemMessage);
return;
}
ChatMessagePO chatMessagePO = convertToChatMessagePO(newMessage, (Long) threadId);
if (chatMessagePO == null) {
return;
}
Expand All @@ -103,4 +110,8 @@ public void deleteMessages(Object threadId) {
chatMessagePOS.forEach(chatMessage -> chatMessage.setIsDeleted(true));
chatMessageDao.partialUpdateByIds(chatMessagePOS);
}

public PersistentChatMemoryStore clone() {
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private String getValueFromAssistantStreamMessage(AssistantStreamMessage assista
return streamMessage.toString();
}

private void saveMessage(String message, MessageType sender) {
private void addMessage(String message, MessageType sender) {
ChatMessage chatMessage;
if (sender.equals(MessageType.AI)) {
chatMessage = new AiMessage(message);
Expand Down Expand Up @@ -131,7 +131,7 @@ public void setSystemPrompt(String systemPrompt) {
} catch (NoApiKeyException | InputRequiredException | InvalidateParameter e) {
throw new RuntimeException(e);
}
saveMessage(systemPrompt, MessageType.SYSTEM);
addMessage(systemPrompt, MessageType.SYSTEM);
}

public static Builder builder() {
Expand All @@ -140,7 +140,7 @@ public static Builder builder() {

@Override
public Flux<String> streamAsk(String userMessage) {
saveMessage(userMessage, MessageType.USER);
addMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -174,13 +174,13 @@ public Flux<String> streamAsk(String userMessage) {
return message;
})
.doOnComplete(() -> {
saveMessage(finalMessage.toString(), MessageType.AI);
addMessage(finalMessage.toString(), MessageType.AI);
});
}

@Override
public String ask(String userMessage) {
saveMessage(userMessage, MessageType.USER);
addMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -244,7 +244,7 @@ public String ask(String userMessage) {
ContentText contentText = (ContentText) content;
finalMessage.append(contentText.getText().getValue());
}
saveMessage(finalMessage.toString(), MessageType.AI);
addMessage(finalMessage.toString(), MessageType.AI);
return finalMessage.toString();
}

Expand Down

0 comments on commit 39e2394

Please sign in to comment.