Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 8df8fa7
Author: lhpqaq <[email protected]>
Date:   Fri Sep 20 18:25:07 2024 +0800

    BIGTOP-4226: Support soft deletion in Chatbot
  • Loading branch information
lhpqaq committed Sep 20, 2024
1 parent 4cce064 commit 3acc04a
Show file tree
Hide file tree
Showing 23 changed files with 148 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
Expand All @@ -27,7 +27,6 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

Expand All @@ -47,12 +46,10 @@ public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao cha

private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
String sender = chatMessagePO.getSender().toLowerCase();
if (sender.equals(MessageSender.AI.getValue())) {
if (sender.equals(MessageType.AI.getValue())) {
return new AiMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.USER.getValue())) {
} else if (sender.equals(MessageType.USER.getValue())) {
return new UserMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.SYSTEM.getValue())) {
return new SystemMessage(chatMessagePO.getMessage());
} else {
return null;
}
Expand All @@ -61,21 +58,17 @@ private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) {
ChatMessagePO chatMessagePO = new ChatMessagePO();
if (chatMessage.type().equals(ChatMessageType.AI)) {
chatMessagePO.setSender(MessageSender.AI.getValue());
chatMessagePO.setSender(MessageType.AI.getValue());
AiMessage aiMessage = (AiMessage) chatMessage;
chatMessagePO.setMessage(aiMessage.text());
} else if (chatMessage.type().equals(ChatMessageType.USER)) {
chatMessagePO.setSender(MessageSender.USER.getValue());
chatMessagePO.setSender(MessageType.USER.getValue());
UserMessage userMessage = (UserMessage) chatMessage;
chatMessagePO.setMessage(userMessage.singleText());
} else if (chatMessage.type().equals(ChatMessageType.SYSTEM)) {
chatMessagePO.setSender(MessageSender.SYSTEM.getValue());
SystemMessage systemMessage = (SystemMessage) chatMessage;
chatMessagePO.setMessage(systemMessage.text());
} else {
chatMessagePO.setSender(chatMessage.type().toString());
return null;
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(chatThreadId);
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(chatThreadId);
chatMessagePO.setUserId(chatThreadPO.getUserId());
chatMessagePO.setThreadId(chatThreadId);
return chatMessagePO;
Expand All @@ -94,11 +87,16 @@ public List<ChatMessage> getMessages(Object threadId) {
@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId);
if (chatMessagePO == null) {
return;
}
chatMessageDao.save(chatMessagePO);
}

@Override
public void deleteMessages(Object threadId) {
chatMessageDao.deleteByThreadId((Long) threadId);
List<ChatMessagePO> chatMessagePOS = chatMessageDao.findAllByThreadId((Long) threadId);
chatMessagePOS.forEach(chatMessage -> chatMessage.setIsDeleted(true));
chatMessageDao.partialUpdateByIds(chatMessagePOS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@
import java.util.stream.Collectors;

@Getter
public enum MessageSender {
public enum MessageType {
USER("user"),
AI("ai"),
SYSTEM("system");

private final String value;

MessageSender(String value) {
MessageType(String value) {
this.value = value;
}

public static List<String> getSenders() {
return Arrays.stream(values()).map(item -> item.value).collect(Collectors.toList());
}

public static MessageSender getMessageSender(String value) {
public static MessageType getMessageSender(String value) {
if (Objects.isNull(value) || value.isEmpty()) {
return null;
}
for (MessageSender messageSender : MessageSender.values()) {
if (messageSender.value.equals(value)) {
return messageSender;
for (MessageType messageType : MessageType.values()) {
if (messageType.value.equals(value)) {
return messageType;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.bigtop.manager.ai.dashscope;

import org.apache.bigtop.manager.ai.core.AbstractAIAssistant;
import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;

Expand Down Expand Up @@ -88,13 +88,13 @@ private String getValueFromAssistantStreamMessage(AssistantStreamMessage assista
return streamMessage.toString();
}

private void saveMessage(String message, MessageSender sender) {
private void saveMessage(String message, MessageType sender) {
ChatMessage chatMessage;
if (sender.equals(MessageSender.AI)) {
if (sender.equals(MessageType.AI)) {
chatMessage = new AiMessage(message);
} else if (sender.equals(MessageSender.USER)) {
} else if (sender.equals(MessageType.USER)) {
chatMessage = new UserMessage(message);
} else if (sender.equals(MessageSender.SYSTEM)) {
} else if (sender.equals(MessageType.SYSTEM)) {
chatMessage = new SystemMessage(message);
} else {
return;
Expand Down Expand Up @@ -131,7 +131,7 @@ public void setSystemPrompt(String systemPrompt) {
} catch (NoApiKeyException | InputRequiredException | InvalidateParameter e) {
throw new RuntimeException(e);
}
saveMessage(systemPrompt, MessageSender.SYSTEM);
saveMessage(systemPrompt, MessageType.SYSTEM);
}

public static Builder builder() {
Expand All @@ -140,7 +140,7 @@ public static Builder builder() {

@Override
public Flux<String> streamAsk(String userMessage) {
saveMessage(userMessage, MessageSender.USER);
saveMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -174,13 +174,13 @@ public Flux<String> streamAsk(String userMessage) {
return message;
})
.doOnComplete(() -> {
saveMessage(finalMessage.toString(), MessageSender.AI);
saveMessage(finalMessage.toString(), MessageType.AI);
});
}

@Override
public String ask(String userMessage) {
saveMessage(userMessage, MessageSender.USER);
saveMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -244,7 +244,7 @@ public String ask(String userMessage) {
ContentText contentText = (ContentText) content;
finalMessage.append(contentText.getText().getValue());
}
saveMessage(finalMessage.toString(), MessageSender.AI);
saveMessage(finalMessage.toString(), MessageType.AI);
return finalMessage.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,7 @@ public class ChatMessagePO extends BasePO implements Serializable {

@Column(name = "thread_id", nullable = false)
private Long threadId;

@Column(name = "is_deleted")
private Boolean isDeleted;
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ public class ChatThreadPO extends BasePO implements Serializable {

@Column(name = "platform_id", nullable = false)
private Long platformId;

@Column(name = "is_deleted")
private Boolean isDeleted;
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@ public class PlatformAuthorizedPO extends BasePO implements Serializable {

@Column(name = "platform_id", nullable = false)
private Long platformId;

@Column(name = "is_deleted")
private Boolean isDeleted;
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,4 @@

public interface ChatMessageDao extends BaseDao<ChatMessagePO> {
List<ChatMessagePO> findAllByThreadId(@Param("threadId") Long threadId);

void deleteByThreadId(@Param("threadId") Long threadId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
public interface ChatThreadDao extends BaseDao<ChatThreadPO> {
List<ChatThreadPO> findAllByUserId(@Param("userId") Long userId);

ChatThreadPO findById(Long id);

ChatThreadPO findByThreadId(@Param("id") Long id);

List<ChatThreadPO> findAllByPlatformAuthorizedIdAndUserId(
@Param("platformId") Long platformAuthorizedId, @Param("userId") Long userId);

void saveWithThreadInfo(ChatThreadPO chatThreadPO);

List<ChatThreadPO> findAllByPlatformId(@Param("platformId") Long platformId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@

import org.apache.ibatis.annotations.Param;

import java.util.List;

public interface PlatformAuthorizedDao extends BaseDao<PlatformAuthorizedPO> {
PlatformAuthorizedPO findByPlatformId(@Param("id") Long platformId);

void saveWithCredentials(PlatformAuthorizedPO platformAuthorizedPO);

List<PlatformAuthorizedPO> findAllPlatform();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
SELECT *
FROM llm_chat_message
WHERE thread_id = #{threadId}
AND is_deleted = 0
</select>

<delete id="deleteByThreadId">
DELETE FROM llm_chat_message
WHERE thread_id = #{threadId}
</delete>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@
SELECT *
FROM llm_chat_thread
WHERE user_id = #{userId}
AND is_deleted = 0
</select>

<select id="findAllByPlatformAuthorizedIdAndUserId" resultType="org.apache.bigtop.manager.dao.po.ChatThreadPO">
SELECT *
FROM llm_chat_thread
WHERE platform_id = #{platformId} AND user_id = #{userId}
WHERE platform_id = #{platformId} AND user_id = #{userId} AND is_deleted = 0
</select>

<select id="findByThreadId" resultMap="ChatThreadResultMap">
SELECT * FROM llm_chat_thread WHERE id = #{id}
SELECT * FROM llm_chat_thread WHERE id = #{id} AND is_deleted = 0
</select>

<insert id="saveWithThreadInfo" parameterType="org.apache.bigtop.manager.dao.po.ChatThreadPO" useGeneratedKeys="true" keyProperty="id">
Expand All @@ -55,4 +56,10 @@
thread_info = VALUES(thread_info)
</insert>

<select id="findAllByPlatformId" resultType="org.apache.bigtop.manager.dao.po.ChatThreadPO">
SELECT *
FROM llm_chat_thread
WHERE platform_id = #{platformId} AND is_deleted = 0
</select>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@
</resultMap>

<select id="findByPlatformId" resultMap="PlatformAuthorizedResultMap">
SELECT * FROM llm_platform_authorized WHERE id = #{id}
SELECT * FROM llm_platform_authorized WHERE id = #{id} AND is_deleted = 0
</select>

<select id="findAllPlatform" resultMap="PlatformAuthorizedResultMap">
SELECT * FROM llm_platform_authorized WHERE is_deleted = 0
</select>

<insert id="saveWithCredentials" parameterType="org.apache.bigtop.manager.dao.po.PlatformAuthorizedPO" useGeneratedKeys="true" keyProperty="id">
INSERT INTO llm_platform_authorized (platform_id, credentials)
VALUES (#{platformId}, #{credentials, typeHandler=org.apache.bigtop.manager.dao.handler.JsonTypeHandler})
ON DUPLICATE KEY UPDATE
platform_id = VALUES(platform_id),
credentials = VALUES(credentials)
ON DUPLICATE KEY UPDATE
platform_id = VALUES(platform_id),
credentials = VALUES(credentials)
</insert>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
SELECT *
FROM llm_chat_message
WHERE thread_id = #{threadId}
AND is_deleted = false
</select>

<delete id="deleteByThreadId">
DELETE FROM llm_chat_message
WHERE thread_id = #{threadId}
</delete>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@
SELECT *
FROM llm_chat_thread
WHERE user_id = #{userId}
AND is_deleted = false
</select>

<select id="findAllByPlatformAuthorizedIdAndUserId" resultType="org.apache.bigtop.manager.dao.po.ChatThreadPO">
SELECT *
FROM llm_chat_thread
WHERE platform_id = #{platformId} AND user_id = #{userId}
WHERE platform_id = #{platformId} AND user_id = #{userId} AND is_deleted = false
</select>

<select id="findByThreadId" resultMap="ChatThreadResultMap">
SELECT * FROM llm_chat_thread WHERE id = #{id}
SELECT * FROM llm_chat_thread WHERE id = #{id} AND is_deleted = false
</select>

<insert id="saveWithThreadInfo" parameterType="org.apache.bigtop.manager.dao.po.ChatThreadPO" useGeneratedKeys="true" keyProperty="id">
Expand All @@ -55,4 +56,10 @@
thread_info = VALUES(thread_info)
</insert>

<select id="findAllByPlatformId" resultType="org.apache.bigtop.manager.dao.po.ChatThreadPO">
SELECT *
FROM llm_chat_thread
WHERE platform_id = #{platformId} AND is_deleted = 0
</select>

</mapper>
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@
</resultMap>

<select id="findByPlatformId" resultMap="PlatformAuthorizedResultMap">
SELECT * FROM llm_platform_authorized WHERE id = #{id}
SELECT * FROM llm_platform_authorized WHERE id = #{id} AND is_deleted = false
</select>

<select id="findAllPlatform" resultMap="PlatformAuthorizedResultMap">
SELECT * FROM llm_platform_authorized WHERE is_deleted = false
</select>

<insert id="saveWithCredentials" parameterType="org.apache.bigtop.manager.dao.po.PlatformAuthorizedPO" useGeneratedKeys="true" keyProperty="id">
INSERT INTO llm_platform_authorized (platform_id, credentials)
VALUES (#{platformId}, #{credentials, typeHandler=org.apache.bigtop.manager.dao.handler.JsonTypeHandler})
ON DUPLICATE KEY UPDATE
platform_id = VALUES(platform_id),
credentials = VALUES(credentials)
platform_id = VALUES(platform_id),
credentials = VALUES(credentials)
</insert>

</mapper>
Loading

0 comments on commit 3acc04a

Please sign in to comment.