Skip to content

Commit

Permalink
BIGTOP-4226: Support soft deletion in Chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 20, 2024
1 parent e2f3202 commit 8df8fa7
Show file tree
Hide file tree
Showing 86 changed files with 2,780 additions and 1,078 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
Expand All @@ -27,7 +27,6 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

Expand All @@ -47,12 +46,10 @@ public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao cha

private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
String sender = chatMessagePO.getSender().toLowerCase();
if (sender.equals(MessageSender.AI.getValue())) {
if (sender.equals(MessageType.AI.getValue())) {
return new AiMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.USER.getValue())) {
} else if (sender.equals(MessageType.USER.getValue())) {
return new UserMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.SYSTEM.getValue())) {
return new SystemMessage(chatMessagePO.getMessage());
} else {
return null;
}
Expand All @@ -61,21 +58,17 @@ private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) {
ChatMessagePO chatMessagePO = new ChatMessagePO();
if (chatMessage.type().equals(ChatMessageType.AI)) {
chatMessagePO.setSender(MessageSender.AI.getValue());
chatMessagePO.setSender(MessageType.AI.getValue());
AiMessage aiMessage = (AiMessage) chatMessage;
chatMessagePO.setMessage(aiMessage.text());
} else if (chatMessage.type().equals(ChatMessageType.USER)) {
chatMessagePO.setSender(MessageSender.USER.getValue());
chatMessagePO.setSender(MessageType.USER.getValue());
UserMessage userMessage = (UserMessage) chatMessage;
chatMessagePO.setMessage(userMessage.singleText());
} else if (chatMessage.type().equals(ChatMessageType.SYSTEM)) {
chatMessagePO.setSender(MessageSender.SYSTEM.getValue());
SystemMessage systemMessage = (SystemMessage) chatMessage;
chatMessagePO.setMessage(systemMessage.text());
} else {
chatMessagePO.setSender(chatMessage.type().toString());
return null;
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(chatThreadId);
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(chatThreadId);
chatMessagePO.setUserId(chatThreadPO.getUserId());
chatMessagePO.setThreadId(chatThreadId);
return chatMessagePO;
Expand All @@ -94,11 +87,16 @@ public List<ChatMessage> getMessages(Object threadId) {
@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId);
if (chatMessagePO == null) {
return;
}
chatMessageDao.save(chatMessagePO);
}

@Override
public void deleteMessages(Object threadId) {
chatMessageDao.deleteByThreadId((Long) threadId);
List<ChatMessagePO> chatMessagePOS = chatMessageDao.findAllByThreadId((Long) threadId);
chatMessagePOS.forEach(chatMessage -> chatMessage.setIsDeleted(true));
chatMessageDao.partialUpdateByIds(chatMessagePOS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@
import java.util.stream.Collectors;

@Getter
public enum MessageSender {
public enum MessageType {
USER("user"),
AI("ai"),
SYSTEM("system");

private final String value;

MessageSender(String value) {
MessageType(String value) {
this.value = value;
}

public static List<String> getSenders() {
return Arrays.stream(values()).map(item -> item.value).collect(Collectors.toList());
}

public static MessageSender getMessageSender(String value) {
public static MessageType getMessageSender(String value) {
if (Objects.isNull(value) || value.isEmpty()) {
return null;
}
for (MessageSender messageSender : MessageSender.values()) {
if (messageSender.value.equals(value)) {
return messageSender;
for (MessageType messageType : MessageType.values()) {
if (messageType.value.equals(value)) {
return messageType;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.bigtop.manager.ai.dashscope;

import org.apache.bigtop.manager.ai.core.AbstractAIAssistant;
import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;

Expand Down Expand Up @@ -88,13 +88,13 @@ private String getValueFromAssistantStreamMessage(AssistantStreamMessage assista
return streamMessage.toString();
}

private void saveMessage(String message, MessageSender sender) {
private void saveMessage(String message, MessageType sender) {
ChatMessage chatMessage;
if (sender.equals(MessageSender.AI)) {
if (sender.equals(MessageType.AI)) {
chatMessage = new AiMessage(message);
} else if (sender.equals(MessageSender.USER)) {
} else if (sender.equals(MessageType.USER)) {
chatMessage = new UserMessage(message);
} else if (sender.equals(MessageSender.SYSTEM)) {
} else if (sender.equals(MessageType.SYSTEM)) {
chatMessage = new SystemMessage(message);
} else {
return;
Expand Down Expand Up @@ -131,7 +131,7 @@ public void setSystemPrompt(String systemPrompt) {
} catch (NoApiKeyException | InputRequiredException | InvalidateParameter e) {
throw new RuntimeException(e);
}
saveMessage(systemPrompt, MessageSender.SYSTEM);
saveMessage(systemPrompt, MessageType.SYSTEM);
}

public static Builder builder() {
Expand All @@ -140,7 +140,7 @@ public static Builder builder() {

@Override
public Flux<String> streamAsk(String userMessage) {
saveMessage(userMessage, MessageSender.USER);
saveMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -174,13 +174,13 @@ public Flux<String> streamAsk(String userMessage) {
return message;
})
.doOnComplete(() -> {
saveMessage(finalMessage.toString(), MessageSender.AI);
saveMessage(finalMessage.toString(), MessageType.AI);
});
}

@Override
public String ask(String userMessage) {
saveMessage(userMessage, MessageSender.USER);
saveMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -244,7 +244,7 @@ public String ask(String userMessage) {
ContentText contentText = (ContentText) content;
finalMessage.append(contentText.getText().getValue());
}
saveMessage(finalMessage.toString(), MessageSender.AI);
saveMessage(finalMessage.toString(), MessageType.AI);
return finalMessage.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

@Getter
public enum DBType {
MYSQL("mysql", "MYSQL"),
MYSQL("mysql", "MySQL"),
POSTGRESQL("postgresql", "PostgreSQL"),
DM("dm", "DaMeng");

DBType(String code, String desc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.bigtop.manager.dao.annotations.CreateTime;
import org.apache.bigtop.manager.dao.annotations.UpdateBy;
import org.apache.bigtop.manager.dao.annotations.UpdateTime;
import org.apache.bigtop.manager.dao.po.BasePO;

import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
Expand Down Expand Up @@ -68,22 +69,27 @@ public Object intercept(Invocation invocation) throws Throwable {
Object parameter = invocation.getArgs()[1];
log.debug("sqlCommandType {}", sqlCommandType);

Collection<Object> objects;
if (parameter instanceof MapperMethod.ParamMap) {
MapperMethod.ParamMap<Object> paramMap = ((MapperMethod.ParamMap<Object>) parameter);
if (paramMap.get("param1") instanceof Collection) {
objects = ((Collection<Object>) paramMap.get("param1"));
if (SqlCommandType.INSERT == sqlCommandType || SqlCommandType.UPDATE == sqlCommandType) {
Collection<Object> objects;
if (parameter instanceof MapperMethod.ParamMap) {
MapperMethod.ParamMap<Object> paramMap = ((MapperMethod.ParamMap<Object>) parameter);
if (!paramMap.containsKey("param1") && paramMap.containsKey("arg0")) {
objects = ((Collection<Object>) paramMap.get("arg0"));
} else if (paramMap.get("param1") instanceof Collection) {
objects = ((Collection<Object>) paramMap.get("param1"));
} else {
objects = Collections.singletonList(paramMap.get("param1"));
}
} else {
objects = Collections.singletonList(paramMap.get("param1"));
objects = Collections.singletonList(parameter);
}
} else {
objects = Collections.singletonList(parameter);
}

for (Object o : objects) {
setAuditFields(o, sqlCommandType);
for (Object o : objects) {
if (o instanceof BasePO) {
setAuditFields(o, sqlCommandType);
}
}
}

return invocation.proceed();
}

Expand All @@ -92,26 +98,25 @@ private void setAuditFields(Object object, SqlCommandType sqlCommandType) throws
Timestamp timestamp = new Timestamp(System.currentTimeMillis());

List<Field> fields = ClassUtils.getFields(object.getClass());
if (SqlCommandType.INSERT == sqlCommandType || SqlCommandType.UPDATE == sqlCommandType) {
for (Field field : fields) {
boolean accessible = field.canAccess(object);
field.setAccessible(true);
if (field.isAnnotationPresent(CreateBy.class)
&& SqlCommandType.INSERT == sqlCommandType
&& userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(CreateTime.class) && SqlCommandType.INSERT == sqlCommandType) {
field.set(object, timestamp);
}
if (field.isAnnotationPresent(UpdateBy.class) && userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(UpdateTime.class)) {
field.set(object, timestamp);
}
field.setAccessible(accessible);

for (Field field : fields) {
boolean accessible = field.canAccess(object);
field.setAccessible(true);
if (field.isAnnotationPresent(CreateBy.class)
&& SqlCommandType.INSERT == sqlCommandType
&& userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(CreateTime.class) && SqlCommandType.INSERT == sqlCommandType) {
field.set(object, timestamp);
}
if (field.isAnnotationPresent(UpdateBy.class) && userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(UpdateTime.class)) {
field.set(object, timestamp);
}
field.setAccessible(accessible);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ public class ChatMessagePO extends BasePO implements Serializable {
@Column(name = "message", nullable = false, length = 255)
private String message;

@Column(name = "sender")
@Column(name = "sender", nullable = false)
private String sender;

@Column(name = "user_id")
@Column(name = "user_id", nullable = false)
private Long userId;

@Column(name = "thread_id")
@Column(name = "thread_id", nullable = false)
private Long threadId;

@Column(name = "is_deleted")
private Boolean isDeleted;
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ public class ChatThreadPO extends BasePO implements Serializable {
@Column(name = "model", nullable = false, length = 255)
private String model;

@Column(name = "thread_info", columnDefinition = "json", nullable = false)
@Column(name = "thread_info", columnDefinition = "json")
private Map<String, String> threadInfo;

@Column(name = "user_id")
@Column(name = "user_id", nullable = false)
private Long userId;

@Column(name = "platform_id")
@Column(name = "platform_id", nullable = false)
private Long platformId;

@Column(name = "is_deleted")
private Boolean isDeleted;
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public class HostPO extends BasePO implements Serializable {
@Column(name = "state")
private String state;

@Column(name = "cluster_id")
@Column(name = "cluster_id", nullable = false)
private Long clusterId;

@Transient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ public class JobPO extends BasePO implements Serializable {
@Column(name = "id")
private Long id;

@Column(name = "state")
@Column(name = "state", nullable = false)
private String state;

@Column(name = "name")
private String name;

@Lob
@Column(name = "context")
@Column(name = "context", nullable = false)
private String context;

@Column(name = "cluster_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class PlatformAuthorizedPO extends BasePO implements Serializable {
@Column(name = "credentials", columnDefinition = "json", nullable = false)
private Map<String, String> credentials;

@Column(name = "platform_id")
@Column(name = "platform_id", nullable = false)
private Long platformId;

@Column(name = "is_deleted")
private Boolean isDeleted;
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ public class RepoPO extends BasePO implements Serializable {
@Column(name = "repo_type")
private String repoType;

@Column(name = "cluster_id")
@Column(name = "cluster_id", nullable = false)
private Long clusterId;
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public class StackPO extends BasePO implements Serializable {
@Column(name = "id")
private Long id;

@Column(name = "stack_name")
@Column(name = "stack_name", nullable = false)
private String stackName;

@Column(name = "stack_version")
@Column(name = "stack_version", nullable = false)
private String stackVersion;
}
Loading

0 comments on commit 8df8fa7

Please sign in to comment.