Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class ChatMessageImage(
override val latencyMs: Float = 0f,
override val accelerator: String = "",
override val hideSenderLabel: Boolean = false,
/**
* Caches the local absolute file paths to avoid redundant PNG compressions across session
* updates.
*/
var persistedPaths: List<String>? = null,
) :
ChatMessage(
type = ChatMessageType.IMAGE,
Expand All @@ -164,6 +169,7 @@ class ChatMessageImage(
latencyMs = latencyMs,
accelerator = accelerator,
hideSenderLabel = hideSenderLabel,
persistedPaths = persistedPaths?.toList(),
)
}
}
Expand All @@ -174,13 +180,18 @@ class ChatMessageAudioClip(
val sampleRate: Int,
override val side: ChatSide,
override val latencyMs: Float = 0f,
/**
* Caches the local absolute file path to bypass redundant disk write I/O during session saves.
*/
var persistedPath: String? = null,
) : ChatMessage(type = ChatMessageType.AUDIO_CLIP, side = side, latencyMs = latencyMs) {
override fun clone(): ChatMessageAudioClip {
return ChatMessageAudioClip(
audioData = audioData,
sampleRate = sampleRate,
side = side,
latencyMs = latencyMs,
persistedPath = persistedPath,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package com.google.ai.edge.gallery.ui.common.chat
// import com.google.ai.edge.gallery.ui.theme.GalleryTheme

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.os.Bundle
import android.util.Log
import androidx.activity.compose.BackHandler
Expand Down Expand Up @@ -80,9 +81,11 @@ import com.google.ai.edge.gallery.firebaseAnalytics
import com.google.ai.edge.gallery.ui.common.ModelPageAppBar
import com.google.ai.edge.gallery.ui.modelmanager.ModelInitializationStatusType
import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel
import java.io.File
import java.util.UUID
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

private const val TAG = "AGChatView"

Expand All @@ -107,7 +110,9 @@ fun ChatView(
onBenchmarkClicked: (Model, ChatMessage, Int, Int) -> Unit,
navigateUp: () -> Unit,
modifier: Modifier = Modifier,
onResetSessionClicked: (Model, List<ChatMessage>, () -> Unit) -> Unit = { _, _, _ -> },
onResetSessionClicked: (Model, List<ChatMessage>, () -> Unit) -> Unit = { _, _, onDone ->
onDone()
},
onStreamImageMessage: (Model, ChatMessageImage) -> Unit = { _, _ -> },
onStopButtonClicked: (Model) -> Unit = {},
onSkillClicked: () -> Unit = {},
Expand Down Expand Up @@ -146,6 +151,7 @@ fun ChatView(
messages = currentMessages,
originalModel = selectedModel.name,
taskId = task.id,
context = context,
)
}
}
Expand Down Expand Up @@ -215,19 +221,23 @@ fun ChatView(
},
)

val messages = deserializeProtoMessages(session.messagesList)
onResetSessionClicked(selectedModel, messages) {
for (msg in messages) {
viewModel.addMessage(selectedModel, msg)
scope.launch {
viewModel.setIsResettingSession(true)
val messages =
withContext(Dispatchers.IO) { deserializeProtoMessages(session.messagesList) }
onResetSessionClicked(selectedModel, messages) {
for (msg in messages) {
viewModel.addMessage(selectedModel, msg)
}
viewModel.setIsResettingSession(false)
}
viewModel.currentSessionId = session.sessionId
}

viewModel.currentSessionId = session.sessionId
}
scope.launch { drawerState.close() }
},
onHistoryItemDeleted = { sessionId -> viewModel.deleteSession(sessionId) },
onHistoryItemsDeleteAll = { viewModel.clearAllSessions() },
onHistoryItemDeleted = { sessionId -> viewModel.deleteSession(sessionId, context) },
onHistoryItemsDeleteAll = { viewModel.clearAllSessions(context) },
onNewChatClicked = {
Log.d(
TAG,
Expand Down Expand Up @@ -506,6 +516,37 @@ private fun deserializeProtoMessages(
"INFO" -> ChatMessageInfo(protoMsg.content)
"WARNING" -> ChatMessageWarning(protoMsg.content)
"ERROR" -> ChatMessageError(protoMsg.content)
"IMAGE" -> {
val bitmaps =
protoMsg.imageFilePathsList.mapNotNull { path -> BitmapFactory.decodeFile(path) }
if (bitmaps.isNotEmpty()) {
ChatMessageImage(
bitmaps = bitmaps,
imageBitMaps = bitmaps.map { it.asImageBitmap() },
side = side,
latencyMs = protoMsg.latencyMs,
accelerator = protoMsg.accelerator,
hideSenderLabel = protoMsg.hideSenderLabel,
persistedPaths = protoMsg.imageFilePathsList.toList(),
)
} else null
}
"AUDIO_CLIP" -> {
val firstAudio = protoMsg.audioClipsList.firstOrNull()
if (firstAudio != null) {
try {
ChatMessageAudioClip(
audioData = File(firstAudio.filePath).readBytes(),
sampleRate = firstAudio.sampleRate,
side = side,
latencyMs = protoMsg.latencyMs,
persistedPath = firstAudio.filePath,
)
} catch (e: Exception) {
null
}
} else null
}
else -> null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.ai.edge.gallery.ui.common.chat

import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.compose.ui.graphics.vector.ImageVector
import androidx.datastore.core.DataStore
Expand All @@ -24,10 +26,13 @@ import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.common.processLlmResponse
import com.google.ai.edge.gallery.data.ConfigKeys
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.proto.AudioMessageProto
import com.google.ai.edge.gallery.proto.ChatMessageProto
import com.google.ai.edge.gallery.proto.ChatSessionProto
import com.google.ai.edge.gallery.proto.ChatSideProto
import com.google.ai.edge.gallery.proto.UserData
import java.io.File
import java.io.FileOutputStream
import java.util.UUID
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
Expand Down Expand Up @@ -414,59 +419,114 @@ abstract class ChatViewModel(val userDataDataStore: DataStore<UserData>? = null)
messages: List<ChatMessage>,
originalModel: String,
taskId: String,
context: Context? = null,
) {
val firstTextMessage = messages.filterIsInstance<ChatMessageText>().firstOrNull()?.content
val title =
firstTextMessage?.take(30)?.let { if (it.length == 30) "$it..." else it }
?: "New Chat Session"

val protoMessages = messages.mapNotNull { msg ->
val builder = ChatMessageProto.newBuilder()
when (msg) {
is ChatMessageText -> {
builder
.setMessageType("TEXT")
.setContent(msg.content)
.setSide(mapChatSide(msg.side))
.setLatencyMs(msg.latencyMs)
.setAccelerator(msg.accelerator)
.setHideSenderLabel(msg.hideSenderLabel)
.setIsMarkdown(msg.isMarkdown)
}
is ChatMessageThinking -> {
builder
.setMessageType("THINKING")
.setContent(msg.content)
.setSide(mapChatSide(msg.side))
.setInProgress(msg.inProgress)
.setAccelerator(msg.accelerator)
.setHideSenderLabel(msg.hideSenderLabel)
}
is ChatMessageInfo -> {
builder.setMessageType("INFO").setContent(msg.content).setSide(mapChatSide(msg.side))
}
is ChatMessageWarning -> {
builder.setMessageType("WARNING").setContent(msg.content).setSide(mapChatSide(msg.side))
}
is ChatMessageError -> {
builder.setMessageType("ERROR").setContent(msg.content).setSide(mapChatSide(msg.side))
val messagesSnapshot = messages.toList()
viewModelScope.launch(Dispatchers.IO) {
val firstTextMessage =
messagesSnapshot.filterIsInstance<ChatMessageText>().firstOrNull()?.content
val title =
firstTextMessage?.take(30)?.let { if (it.length == 30) "$it..." else it }
?: "New Chat Session"

val protoMessages = messagesSnapshot.mapNotNull { msg ->
val builder = ChatMessageProto.newBuilder()
when (msg) {
is ChatMessageText -> {
builder
.setMessageType("TEXT")
.setContent(msg.content)
.setSide(mapChatSide(msg.side))
.setLatencyMs(msg.latencyMs)
.setAccelerator(msg.accelerator)
.setHideSenderLabel(msg.hideSenderLabel)
.setIsMarkdown(msg.isMarkdown)
}
is ChatMessageThinking -> {
builder
.setMessageType("THINKING")
.setContent(msg.content)
.setSide(mapChatSide(msg.side))
.setInProgress(msg.inProgress)
.setAccelerator(msg.accelerator)
.setHideSenderLabel(msg.hideSenderLabel)
}
is ChatMessageInfo -> {
builder.setMessageType("INFO").setContent(msg.content).setSide(mapChatSide(msg.side))
}
is ChatMessageWarning -> {
builder.setMessageType("WARNING").setContent(msg.content).setSide(mapChatSide(msg.side))
}
is ChatMessageError -> {
builder.setMessageType("ERROR").setContent(msg.content).setSide(mapChatSide(msg.side))
}
is ChatMessageImage -> {
builder
.setMessageType("IMAGE")
.setSide(mapChatSide(msg.side))
.setLatencyMs(msg.latencyMs)
synchronized(msg) {
val cachedPaths = msg.persistedPaths
if (cachedPaths != null) {
builder.addAllImageFilePaths(cachedPaths)
} else if (context != null) {
msg.persistedPaths = buildList {
msg.bitmaps.forEachIndexed { index, bitmap ->
val fileName = "img_${sessionId}_${System.currentTimeMillis()}_$index.png"
val file = File(context.cacheDir, fileName)
FileOutputStream(file).use { fos ->
bitmap.compress(Bitmap.CompressFormat.PNG, 100, fos)
}
add(file.absolutePath)
builder.addImageFilePaths(file.absolutePath)
}
}
}
}
}
is ChatMessageAudioClip -> {
builder
.setMessageType("AUDIO_CLIP")
.setSide(mapChatSide(msg.side))
.setLatencyMs(msg.latencyMs)
synchronized(msg) {
val cachedPath = msg.persistedPath
if (cachedPath != null) {
val audioProto =
AudioMessageProto.newBuilder()
.setFilePath(cachedPath)
.setSampleRate(msg.sampleRate)
.build()
builder.addAudioClips(audioProto)
} else if (context != null) {
val fileName = "audio_${sessionId}_${System.currentTimeMillis()}.pcm"
val file = File(context.cacheDir, fileName)
FileOutputStream(file).use { fos -> fos.write(msg.audioData) }
msg.persistedPath = file.absolutePath
val audioProto =
AudioMessageProto.newBuilder()
.setFilePath(file.absolutePath)
.setSampleRate(msg.sampleRate)
.build()
builder.addAudioClips(audioProto)
}
}
}
else -> return@mapNotNull null
}
else -> return@mapNotNull null
builder.build()
}
builder.build()
}

val sessionProto =
ChatSessionProto.newBuilder()
.setSessionId(sessionId)
.setTitle(title)
.setTimestampMs(System.currentTimeMillis())
.setOriginalModel(originalModel)
.setTaskId(taskId)
.addAllMessages(protoMessages)
.build()
val sessionProto =
ChatSessionProto.newBuilder()
.setSessionId(sessionId)
.setTitle(title)
.setTimestampMs(System.currentTimeMillis())
.setOriginalModel(originalModel)
.setTaskId(taskId)
.addAllMessages(protoMessages)
.build()

viewModelScope.launch(Dispatchers.IO) {
userDataDataStore?.updateData { userData ->
val currentSessions = userData.chatSessionsList.toMutableList()
currentSessions.removeAll { it.sessionId == sessionId }
Expand All @@ -481,8 +541,18 @@ abstract class ChatViewModel(val userDataDataStore: DataStore<UserData>? = null)
*
* @param sessionId The ID of the session to delete.
*/
fun deleteSession(sessionId: String) {
fun deleteSession(sessionId: String, context: Context? = null) {
viewModelScope.launch(Dispatchers.IO) {
if (context != null) {
val files = context.cacheDir.listFiles()
files?.forEach { file ->
if (
file.name.startsWith("img_${sessionId}_") || file.name.startsWith("audio_${sessionId}_")
) {
file.delete()
}
}
}
userDataDataStore?.updateData { userData ->
val currentSessions = userData.chatSessionsList.filter { it.sessionId != sessionId }
userData.toBuilder().clearChatSessions().addAllChatSessions(currentSessions).build()
Expand All @@ -491,8 +561,16 @@ abstract class ChatViewModel(val userDataDataStore: DataStore<UserData>? = null)
}

/** Clears all saved chat sessions from persistent storage. */
fun clearAllSessions() {
fun clearAllSessions(context: Context? = null) {
viewModelScope.launch(Dispatchers.IO) {
if (context != null) {
val files = context.cacheDir.listFiles()
files?.forEach { file ->
if (file.name.startsWith("img_") || file.name.startsWith("audio_")) {
file.delete()
}
}
}
userDataDataStore?.updateData { userData -> userData.toBuilder().clearChatSessions().build() }
}
}
Expand Down
12 changes: 12 additions & 0 deletions Android/src/app/src/main/proto/chat_history.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ message ChatMessageProto {
// Whether the message generation is still in progress (e.g., for thinking
// steps).
bool in_progress = 8;
// A list of paths to local PNG image files associated with this message.
repeated string image_file_paths = 9;
// A list of audio recordings associated with this message.
repeated AudioMessageProto audio_clips = 10;
}

// Represents an audio recording associated with a chat session message.
message AudioMessageProto {
// The file path to the local audio recording.
string file_path = 1;
// The sample rate of the audio data (e.g., 16000).
int32 sample_rate = 2;
}

// Represents a saved chat session containing a history of messages.
Expand Down
Loading