diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelFeedback.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelFeedback.kt new file mode 100644 index 000000000..c4b907673 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelFeedback.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.data + +import com.google.gson.annotations.SerializedName +import javax.inject.Qualifier + +/** Hilt qualifier annotation for feedback API key binding. */ +@Qualifier @Retention(AnnotationRetention.BINARY) annotation class FeedbackApiKey + +/** Interface to fetch OAuth Bearer credentials on-device. */ +interface AuthTokenProvider { + suspend fun getAuthToken(scope: String): String? +} + +/** Enum matching Feedback Oneplatform MicrofeedbackScore values for lightweight sentiment. */ +enum class MicrofeedbackScore { + @SerializedName("SCORE_UNSPECIFIED") SCORE_UNSPECIFIED, + @SerializedName("SCORE0") SCORE0, + @SerializedName("SCORE1") SCORE1, + @SerializedName("SCORE2") SCORE2, + @SerializedName("SCORE3") SCORE3, + @SerializedName("SCORE4") SCORE4, + @SerializedName("SCORE5") SCORE5, +} + +/** A key-value pair for Product Specific Data (PSD) metadata attachment. */ +data class ModelFeedbackPsdData( + @SerializedName("key") val key: String, + @SerializedName("value") val value: String, +) + +/** Product metadata and environment info where the feedback was collected. */ +data class ModelFeedbackProductInfo( + @SerializedName("ui_language") val uiLanguage: String = "en-US", + @SerializedName("product_version") val productVersion: String, + @SerializedName("product_specific_data") val productSpecificData: List, +) + +/** Core user entry details, including comment text and lightweight sentiment scores. */ +data class ModelFeedbackDataPayload( + @SerializedName("description") val description: String, + @SerializedName("microfeedback_score") val microfeedbackScore: MicrofeedbackScore, +) + +/** DTO request body for the Feedback Oneplatform SubmitFeedback RPC public endpoint. */ +data class ModelFeedbackRequest( + @SerializedName("product_id") val productId: Int = 5372309, + @SerializedName("bucket_id") val bucketId: String = "android-agent-chat-feedback", + @SerializedName("product_info") val productInfo: ModelFeedbackProductInfo, + @SerializedName("feedback_data") val feedbackData: ModelFeedbackDataPayload, +) diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelFeedbackRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelFeedbackRepository.kt new file mode 100644 index 000000000..1aa5dca73 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelFeedbackRepository.kt @@ -0,0 +1,169 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.data + +import android.util.Log +import com.google.ai.edge.gallery.BuildConfig +import com.google.gson.Gson +import java.io.OutputStreamWriter +import java.net.HttpURLConnection +import java.net.URL +import javax.inject.Inject +import javax.inject.Singleton +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext + +private const val TAG = "AGModelFeedbackRepo" + +/** + * Repository for packaging and submitting user feedback on model responses to the Oneplatform API. + */ +@Singleton +class ModelFeedbackRepository +@Inject +constructor( + private val authTokenProvider: AuthTokenProvider, + @FeedbackApiKey private val apiKey: String, +) { + + /** + * Submits user rating and conversational metadata to the Feedback Oneplatform service. + * + * @param isPositive True if Thumbs Up ( SCORE5 ), false if Thumbs Down ( SCORE0 ). + * @param description Free text user comment entered in the dialog. + * @param selectedChips Categorical taxonomical chips chosen by the user. + * @param userPrompt Prompt that triggered the response. + * @param modelResponse Agent answer being rated. + * @param modelId Unique name of the model. + * @param modelVersion Active version identifier of the model. + * @param temperature Generative temperature model parameter. + * @param topK Top K model parameter. + * @param topP Top P model parameter. + * @param extraPsd Map of any additional key-value pairs specific to the feature (e.g. + * feature_card). + * @param conversationHistory Full formatted conversation logs up to the rated agent answer. + */ + @Suppress("AndroidLintDispatcherUsage") + suspend fun submitFeedback( + isPositive: Boolean, + description: String, + selectedChips: List, + userPrompt: String, + modelResponse: String, + modelId: String, + modelVersion: String, + temperature: String, + topK: String, + topP: String, + extraPsd: Map = emptyMap(), + conversationHistory: String, + ): Result = + withContext(Dispatchers.IO) { + try { + // Retrieve the OAuth Bearer Token with the supportcontent scope + val scope = "oauth2:https://www.googleapis.com/auth/supportcontent" + val token = authTokenProvider.getAuthToken(scope) + Log.d(TAG, "Fetched OAuth token present: ${token != null} (scope: $scope)") + + // TODO: Remove this short-circuit block once we configure an active FeedbackApiKey in + // AppModule.kt + if (token == null && apiKey.isEmpty()) { + Log.w( + TAG, + "No OAuth token or API Key provided. Short-circuiting to simulate successful sandbox submission for local developer testing.", + ) + return@withContext Result.success(Unit) + } + + val score = if (isPositive) MicrofeedbackScore.SCORE5 else MicrofeedbackScore.SCORE0 + + // Construct tabular metadata key-value pairs + val psdList = + mutableListOf( + ModelFeedbackPsdData("model_id", modelId), + ModelFeedbackPsdData("model_version", modelVersion), + ModelFeedbackPsdData("temperature", temperature), + ModelFeedbackPsdData("top_k", topK), + ModelFeedbackPsdData("top_p", topP), + ModelFeedbackPsdData("selected_chips", selectedChips.joinToString(",")), + ModelFeedbackPsdData("app_version", BuildConfig.VERSION_NAME), + ModelFeedbackPsdData("user_prompt", userPrompt), + ModelFeedbackPsdData("model_response", modelResponse), + ModelFeedbackPsdData("conversation_history", conversationHistory), + ) + + // Merge extra PSD fields + for ((key, value) in extraPsd) { + psdList.add(ModelFeedbackPsdData(key, value)) + } + + val productInfo = + ModelFeedbackProductInfo( + uiLanguage = "en-US", + productVersion = BuildConfig.VERSION_NAME, + productSpecificData = psdList, + ) + + val payload = + ModelFeedbackDataPayload(description = description, microfeedbackScore = score) + + val request = + ModelFeedbackRequest( + productId = 5372309, + bucketId = "android-agent-chat-feedback", + productInfo = productInfo, + feedbackData = payload, + ) + + // Staging public network REST submission endpoint + var urlString = + "https://stagingqual-feedback-pa-googleapis.sandbox.google.com/v1/feedback/products/5372309:submit" + if (token == null && apiKey.isNotEmpty()) { + urlString += "?key=$apiKey" + } + val url = URL(urlString) + val connection = url.openConnection() as HttpURLConnection + connection.requestMethod = "POST" + connection.doOutput = true + connection.setRequestProperty("Content-Type", "application/json; charset=utf-8") + if (token != null) { + connection.setRequestProperty("Authorization", "Bearer $token") + } + + val json = Gson().toJson(request) + Log.d(TAG, "Feedback JSON Request Payload: $json") + OutputStreamWriter(connection.outputStream, "UTF-8").use { writer -> + writer.write(json) + writer.flush() + } + + val responseCode = connection.responseCode + Log.d(TAG, "Feedback submission HTTP Response Code: $responseCode") + if (responseCode in 200..299) { + Result.success(Unit) + } else { + val errorMsg = connection.errorStream?.bufferedReader()?.readText() ?: "Unknown error" + Result.failure( + Exception("Feedback submission failed with response code: $responseCode - $errorMsg") + ) + } + } catch (e: Exception) { + Log.e(TAG, "Error occurred during feedback submission", e) + Result.failure(e) + } + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt index 5ddb0f7f6..6929aef43 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/di/AppModule.kt @@ -28,10 +28,12 @@ import com.google.ai.edge.gallery.GalleryLifecycleProvider import com.google.ai.edge.gallery.SettingsSerializer import com.google.ai.edge.gallery.SkillsSerializer import com.google.ai.edge.gallery.UserDataSerializer +import com.google.ai.edge.gallery.data.AuthTokenProvider import com.google.ai.edge.gallery.data.DataStoreRepository import com.google.ai.edge.gallery.data.DefaultDataStoreRepository import com.google.ai.edge.gallery.data.DefaultDownloadRepository import com.google.ai.edge.gallery.data.DownloadRepository +import com.google.ai.edge.gallery.data.FeedbackApiKey import com.google.ai.edge.gallery.proto.BenchmarkResults import com.google.ai.edge.gallery.proto.CutoutCollection import com.google.ai.edge.gallery.proto.Settings @@ -183,4 +185,24 @@ internal object AppModule { ): DownloadRepository { return DefaultDownloadRepository(context, lifecycleProvider) } + + // Provides AuthTokenProvider stub implementation + @Provides + @Singleton + fun provideAuthTokenProvider(): AuthTokenProvider { + return object : AuthTokenProvider { + override suspend fun getAuthToken(scope: String): String? { + return null + } + } + } + + // Provides FeedbackApiKey + @Provides + @Singleton + @FeedbackApiKey + fun provideFeedbackApiKey(): String { + // TODO: Add the staging/sandbox Listnr API key here for anonymous feedback submissions + return "" + } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt index b6dffbc47..c2758872c 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatMessage.kt @@ -63,15 +63,20 @@ open class ChatMessage( open val hideSenderLabel: Boolean = false, open val disableBubbleShape: Boolean = false, ) { + var feedbackRating: Boolean? = null + open fun clone(): ChatMessage { - return ChatMessage( - type = type, - side = side, - latencyMs = latencyMs, - accelerator = accelerator, - hideSenderLabel = hideSenderLabel, - disableBubbleShape = disableBubbleShape, - ) + val cloned = + ChatMessage( + type = type, + side = side, + latencyMs = latencyMs, + accelerator = accelerator, + hideSenderLabel = hideSenderLabel, + disableBubbleShape = disableBubbleShape, + ) + cloned.feedbackRating = feedbackRating + return cloned } } @@ -126,16 +131,19 @@ open class ChatMessageText( hideSenderLabel = hideSenderLabel, ) { override fun clone(): ChatMessageText { - return ChatMessageText( - content = content, - side = side, - latencyMs = latencyMs, - accelerator = accelerator, - isMarkdown = isMarkdown, - llmBenchmarkResult = llmBenchmarkResult, - hideSenderLabel = hideSenderLabel, - data = data, - ) + val cloned = + ChatMessageText( + content = content, + side = side, + latencyMs = latencyMs, + accelerator = accelerator, + isMarkdown = isMarkdown, + llmBenchmarkResult = llmBenchmarkResult, + hideSenderLabel = hideSenderLabel, + data = data, + ) + cloned.feedbackRating = feedbackRating + return cloned } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt index 4401f3484..117a1b0c3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatPanel.kt @@ -17,18 +17,16 @@ package com.google.ai.edge.gallery.ui.common.chat import android.graphics.Bitmap +import android.util.Log import androidx.compose.animation.AnimatedVisibility -import androidx.compose.animation.core.FastOutSlowInEasing import androidx.compose.animation.core.Spring import androidx.compose.animation.core.VisibilityThreshold import androidx.compose.animation.core.spring -import androidx.compose.animation.core.tween import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeOut import androidx.compose.animation.scaleIn import androidx.compose.animation.scaleOut import androidx.compose.animation.slideInVertically -import androidx.compose.foundation.ScrollState import androidx.compose.foundation.background import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box @@ -42,14 +40,21 @@ import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.imePadding import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.wrapContentWidth import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.verticalScroll import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ThumbDown +import androidx.compose.material.icons.filled.ThumbUp +import androidx.compose.material.icons.outlined.ThumbDown +import androidx.compose.material.icons.outlined.ThumbUp import androidx.compose.material.icons.outlined.Timer import androidx.compose.material.icons.rounded.Refresh import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme import androidx.compose.material3.SnackbarHost import androidx.compose.material3.SnackbarHostState @@ -77,7 +82,7 @@ import androidx.compose.ui.graphics.graphicsLayer import androidx.compose.ui.input.nestedscroll.NestedScrollConnection import androidx.compose.ui.input.nestedscroll.NestedScrollSource import androidx.compose.ui.input.nestedscroll.nestedScroll -import androidx.compose.ui.layout.onSizeChanged +import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.res.dimensionResource @@ -103,10 +108,8 @@ import com.google.ai.edge.gallery.ui.modelmanager.ModelManagerViewModel import com.google.ai.edge.gallery.ui.theme.customColors import kotlinx.coroutines.android.awaitFrame import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.collectLatest -import kotlinx.coroutines.launch -private const val SCROLL_ANIMATION_DURATION_MS = 300 +private const val TAG = "AGChatPanel" /** Composable function for the main chat panel, displaying messages and handling user input. */ @OptIn(ExperimentalMaterial3Api::class) @@ -131,6 +134,12 @@ fun ChatPanel( showImagePicker: Boolean = false, showAudioPicker: Boolean = false, emptyStateComposable: @Composable (Model) -> Unit = {}, + onFeedbackSubmitted: + ( + isPositive: Boolean, comment: String, selectedChips: List, agentMessageIndex: Int, + ) -> Unit = + { _, _, _, _ -> + }, ) { val uiState by viewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() @@ -167,6 +176,7 @@ fun ChatPanel( var curMessage by remember { mutableStateOf("") } // Correct state val focusManager = LocalFocusManager.current + val context = LocalContext.current // List state to control scrolling. val listState = rememberScrollState() @@ -175,6 +185,9 @@ fun ChatPanel( val benchmarkMessage: MutableState = remember { mutableStateOf(null) } var showErrorDialog by remember { mutableStateOf(false) } + var showFeedbackDialog by remember { mutableStateOf(false) } + var isPositiveFeedback by remember { mutableStateOf(true) } + var feedbackMessageIndex by remember { mutableIntStateOf(-1) } var showAudioRecorder by remember { mutableStateOf(false) } var curAmplitude by remember { mutableIntStateOf(0) } @@ -514,6 +527,50 @@ fun ChatPanel( horizontalArrangement = Arrangement.spacedBy(8.dp), ) { LatencyText(message = message) + if (message is ChatMessageText && !uiState.inProgress) { + val isUpHighlighted = message.feedbackRating == true + val isDownHighlighted = message.feedbackRating == false + + IconButton( + onClick = { + Log.d(TAG, "Thumbs Up clicked on response at index: $index") + isPositiveFeedback = true + feedbackMessageIndex = index + showFeedbackDialog = true + }, + modifier = Modifier.size(28.dp), + ) { + Icon( + imageVector = + if (isUpHighlighted) Icons.Filled.ThumbUp else Icons.Outlined.ThumbUp, + contentDescription = "Thumbs Up", + tint = + if (isUpHighlighted) MaterialTheme.colorScheme.primary + else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.6f), + modifier = Modifier.size(18.dp), + ) + } + IconButton( + onClick = { + Log.d(TAG, "Thumbs Down clicked on response at index: $index") + isPositiveFeedback = false + feedbackMessageIndex = index + showFeedbackDialog = true + }, + modifier = Modifier.size(28.dp), + ) { + Icon( + imageVector = + if (isDownHighlighted) Icons.Filled.ThumbDown + else Icons.Outlined.ThumbDown, + contentDescription = "Thumbs Down", + tint = + if (isDownHighlighted) MaterialTheme.colorScheme.primary + else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.6f), + modifier = Modifier.size(18.dp), + ) + } + } } } else if (message.side == ChatSide.USER) { Row( @@ -692,19 +749,52 @@ fun ChatPanel( }, ) } -} -private suspend fun scrollToBottom( - listState: ScrollState, - animate: Boolean = false, - animationDurationMs: Int = SCROLL_ANIMATION_DURATION_MS, -) { - if (animate) { - listState.animateScrollTo( - listState.maxValue, - animationSpec = tween(durationMillis = animationDurationMs, easing = FastOutSlowInEasing), + // Feedback dialog. + if (showFeedbackDialog) { + val positiveChips = + if (task.id == BuiltInTaskId.LLM_CHAT) { + listOf("Accurate", "Easy to Understand", "Creative", "Informative", "Other") + } else { + listOf("Correct Action", "Helpful", "Fast", "Successful", "Other") + } + val negativeChips = + if (task.id == BuiltInTaskId.LLM_CHAT) { + listOf( + "Factually Inaccurate", + "Offensive/Unsafe", + "Didn't follow instructions", + "Too long", + "Other", + ) + } else { + listOf("Incorrect Action", "Failed execution", "Slow", "Confusing Behavior", "Other") + } + val chips = if (isPositiveFeedback) positiveChips else negativeChips + + FeedbackDialog( + isPositive = isPositiveFeedback, + chips = chips, + onDismiss = { showFeedbackDialog = false }, + onSubmit = { comment, selectedChips -> + Log.d( + TAG, + "FeedbackDialog onSubmit. Comment length: ${comment.length}, selected chips: ${selectedChips.joinToString(",")}", + ) + showFeedbackDialog = false + onFeedbackSubmitted(isPositiveFeedback, comment, selectedChips, feedbackMessageIndex) + }, ) - } else { - listState.scrollTo(listState.maxValue) + } +} + +private suspend fun scrollToBottom(listState: LazyListState, animate: Boolean = false) { + val itemCount = listState.layoutInfo.totalItemsCount + if (itemCount > 0) { + if (animate) { + listState.animateScrollToItem(itemCount - 1, scrollOffset = 1000000) + } else { + listState.scrollToItem(itemCount - 1, scrollOffset = 1000000) + } } } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt index d8cf3790c..c3c885ef9 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ChatView.kt @@ -118,6 +118,12 @@ fun ChatView( curSystemPrompt: String = "", onSystemPromptChanged: (String) -> Unit = {}, sendMessageTrigger: SendMessageTrigger? = null, + onFeedbackSubmitted: + ( + isPositive: Boolean, comment: String, selectedChips: List, agentMessageIndex: Int, + ) -> Unit = + { _, _, _, _ -> + }, ) { val uiState by viewModel.uiState.collectAsState() val modelManagerUiState by modelManagerViewModel.uiState.collectAsState() @@ -349,6 +355,7 @@ fun ChatView( showImagePicker = showImagePicker, showAudioPicker = showAudioPicker, emptyStateComposable = emptyStateComposable, + onFeedbackSubmitted = onFeedbackSubmitted, ) // Model download false -> diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ModelFeedbackDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ModelFeedbackDialog.kt new file mode 100644 index 000000000..9f073fe34 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/chat/ModelFeedbackDialog.kt @@ -0,0 +1,136 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common.chat + +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.ExperimentalLayoutApi +import androidx.compose.foundation.layout.FlowRow +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.material3.Button +import androidx.compose.material3.Checkbox +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.FilterChip +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ModalBottomSheet +import androidx.compose.material3.OutlinedTextField +import androidx.compose.material3.Text +import androidx.compose.material3.TextButton +import androidx.compose.material3.rememberModalBottomSheetState +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateListOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.unit.dp + +/** + * A Jetpack Compose bottom sheet for collecting specific user feedback on model responses. + * + * @param isPositive True if the user clicked "Thumbs Up", false if "Thumbs Down". + * @param chips Categorical selection chips (preset issues). + * @param onDismiss Callback invoked when the bottom sheet is dismissed or canceled. + * @param onSubmit Callback invoked when the user submits their feedback. + */ +@OptIn(ExperimentalLayoutApi::class, ExperimentalMaterial3Api::class) +@Composable +fun FeedbackDialog( + isPositive: Boolean, + chips: List, + onDismiss: () -> Unit, + onSubmit: (comment: String, selectedChips: List) -> Unit, + modifier: Modifier = Modifier, +) { + var comment by remember { mutableStateOf("") } + val selectedChips = remember { mutableStateListOf() } + var legalChecked by remember { mutableStateOf(false) } + + // Enabled only after at least one chip is selected OR text is provided in open-comment field + // AND the legal disclosure consent checkbox is checked. + val isSubmitEnabled = (selectedChips.isNotEmpty() || comment.isNotBlank()) && legalChecked + + val sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true) + + ModalBottomSheet(onDismissRequest = onDismiss, sheetState = sheetState, modifier = modifier) { + Column( + modifier = + Modifier.fillMaxWidth().padding(horizontal = 24.dp).padding(top = 8.dp, bottom = 32.dp), + verticalArrangement = Arrangement.spacedBy(16.dp), + ) { + Text(text = "Why did you choose this rating?", style = MaterialTheme.typography.titleMedium) + + // Taxonomical issue chips + FlowRow( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalArrangement = Arrangement.spacedBy(8.dp), + modifier = Modifier.fillMaxWidth(), + ) { + chips.forEach { chipText -> + val isSelected = selectedChips.contains(chipText) + FilterChip( + selected = isSelected, + onClick = { + if (isSelected) { + selectedChips.remove(chipText) + } else { + selectedChips.add(chipText) + } + }, + label = { Text(chipText) }, + ) + } + } + + // Open-text comment field + OutlinedTextField( + value = comment, + onValueChange = { comment = it }, + label = { Text("Add comments (optional)") }, + placeholder = { Text("Tell us more about your experience...") }, + modifier = Modifier.fillMaxWidth(), + maxLines = 4, + ) + + // Legal disclosure checkbox + Row(verticalAlignment = Alignment.Top, modifier = Modifier.fillMaxWidth()) { + Checkbox(checked = legalChecked, onCheckedChange = { legalChecked = it }) + Text( + text = "[legal disclosure text]", + style = MaterialTheme.typography.bodySmall, + modifier = Modifier.padding(start = 8.dp, top = 4.dp), + ) + } + + // Action buttons + Row(horizontalArrangement = Arrangement.End, modifier = Modifier.fillMaxWidth()) { + TextButton(onClick = onDismiss) { Text("Cancel") } + Button( + onClick = { onSubmit(comment, selectedChips.toList()) }, + enabled = isSubmitEnabled, + modifier = Modifier.padding(start = 8.dp), + ) { + Text("Submit") + } + } + } + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt index 27bf70fa7..fd5a19ea1 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatScreen.kt @@ -323,5 +323,16 @@ fun ChatViewWrapper( onSystemPromptChanged = onSystemPromptChanged, sendMessageTrigger = sendMessageTrigger, showAudioPicker = showAudioPicker, + onFeedbackSubmitted = { isPositive, comment, selectedChips, agentMessageIndex -> + val model = modelManagerViewModel.uiState.value.selectedModel + viewModel.submitFeedback( + task = task, + model = model, + isPositive = isPositive, + comment = comment, + selectedChips = selectedChips, + agentMessageIndex = agentMessageIndex, + ) + }, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt index 8597c16b3..7fc526c1f 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatViewModel.kt @@ -19,11 +19,13 @@ package com.google.ai.edge.gallery.ui.llmchat import android.content.Context import android.graphics.Bitmap import android.util.Log +import android.widget.Toast import androidx.datastore.core.DataStore import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.common.SystemPromptHelper import com.google.ai.edge.gallery.data.ConfigKeys import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.ModelFeedbackRepository import com.google.ai.edge.gallery.data.SystemPromptRepository import com.google.ai.edge.gallery.data.Task import com.google.ai.edge.gallery.proto.UserData @@ -43,6 +45,7 @@ import com.google.ai.edge.litertlm.Contents import com.google.ai.edge.litertlm.ExperimentalApi import com.google.ai.edge.litertlm.ToolProvider import dagger.hilt.android.lifecycle.HiltViewModel +import dagger.hilt.android.qualifiers.ApplicationContext import javax.inject.Inject import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay @@ -56,6 +59,8 @@ private const val TAG = "AGLlmChatViewModel" open class LlmChatViewModelBase( private val systemPromptRepository: SystemPromptRepository? = null, userDataDataStore: DataStore? = null, + private val modelFeedbackRepository: ModelFeedbackRepository? = null, + protected val context: Context? = null, ) : ChatViewModel(userDataDataStore) { private val _uiSystemPrompt = MutableStateFlow("") val uiSystemPrompt = _uiSystemPrompt.asStateFlow() @@ -409,6 +414,116 @@ open class LlmChatViewModelBase( ) } } + + fun submitFeedback( + task: Task, + model: Model, + isPositive: Boolean, + comment: String, + selectedChips: List, + agentMessageIndex: Int, + ) { + val modelFeedbackRepository = modelFeedbackRepository ?: return + val messages = uiState.value.messagesByModel[model.name] ?: return + if (agentMessageIndex < 0 || agentMessageIndex >= messages.size) return + + viewModelScope.launch(Dispatchers.IO) { + try { + Log.d( + TAG, + "submitFeedback triggered in VM. isPositive: $isPositive, comment length: ${comment.length}, chips: ${selectedChips.joinToString(",")}", + ) + // Scan conversation history backwards to locate the corresponding user prompt. + var userPrompt = "" + for (j in (agentMessageIndex - 1) downTo 0) { + val msg = messages[j] + if (msg.side == ChatSide.USER && msg is ChatMessageText) { + userPrompt = msg.content + break + } + } + + // Format conversation history up to the rated agent message index + val chatHistoryText = StringBuilder() + for (j in 0..agentMessageIndex) { + val msg = messages[j] + val sender = + when (msg.side) { + ChatSide.USER -> "User" + ChatSide.AGENT -> "Agent" + ChatSide.SYSTEM -> "System" + } + if (msg is ChatMessageText) { + chatHistoryText.append("$sender: ${msg.content}\n\n") + } + } + + val modelResponse = (messages[agentMessageIndex] as? ChatMessageText)?.content ?: "" + + val temperature = model.getFloatConfigValue(ConfigKeys.TEMPERATURE, 0.7f).toString() + val topK = model.getIntConfigValue(ConfigKeys.TOPK, 40).toString() + val topP = model.getFloatConfigValue(ConfigKeys.TOPP, 0.9f).toString() + val modelVersion = "1.0" + + Log.d( + TAG, + "Submitting feedback. userPrompt length: ${userPrompt.length}, modelResponse length: ${modelResponse.length}, history length: ${chatHistoryText.length}, temperature: $temperature, topK: $topK, topP: $topP", + ) + + val result = + modelFeedbackRepository.submitFeedback( + isPositive = isPositive, + description = comment, + selectedChips = selectedChips, + userPrompt = userPrompt, + modelResponse = modelResponse, + modelId = model.name, + modelVersion = modelVersion, + temperature = temperature, + topK = topK, + topP = topP, + extraPsd = mapOf("feature_card" to task.id), + conversationHistory = chatHistoryText.toString(), + ) + + result + .onSuccess { + Log.i(TAG, "Feedback submitted successfully!") + viewModelScope.launch(Dispatchers.Main) { + context?.let { + Toast.makeText(it, "Thank you for submitting feedback (test)", Toast.LENGTH_SHORT) + .show() + } + } + val currentMessages = uiState.value.messagesByModel[model.name] + if ( + currentMessages != null && + agentMessageIndex >= 0 && + agentMessageIndex < currentMessages.size + ) { + val ratedMessage = currentMessages[agentMessageIndex].clone() + ratedMessage.feedbackRating = isPositive + replaceMessage(model = model, index = agentMessageIndex, message = ratedMessage) + } + } + .onFailure { e -> + Log.e(TAG, "Feedback submission failed", e) + viewModelScope.launch(Dispatchers.Main) { + context?.let { + Toast.makeText( + it, + "Network connection failed. Feedback not submitted.", + Toast.LENGTH_SHORT, + ) + .show() + } + } + } + } catch (e: Exception) { + Log.e(TAG, "Error submitting feedback", e) + } + } + } } @HiltViewModel @@ -417,7 +532,10 @@ class LlmChatViewModel constructor( systemPromptRepository: SystemPromptRepository, userDataDataStore: DataStore, -) : LlmChatViewModelBase(systemPromptRepository, userDataDataStore) + modelFeedbackRepository: ModelFeedbackRepository, + @ApplicationContext context: Context, +) : + LlmChatViewModelBase(systemPromptRepository, userDataDataStore, modelFeedbackRepository, context) @HiltViewModel class LlmAskImageViewModel @@ -425,7 +543,10 @@ class LlmAskImageViewModel constructor( systemPromptRepository: SystemPromptRepository, userDataDataStore: DataStore, -) : LlmChatViewModelBase(systemPromptRepository, userDataDataStore) + modelFeedbackRepository: ModelFeedbackRepository, + @ApplicationContext context: Context, +) : + LlmChatViewModelBase(systemPromptRepository, userDataDataStore, modelFeedbackRepository, context) @HiltViewModel class LlmAskAudioViewModel @@ -433,4 +554,7 @@ class LlmAskAudioViewModel constructor( systemPromptRepository: SystemPromptRepository, userDataDataStore: DataStore, -) : LlmChatViewModelBase(systemPromptRepository, userDataDataStore) + modelFeedbackRepository: ModelFeedbackRepository, + @ApplicationContext context: Context, +) : + LlmChatViewModelBase(systemPromptRepository, userDataDataStore, modelFeedbackRepository, context)