Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ import com.google.ai.edge.gallery.firebaseAnalytics
import com.google.ai.edge.gallery.ui.common.BaseGalleryWebViewClient
import com.google.ai.edge.gallery.ui.common.GalleryWebView
import com.google.ai.edge.gallery.ui.common.buildTrackableUrlAnnotatedString
import com.google.ai.edge.gallery.ui.common.chat.ChatMessage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageCollapsableProgressPanel
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
Expand All @@ -97,6 +98,7 @@ import com.google.ai.edge.gallery.ui.llmchat.LlmChatScreen
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.litertlm.Message
import com.google.ai.edge.litertlm.tool
import java.lang.Exception
import kotlin.coroutines.resume
Expand Down Expand Up @@ -190,14 +192,15 @@ fun AgentChatScreen(

updateProgressPanel(viewModel = viewModel, model = model, agentTools = agentTools)
},
onResetSessionClickedOverride = { task, model ->
onResetSessionClickedOverride = { task, _, initialMessages ->
resetSessionWithCurrentSkills(
viewModel,
modelManagerViewModel,
skillManagerViewModel,
task,
curSystemPrompt,
agentTools,
initialMessages = initialMessages,
)
},
onSkillClicked = { showSkillManagerBottomSheet = true },
Expand Down Expand Up @@ -596,8 +599,18 @@ private fun resetSessionWithCurrentSkills(
curSystemPrompt: String,
agentTools: AgentTools,
onDone: (Model) -> Unit = {},
initialMessages: List<ChatMessage> = listOf(),
) {
val model = modelManagerViewModel.uiState.value.selectedModel
val litertMessages = initialMessages.mapNotNull { chatMessage ->
if (chatMessage is ChatMessageText) {
if (chatMessage.side == ChatSide.USER) {
Message.user(chatMessage.content)
} else {
Message.model(chatMessage.content)
}
} else null
}
viewModel.resetSession(
task = task,
model = model,
Expand All @@ -607,6 +620,7 @@ private fun resetSessionWithCurrentSkills(
supportAudio = true,
onDone = { onDone(model) },
enableConversationConstrainedDecoding = true,
initialMessages = litertMessages,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import android.content.Context
import android.graphics.Bitmap
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.litertlm.Contents
import com.google.ai.edge.litertlm.Message
import com.google.ai.edge.litertlm.ToolProvider
import kotlinx.coroutines.CoroutineScope

Expand Down Expand Up @@ -79,6 +80,7 @@ interface LlmModelHelper {
systemInstruction: Contents? = null,
tools: List<ToolProvider> = listOf(),
enableConversationConstrainedDecoding: Boolean = false,
initialMessages: List<Message> = listOf(),
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import com.google.ai.edge.gallery.runtime.CleanUpListener
import com.google.ai.edge.gallery.runtime.LlmModelHelper
import com.google.ai.edge.gallery.runtime.ResultListener
import com.google.ai.edge.litertlm.Contents
import com.google.ai.edge.litertlm.Message
import com.google.ai.edge.litertlm.Role
import com.google.ai.edge.litertlm.ToolProvider
import com.google.mlkit.genai.common.DownloadStatus
import com.google.mlkit.genai.common.FeatureStatus
Expand Down Expand Up @@ -195,10 +197,16 @@ object AICoreModelHelper : LlmModelHelper {
systemInstruction: Contents?,
tools: List<ToolProvider>,
enableConversationConstrainedDecoding: Boolean,
initialMessages: List<Message>,
) {
Log.d(TAG, "Resetting conversation for model '${model.name}'")
val instance = model.instance as? AICoreModelInstance ?: return
instance.chatHistory.clear()
for (msg in initialMessages) {
instance.chatHistory.add(
AICoreChatMessage(isUser = (msg.role == Role.USER), text = msg.contents.toString())
)
}
Log.d(TAG, "Resetting done")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import java.util.UUID
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch

private const val TAG = "AGChatView"
Expand All @@ -108,7 +107,7 @@ fun ChatView(
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
onResetSessionClicked: (Model) -> Unit = {},
onResetSessionClicked: (Model, List<ChatMessage>, () -> Unit) -> Unit = { _, _, _ -> },
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStopButtonClicked: (Model) -> Unit = {},
onSkillClicked: () -> Unit = {},
Expand Down Expand Up @@ -138,7 +137,6 @@ fun ChatView(
remember(allHistorySessions, task.id) { allHistorySessions.filter { it.taskId == task.id } }

val context = LocalContext.current
var feedFullHistoryOnNextMessage by remember { mutableStateOf(false) }

val currentMessages = uiState.messagesByModel[selectedModel.name] ?: emptyList()
LaunchedEffect(uiState.inProgress) {
Expand Down Expand Up @@ -217,16 +215,14 @@ fun ChatView(
},
)

onResetSessionClicked(selectedModel)
viewModel.clearAllMessages(selectedModel)

val messages = deserializeProtoMessages(session.messagesList)
for (msg in messages) {
viewModel.addMessage(selectedModel, msg)
onResetSessionClicked(selectedModel, messages) {
for (msg in messages) {
viewModel.addMessage(selectedModel, msg)
}
}

viewModel.currentSessionId = session.sessionId
feedFullHistoryOnNextMessage = true
}
scope.launch { drawerState.close() }
},
Expand All @@ -247,7 +243,7 @@ fun ChatView(
},
)

onResetSessionClicked(selectedModel)
onResetSessionClicked(selectedModel, emptyList()) {}
viewModel.currentSessionId = UUID.randomUUID().toString()
scope.launch { drawerState.close() }
},
Expand Down Expand Up @@ -339,38 +335,7 @@ fun ChatView(
viewModel = viewModel,
innerPadding = innerPadding,
navigateUp = navigateUp,
// TODO(zichuanwei): Update the logic here to use the proper litertlm api.
// the current logic is to be compatible with AICore logic, as AI core doesn't
// support message preloading or multi-turn conversations.
onSendMessage = { model, messages ->
if (feedFullHistoryOnNextMessage) {
feedFullHistoryOnNextMessage = false
val history = uiState.messagesByModel[model.name] ?: emptyList()
val originalShortMessage = messages.lastOrNull() as? ChatMessageText
val combinedMessage =
if (originalShortMessage != null) {
buildFirstMessageWithHistory(history, originalShortMessage)
} else null
if (combinedMessage != null) {
val modifiedList = messages.dropLast(1) + combinedMessage
onSendMessage(model, modifiedList)

// Revert the visible UI message back to the short one
scope.launch(Dispatchers.Default) {
delay(100)
viewModel.replaceLastMessage(
model,
originalShortMessage!!,
ChatMessageType.TEXT,
)
}
} else {
onSendMessage(model, messages)
}
} else {
onSendMessage(model, messages)
}
},
onSendMessage = { model, messages -> onSendMessage(model, messages) },
onRunAgainClicked = onRunAgainClicked,
onBenchmarkClicked = onBenchmarkClicked,
onStreamImageMessage = onStreamImageMessage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ object LlmChatModelHelper : LlmModelHelper {
systemInstruction: Contents?,
tools: List<ToolProvider>,
enableConversationConstrainedDecoding: Boolean,
initialMessages: List<Message>,
) {
try {
Log.d(TAG, "Resetting conversation for model '${model.name}'")
Expand Down Expand Up @@ -228,6 +229,7 @@ object LlmChatModelHelper : LlmModelHelper {
},
systemInstruction = systemInstruction,
tools = tools,
initialMessages = initialMessages,
)
)
ExperimentalFlags.enableConversationConstrainedDecoding = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,18 @@ import com.google.ai.edge.gallery.data.ModelCapability
import com.google.ai.edge.gallery.data.RuntimeType
import com.google.ai.edge.gallery.data.Task
import com.google.ai.edge.gallery.firebaseAnalytics
import com.google.ai.edge.gallery.ui.common.chat.ChatMessage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageAudioClip
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageImage
import com.google.ai.edge.gallery.ui.common.chat.ChatMessageText
import com.google.ai.edge.gallery.ui.common.chat.ChatSide
import com.google.ai.edge.gallery.ui.common.chat.ChatView
import com.google.ai.edge.gallery.ui.common.chat.SendMessageTrigger
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.gallery.ui.theme.emptyStateContent
import com.google.ai.edge.gallery.ui.theme.emptyStateTitle
import com.google.ai.edge.litertlm.Contents
import com.google.ai.edge.litertlm.Message

private const val TAG = "AGLlmChatScreen"

Expand All @@ -64,7 +67,7 @@ fun LlmChatScreen(
onFirstToken: (Model) -> Unit = {},
onGenerateResponseDone: (Model) -> Unit = {},
onSkillClicked: () -> Unit = {},
onResetSessionClickedOverride: ((Task, Model) -> Unit)? = null,
onResetSessionClickedOverride: ((Task, Model, List<ChatMessage>) -> Unit)? = null,
composableBelowMessageList: @Composable (Model) -> Unit = {},
viewModel: LlmChatViewModel = hiltViewModel(),
allowEditingSystemPrompt: Boolean = false,
Expand Down Expand Up @@ -198,7 +201,7 @@ fun ChatViewWrapper(
onSkillClicked: () -> Unit = {},
onFirstToken: (Model) -> Unit = {},
onGenerateResponseDone: (Model) -> Unit = {},
onResetSessionClickedOverride: ((Task, Model) -> Unit)? = null,
onResetSessionClickedOverride: ((Task, Model, List<ChatMessage>) -> Unit)? = null,
composableBelowMessageList: @Composable (Model) -> Unit = {},
emptyStateComposable: @Composable (Model) -> Unit = {},
allowEditingSystemPrompt: Boolean = false,
Expand Down Expand Up @@ -297,16 +300,20 @@ fun ChatViewWrapper(
}
},
onBenchmarkClicked = { _, _, _, _ -> },
onResetSessionClicked = { model ->
onResetSessionClicked = { model, chatMessages, onDone ->
val litertMessages = chatMessages.mapNotNull { convertToLitertMessage(it) }
if (onResetSessionClickedOverride != null) {
onResetSessionClickedOverride(task, model)
onResetSessionClickedOverride(task, model, chatMessages)
onDone()
} else {
viewModel.resetSession(
task = task,
model = model,
systemInstruction = Contents.of(curSystemPrompt),
supportImage = showImagePicker,
supportAudio = showAudioPicker,
initialMessages = litertMessages,
onDone = onDone,
)
}
},
Expand All @@ -325,3 +332,15 @@ fun ChatViewWrapper(
showAudioPicker = showAudioPicker,
)
}

private fun convertToLitertMessage(chatMessage: ChatMessage): Message? {
if (chatMessage is ChatMessageText) {
return when (chatMessage.side) {
ChatSide.USER -> Message.user(chatMessage.content)
ChatSide.AGENT -> Message.model(chatMessage.content)
ChatSide.SYSTEM ->
null // TODO: Support SYSTEM role once we can decide on which system prompt to use.
}
}
return null
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import com.google.ai.edge.gallery.ui.common.chat.ChatViewModel
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import com.google.ai.edge.litertlm.Contents
import com.google.ai.edge.litertlm.ExperimentalApi
import com.google.ai.edge.litertlm.Message
import com.google.ai.edge.litertlm.ToolProvider
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
Expand Down Expand Up @@ -324,6 +325,7 @@ open class LlmChatViewModelBase(
supportAudio: Boolean = false,
onDone: () -> Unit = {},
enableConversationConstrainedDecoding: Boolean = false,
initialMessages: List<Message> = listOf(),
) {
viewModelScope.launch(Dispatchers.Default) {
setIsResettingSession(true)
Expand All @@ -339,6 +341,7 @@ open class LlmChatViewModelBase(
systemInstruction = systemInstruction,
tools = tools,
enableConversationConstrainedDecoding = enableConversationConstrainedDecoding,
initialMessages = initialMessages,
)
break
} catch (e: Exception) {
Expand Down
Loading