From 51a7a136aec4d3ef854ff649b3c75e7fe6a80eb2 Mon Sep 17 00:00:00 2001 From: Zichuan Wei Date: Sat, 9 May 2026 14:32:04 -0700 Subject: [PATCH] update the load history logic PiperOrigin-RevId: 913067368 --- .../customtasks/agentchat/AgentChatScreen.kt | 16 +++++- .../ai/edge/gallery/runtime/LlmModelHelper.kt | 2 + .../runtime/aicore/AICoreModelHelper.kt | 8 +++ .../edge/gallery/ui/common/chat/ChatView.kt | 49 +++---------------- .../gallery/ui/llmchat/LlmChatModelHelper.kt | 2 + .../edge/gallery/ui/llmchat/LlmChatScreen.kt | 27 ++++++++-- .../gallery/ui/llmchat/LlmChatViewModel.kt | 3 ++ 7 files changed, 60 insertions(+), 47 deletions(-) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt index 3cd91f6f3..71a306837 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt @@ -84,6 +84,7 @@ import com.google.ai.edge.gallery.firebaseAnalytics import com.google.ai.edge.gallery.ui.common.BaseGalleryWebViewClient import com.google.ai.edge.gallery.ui.common.GalleryWebView import com.google.ai.edge.gallery.ui.common.buildTrackableUrlAnnotatedString +import com.google.ai.edge.gallery.ui.common.chat.ChatMessage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageCollapsableProgressPanel import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText @@ -97,6 +98,7 @@ import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel +import com.google.ai.edge.litertlm.Message import com.google.ai.edge.litertlm.tool import java.lang.Exception import kotlin.coroutines.resume @@ -190,7 +192,7 @@ fun AgentChatScreen( updateProgressPanel(viewModel = viewModel, model = model, agentTools = agentTools) }, - onResetSessionClickedOverride = { task, model -> + onResetSessionClickedOverride = { task, _, initialMessages -> resetSessionWithCurrentSkills( viewModel, modelManagerViewModel, @@ -198,6 +200,7 @@ fun AgentChatScreen( task, curSystemPrompt, agentTools, + initialMessages = initialMessages, ) }, onSkillClicked = { showSkillManagerBottomSheet = true }, @@ -596,8 +599,18 @@ private fun resetSessionWithCurrentSkills( curSystemPrompt: String, agentTools: AgentTools, onDone: (Model) -> Unit = {}, + initialMessages: List = listOf(), ) { val model = modelManagerViewModel.uiState.value.selectedModel + val litertMessages = initialMessages.mapNotNull { chatMessage -> + if (chatMessage is ChatMessageText) { + if (chatMessage.side == ChatSide.USER) { + Message.user(chatMessage.content) + } else { + Message.model(chatMessage.content) + } + } else null + } viewModel.resetSession( task = task, model = model, @@ -607,6 +620,7 @@ private fun resetSessionWithCurrentSkills( supportAudio = true, onDone = { onDone(model) }, enableConversationConstrainedDecoding = true, + initialMessages = litertMessages, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt index c3c659676..b5cb4b40e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt @@ -20,6 +20,7 @@ import android.content.Context import android.graphics.Bitmap import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.Message import com.google.ai.edge.litertlm.ToolProvider import kotlinx.coroutines.CoroutineScope @@ -79,6 +80,7 @@ interface LlmModelHelper { systemInstruction: Contents? = null, tools: List = listOf(), enableConversationConstrainedDecoding: Boolean = false, + initialMessages: List = listOf(), ) /** diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt index 9ba60e392..84e476310 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt @@ -30,6 +30,8 @@ import com.google.ai.edge.gallery.runtime.CleanUpListener import com.google.ai.edge.gallery.runtime.LlmModelHelper import com.google.ai.edge.gallery.runtime.ResultListener import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.Message +import com.google.ai.edge.litertlm.Role import com.google.ai.edge.litertlm.ToolProvider import com.google.mlkit.genai.common.DownloadStatus import com.google.mlkit.genai.common.FeatureStatus @@ -195,10 +197,16 @@ object AICoreModelHelper : LlmModelHelper { systemInstruction: Contents?, tools: List, enableConversationConstrainedDecoding: Boolean, + initialMessages: List, ) { Log.d(TAG, "Resetting conversation for model '${model.name}'") val instance = model.instance as? AICoreModelInstance ?: return instance.chatHistory.clear() + for (msg in initialMessages) { + instance.chatHistory.add( + AICoreChatMessage(isUser = (msg.role == Role.USER), text = msg.contents.toString()) + ) + } Log.d(TAG, "Resetting done") } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt index 8258c7ddf..9b07a6865 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt @@ -82,7 +82,6 @@ import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import java.util.UUID import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.delay import kotlinx.coroutines.launch private const val TAG = "AGChatView" @@ -108,7 +107,7 @@ fun ChatView( onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit, navigateUp: () -> Unit, modifier: Modifier = Modifier, - onResetSessionClicked: (Model) -> Unit = {}, + onResetSessionClicked: (Model, List, () -> Unit) -> Unit = { _, _, _ -> }, onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> }, onStopButtonClicked: (Model) -> Unit = {}, onSkillClicked: () -> Unit = {}, @@ -138,7 +137,6 @@ fun ChatView( remember(allHistorySessions, task.id) { allHistorySessions.filter { it.taskId == task.id } } val context = LocalContext.current - var feedFullHistoryOnNextMessage by remember { mutableStateOf(false) } val currentMessages = uiState.messagesByModel[selectedModel.name] ?: emptyList() LaunchedEffect(uiState.inProgress) { @@ -217,16 +215,14 @@ fun ChatView( }, ) - onResetSessionClicked(selectedModel) - viewModel.clearAllMessages(selectedModel) - val messages = deserializeProtoMessages(session.messagesList) - for (msg in messages) { - viewModel.addMessage(selectedModel, msg) + onResetSessionClicked(selectedModel, messages) { + for (msg in messages) { + viewModel.addMessage(selectedModel, msg) + } } viewModel.currentSessionId = session.sessionId - feedFullHistoryOnNextMessage = true } scope.launch { drawerState.close() } }, @@ -247,7 +243,7 @@ fun ChatView( }, ) - onResetSessionClicked(selectedModel) + onResetSessionClicked(selectedModel, emptyList()) {} viewModel.currentSessionId = UUID.randomUUID().toString() scope.launch { drawerState.close() } }, @@ -339,38 +335,7 @@ fun ChatView( viewModel = viewModel, innerPadding = innerPadding, navigateUp = navigateUp, - // TODO(zichuanwei): Update the logic here to use the proper litertlm api. - // the current logic is to be compatible with AICore logic, as AI core doesn't - // support message preloading or multi-turn conversations. - onSendMessage = { model, messages -> - if (feedFullHistoryOnNextMessage) { - feedFullHistoryOnNextMessage = false - val history = uiState.messagesByModel[model.name] ?: emptyList() - val originalShortMessage = messages.lastOrNull() as? ChatMessageText - val combinedMessage = - if (originalShortMessage != null) { - buildFirstMessageWithHistory(history, originalShortMessage) - } else null - if (combinedMessage != null) { - val modifiedList = messages.dropLast(1) + combinedMessage - onSendMessage(model, modifiedList) - - // Revert the visible UI message back to the short one - scope.launch(Dispatchers.Default) { - delay(100) - viewModel.replaceLastMessage( - model, - originalShortMessage!!, - ChatMessageType.TEXT, - ) - } - } else { - onSendMessage(model, messages) - } - } else { - onSendMessage(model, messages) - } - }, + onSendMessage = { model, messages -> onSendMessage(model, messages) }, onRunAgainClicked = onRunAgainClicked, onBenchmarkClicked = onBenchmarkClicked, onStreamImageMessage = onStreamImageMessage, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index 352136e8c..577da922e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -190,6 +190,7 @@ object LlmChatModelHelper : LlmModelHelper { systemInstruction: Contents?, tools: List, enableConversationConstrainedDecoding: Boolean, + initialMessages: List, ) { try { Log.d(TAG, "Resetting conversation for model '${model.name}'") @@ -228,6 +229,7 @@ object LlmChatModelHelper : LlmModelHelper { }, systemInstruction = systemInstruction, tools = tools, + initialMessages = initialMessages, ) ) ExperimentalFlags.enableConversationConstrainedDecoding = false diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 27bf70fa7..0003f5d4e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -43,15 +43,18 @@ import com.google.ai.edge.gallery.data.ModelCapability import com.google.ai.edge.gallery.data.RuntimeType import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.firebaseAnalytics +import com.google.ai.edge.gallery.ui.common.chat.ChatMessage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText +import com.google.ai.edge.gallery.ui.common.chat.ChatSide import com.google.ai.edge.gallery.ui.common.chat.ChatView import com.google.ai.edge.gallery.ui.common.chat.SendMessageTrigger import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.theme.emptyStateContent import com.google.ai.edge.gallery.ui.theme.emptyStateTitle import com.google.ai.edge.litertlm.Contents +import com.google.ai.edge.litertlm.Message private const val TAG = "AGLlmChatScreen" @@ -64,7 +67,7 @@ fun LlmChatScreen( onFirstToken: (Model) -> Unit = {}, onGenerateResponseDone: (Model) -> Unit = {}, onSkillClicked: () -> Unit = {}, - onResetSessionClickedOverride: ((Task, Model) -> Unit)? = null, + onResetSessionClickedOverride: ((Task, Model, List) -> Unit)? = null, composableBelowMessageList: @Composable (Model) -> Unit = {}, viewModel: LlmChatViewModel = hiltViewModel(), allowEditingSystemPrompt: Boolean = false, @@ -198,7 +201,7 @@ fun ChatViewWrapper( onSkillClicked: () -> Unit = {}, onFirstToken: (Model) -> Unit = {}, onGenerateResponseDone: (Model) -> Unit = {}, - onResetSessionClickedOverride: ((Task, Model) -> Unit)? = null, + onResetSessionClickedOverride: ((Task, Model, List) -> Unit)? = null, composableBelowMessageList: @Composable (Model) -> Unit = {}, emptyStateComposable: @Composable (Model) -> Unit = {}, allowEditingSystemPrompt: Boolean = false, @@ -297,9 +300,11 @@ fun ChatViewWrapper( } }, onBenchmarkClicked = { _, _, _, _ -> }, - onResetSessionClicked = { model -> + onResetSessionClicked = { model, chatMessages, onDone -> + val litertMessages = chatMessages.mapNotNull { convertToLitertMessage(it) } if (onResetSessionClickedOverride != null) { - onResetSessionClickedOverride(task, model) + onResetSessionClickedOverride(task, model, chatMessages) + onDone() } else { viewModel.resetSession( task = task, @@ -307,6 +312,8 @@ fun ChatViewWrapper( systemInstruction = Contents.of(curSystemPrompt), supportImage = showImagePicker, supportAudio = showAudioPicker, + initialMessages = litertMessages, + onDone = onDone, ) } }, @@ -325,3 +332,15 @@ fun ChatViewWrapper( showAudioPicker = showAudioPicker, ) } + +private fun convertToLitertMessage(chatMessage: ChatMessage): Message? { + if (chatMessage is ChatMessageText) { + return when (chatMessage.side) { + ChatSide.USER -> Message.user(chatMessage.content) + ChatSide.AGENT -> Message.model(chatMessage.content) + ChatSide.SYSTEM -> + null // TODO: Support SYSTEM role once we can decide on which system prompt to use. + } + } + return null +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index 8597c16b3..a17579e57 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -41,6 +41,7 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.litertlm.Contents import com.google.ai.edge.litertlm.ExperimentalApi +import com.google.ai.edge.litertlm.Message import com.google.ai.edge.litertlm.ToolProvider import dagger.hilt.android.lifecycle.HiltViewModel import javax.inject.Inject @@ -324,6 +325,7 @@ open class LlmChatViewModelBase( supportAudio: Boolean = false, onDone: () -> Unit = {}, enableConversationConstrainedDecoding: Boolean = false, + initialMessages: List = listOf(), ) { viewModelScope.launch(Dispatchers.Default) { setIsResettingSession(true) @@ -339,6 +341,7 @@ open class LlmChatViewModelBase( systemInstruction = systemInstruction, tools = tools, enableConversationConstrainedDecoding = enableConversationConstrainedDecoding, + initialMessages = initialMessages, ) break } catch (e: Exception) {