diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts index ba9759a0e..ff6e6e2cf 100644 --- a/Android/src/app/build.gradle.kts +++ b/Android/src/app/build.gradle.kts @@ -36,8 +36,8 @@ android { applicationId = "com.google.aiedge.gallery" minSdk = 31 targetSdk = 35 - versionCode = 29 - versionName = "1.0.12" + versionCode = 30 + versionName = "1.0.13" // Needed for HuggingFace auth workflows. // Use the scheme of the "Redirect URLs" in HuggingFace app. diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index 5a7173c37..15e0de78d 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -143,10 +143,10 @@ - + diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/SystemPromptHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/SystemPromptHelper.kt new file mode 100644 index 000000000..121b03e48 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/SystemPromptHelper.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.common + +import com.google.ai.edge.gallery.data.SystemPromptRepository +import com.google.ai.edge.gallery.data.Task +import kotlinx.coroutines.flow.firstOrNull + +/** Helper object for system prompt retrieval and compilation. */ +object SystemPromptHelper { + + /** + * Retrieves the effective system prompt for the given [Task]. + * + * Returns the user-defined custom prompt from the [SystemPromptRepository] if available; + * otherwise, falls back to the task's default system prompt. + * + * @param repo The optional [SystemPromptRepository] for custom overrides. If null, returns the + * default. + * @param task The target [Task] containing the identifier and the default fallback system prompt. + * @return A [String] representing the effective system prompt instructions. + */ + suspend fun getEffectiveSystemPrompt(repo: SystemPromptRepository?, task: Task): String { + if (repo == null) return task.defaultSystemPrompt + val customPrompt = repo.getCustomSystemPrompt(task.id).firstOrNull() + return customPrompt ?: task.defaultSystemPrompt + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt index 3cd13bdf0..d09533a1a 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt @@ -22,6 +22,7 @@ import android.graphics.BitmapFactory import android.graphics.Matrix import android.net.Uri import android.os.Build +import android.os.Bundle import android.util.Log import androidx.compose.foundation.layout.WindowInsets import androidx.compose.foundation.layout.ime @@ -36,7 +37,9 @@ import androidx.compose.ui.focus.onFocusEvent import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.platform.LocalFocusManager import androidx.exifinterface.media.ExifInterface +import com.google.ai.edge.gallery.GalleryEvent import com.google.ai.edge.gallery.data.SAMPLE_RATE +import com.google.ai.edge.gallery.firebaseAnalytics import com.google.gson.Gson import java.io.File import java.io.FileInputStream @@ -378,3 +381,14 @@ fun isAICoreSupported(allowedDeviceModels: Set?): Boolean { val currentModel = Build.MODEL?.lowercase() ?: return false return allowedDeviceModels.contains(currentModel) } + +fun logErrorToFirebase(event: GalleryEvent, errorType: String, errorMessage: String?) { + firebaseAnalytics?.logEvent( + event.id, + Bundle().apply { + putBoolean("success", false) + putString("error_type", errorType) + putString("error_message", errorMessage ?: "Unknown error") + }, + ) +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AddOrEditSkillBottomSheet.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AddOrEditSkillBottomSheet.kt index 8318902f2..14b97d8a0 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AddOrEditSkillBottomSheet.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AddOrEditSkillBottomSheet.kt @@ -83,6 +83,7 @@ import androidx.compose.ui.platform.LocalClipboard import androidx.compose.ui.res.stringResource import androidx.compose.ui.unit.dp import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.ui.common.CursorTrackingTextField import com.google.ai.edge.gallery.ui.theme.customColors import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt index 0dea011ab..2bb63fa08 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt @@ -203,7 +203,7 @@ fun AgentChatScreen( showAudioPicker = true, getActiveSkills = { skillManagerViewModel.getSelectedSkills().map { skill -> - if (skill.builtIn) skill.name else "custom_skill" + skillManagerViewModel.getSkillShortId(skill) } }, composableBelowMessageList = { model -> @@ -236,6 +236,8 @@ fun AgentChatScreen( } else { action.url } + val skill = skillManagerViewModel.getSkill(name = skillName) + val skillId = skill?.let { skillManagerViewModel.getSkillShortId(it) } ?: "xxxx" try { // Set up a safety net timeout so we NEVER hang the chat or tool execution launch { @@ -250,6 +252,7 @@ fun AgentChatScreen( GalleryEvent.SKILL_EXECUTION.id, Bundle().apply { putString("skill_name", skillName) + putString("skill_id", skillId) putBoolean("success", false) putString("error_type", "timeout") }, @@ -285,6 +288,7 @@ fun AgentChatScreen( GalleryEvent.SKILL_EXECUTION.id, Bundle().apply { putString("skill_name", skillName) + putString("skill_id", skillId) putBoolean("success", isSuccess) putString("error_type", errorType) }, @@ -323,6 +327,7 @@ fun AgentChatScreen( GalleryEvent.SKILL_EXECUTION.id, Bundle().apply { putString("skill_name", skillName) + putString("skill_id", skillId) putBoolean("success", false) putString("error_type", "exception") }, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt index ee8b9b3d0..eceef8b9f 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt @@ -81,6 +81,7 @@ class AgentChatTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = true, supportAudio = true, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt index c9bd3d2f8..ba47b51f4 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt @@ -37,7 +37,7 @@ import kotlinx.coroutines.runBlocking private const val TAG = "AGAgentTools" -class AgentTools() : ToolSet { +open class AgentTools() : ToolSet { lateinit var context: Context lateinit var skillManagerViewModel: SkillManagerViewModel diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/GenerateLlmPromptBottomSheet.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/GenerateLlmPromptBottomSheet.kt index 53290fd7f..01e725d5e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/GenerateLlmPromptBottomSheet.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/GenerateLlmPromptBottomSheet.kt @@ -45,6 +45,7 @@ import androidx.compose.ui.platform.LocalClipboard import androidx.compose.ui.res.stringResource import androidx.compose.ui.unit.dp import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.ui.common.CursorTrackingTextField import kotlinx.coroutines.launch private val PROMPT_TEMPLATE = diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt index 5b18cff85..6d7a6aa67 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt @@ -100,6 +100,8 @@ import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalUriHandler import androidx.compose.ui.res.pluralStringResource import androidx.compose.ui.res.stringResource +import androidx.compose.ui.semantics.contentDescription +import androidx.compose.ui.semantics.semantics import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextDecoration import androidx.compose.ui.unit.dp @@ -173,7 +175,7 @@ fun SkillManagerBottomSheet( var addSkillOptionTypeToConfirm by remember { mutableStateOf(null) } var skillToEditIndex by remember { mutableIntStateOf(-1) } var searchQuery by remember { mutableStateOf("") } - var savedSelectedSkillsNamesAndDescriptions = remember { "" } + var savedSelectedSkillsNamesAndDescriptions by remember { mutableStateOf("") } var filteredSkills by remember { mutableStateOf(uiState.skills) } val listState = rememberLazyListState() val uriHandler = LocalUriHandler.current @@ -860,7 +862,8 @@ private fun SkillItemRow( Switch( checked = skill.selected, onCheckedChange = onSkillEnabledChange, - modifier = Modifier.offset(y = (-4).dp), + modifier = + Modifier.offset(y = (-4).dp).semantics { contentDescription = "Toggle ${skill.name}" }, enabled = !inMultiSelectMode, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt index 9787657f8..0d9ddb372 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt @@ -117,6 +117,23 @@ val TRYOUT_CHIPS: List = ), ) +enum class SkillSource(val sourceName: String) { + BUILTIN("builtin"), + FEATURED("featured"), + REMOTE_URL("remote_url"), + LOCAL_IMPORT("local_import"), + UNKNOWN("unknown"), +} + +enum class SkillAction(val value: String) { + ADD("add"), + DELETE("delete"), + ENABLE("enable"), + DISABLE("disable"), + ENABLE_ALL("enable_all"), + DISABLE_ALL("disable_all"), +} + data class SkillState(val skill: Skill) data class SkillManagerUiState( @@ -341,13 +358,7 @@ constructor( Log.d(TAG, "Successfully added skill from URL: ${skill.name}") firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "add") - putString("source", "remote_url") - putString("skill_name", getSkillNameForLogging(skill)) - putBoolean("is_built_in", skill.builtIn) - putString("remote_url", url) - }, + getSkillLoggingParams(skill).apply { putString("action", SkillAction.ADD.value) }, ) onSuccess() } @@ -532,12 +543,7 @@ constructor( Log.d(TAG, "Successfully added skill from local import: ${skillWithDir.name}") firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "add") - putString("source", "local_import") - putString("skill_name", getSkillNameForLogging(skillWithDir)) - putBoolean("is_built_in", skillWithDir.builtIn) - }, + getSkillLoggingParams(skillWithDir).apply { putString("action", SkillAction.ADD.value) }, ) onSuccess() } @@ -603,15 +609,14 @@ constructor( return } - val skillNameToLog = getSkillNameForLogging(skill) - Log.d(TAG, "Analytics: skill_management, action=delete, skill_name=${skillNameToLog}") + val loggingParams = getSkillLoggingParams(skill) + Log.d( + TAG, + "Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams", + ) firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "delete") - putString("skill_name", skillNameToLog) - putBoolean("is_built_in", skill.builtIn) - }, + loggingParams.apply { putString("action", SkillAction.DELETE.value) }, ) // Update state. @@ -643,17 +648,14 @@ constructor( } for (skill in skillsToDelete) { + val loggingParams = getSkillLoggingParams(skill) Log.d( TAG, - "Analytics: skill_management, action=delete, skill_name=${getSkillNameForLogging(skill)}", + "Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams", ) firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "delete") - putString("skill_name", getSkillNameForLogging(skill)) - putBoolean("is_built_in", skill.builtIn) - }, + loggingParams.apply { putString("action", SkillAction.DELETE.value) }, ) } @@ -686,10 +688,8 @@ constructor( firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", if (selected) "enable" else "disable") - putString("skill_name", getSkillNameForLogging(skill.skill)) - putBoolean("is_built_in", skill.skill.builtIn) + getSkillLoggingParams(skill.skill).apply { + putString("action", if (selected) SkillAction.ENABLE.value else SkillAction.DISABLE.value) }, ) val updatedSkills = @@ -720,11 +720,16 @@ constructor( Log.d( TAG, - "Analytics: skill_management, action=${if (selected) "enable_all" else "disable_all"}", + "Analytics: skill_management, action=${if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value}", ) firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { putString("action", if (selected) "enable_all" else "disable_all") }, + Bundle().apply { + putString( + "action", + if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value, + ) + }, ) // Update data store. @@ -1166,15 +1171,73 @@ constructor( dataStoreRepository.setSkills(updatedList) } - private fun getSkillNameForLogging(skill: Skill): String { + private fun getSkillSource(skill: Skill): SkillSource { val isFeatured = skill.skillUrl.isNotEmpty() && _uiState.value.featuredSkills.any { it.skillUrl == skill.skillUrl } - return if (skill.builtIn || isFeatured) { - skill.name - } else { - "custom_skill" + return when { + skill.builtIn -> SkillSource.BUILTIN + isFeatured -> SkillSource.FEATURED + skill.skillUrl.isNotEmpty() -> SkillSource.REMOTE_URL + skill.importDirName.isNotEmpty() -> SkillSource.LOCAL_IMPORT + else -> SkillSource.UNKNOWN + } + } + + /** + * Generates a short 4-character hash to act as a stable ID. This solves the 100-character limit + * for list logging in GA4 AND allows us to distinguish between different custom skills in + * reports. Note: When we migrate to Cleancut or a similar service that doesn't have severe + * character limits, we can drop the human-readable skill_name from setup events and rely purely + * on this hash ID. + */ + fun getSkillShortId(skill: Skill): String { + val source = getSkillSource(skill) + val identifier = + when (source) { + SkillSource.BUILTIN, + SkillSource.FEATURED -> skill.name + SkillSource.LOCAL_IMPORT -> skill.importDirName + else -> skill.skillUrl + } + if (identifier.isEmpty()) return "xxxx" + + val prefix = + when (source) { + SkillSource.BUILTIN -> "b_" + SkillSource.FEATURED -> "f_" + SkillSource.LOCAL_IMPORT -> "l_" + else -> "c_" + } + + return try { + val digest = java.security.MessageDigest.getInstance("SHA-256") + val hashBytes = digest.digest(identifier.toByteArray()) + val hexString = hashBytes.joinToString("") { "%02x".format(it) } + prefix + hexString.take(4) + } catch (e: Exception) { + prefix + "fail" + } + } + + private fun getSkillLoggingParams(skill: Skill): Bundle { + val source = getSkillSource(skill) + val skillName = + if (source == SkillSource.BUILTIN || source == SkillSource.FEATURED) skill.name + else "custom_skill" + val bundle = + Bundle().apply { + putString("source", source.sourceName) + putString("skill_name", skillName) + putString("skill_id", getSkillShortId(skill)) + } + if ( + skill.skillUrl.isNotEmpty() && + (source == SkillSource.REMOTE_URL || source == SkillSource.FEATURED) + ) { + bundle.putString("remote_url", skill.skillUrl.take(100)) } + return bundle } private fun getSkillDestinationDir(originalImportDirName: String): File { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/common/SteadinessMonitor.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/common/SteadinessMonitor.kt new file mode 100644 index 000000000..c95929414 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/common/SteadinessMonitor.kt @@ -0,0 +1,76 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.customtasks.common + +import android.content.Context +import android.hardware.Sensor +import android.hardware.SensorEvent +import android.hardware.SensorEventListener +import android.hardware.SensorManager +import kotlin.math.sqrt +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow + +class SteadinessMonitor(context: Context, private val steadyDurationMs: Long = 2000L) : + SensorEventListener { + private val sensorManager = context.getSystemService(Context.SENSOR_SERVICE) as SensorManager + private val gyroSensor: Sensor? = sensorManager.getDefaultSensor(Sensor.TYPE_GYROSCOPE) + + // If gyroscope is not available, default to stable. + private val _isStable = MutableStateFlow(gyroSensor == null) + val isStable: StateFlow = _isStable + + // Threshold: 0.1 rad/s is quite steady. + private val STABILITY_THRESHOLD = 0.1f + + fun start() { + gyroSensor?.let { sensorManager.registerListener(this, it, SensorManager.SENSOR_DELAY_UI) } + } + + fun stop() { + sensorManager.unregisterListener(this) + } + + private var steadyStartTime: Long? = null + + override fun onSensorChanged(event: SensorEvent?) { + if (event?.sensor?.type == Sensor.TYPE_GYROSCOPE) { + val x = event.values[0] + val y = event.values[1] + val z = event.values[2] + + val magnitude = sqrt(x * x + y * y + z * z) + + if (magnitude < STABILITY_THRESHOLD) { + if (steadyStartTime == null) { + steadyStartTime = System.currentTimeMillis() + } + val start = steadyStartTime + if (start != null && System.currentTimeMillis() - start >= steadyDurationMs) { + _isStable.value = true + } else { + _isStable.value = false + } + } else { + steadyStartTime = null + _isStable.value = false + } + } + } + + override fun onAccuracyChanged(sensor: Sensor?, accuracy: Int) {} +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt index 592a60872..8aff4d0ba 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt @@ -74,6 +74,7 @@ class MobileActionsTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt index 40457d1a0..a2bd2e060 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt @@ -27,6 +27,7 @@ import androidx.core.net.toUri import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.data.BuiltInTaskId import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance @@ -211,6 +212,7 @@ constructor(@ApplicationContext private val appContext: Context) : ViewModel() { LlmChatModelHelper.initialize( context = context, model = model, + taskId = BuiltInTaskId.LLM_MOBILE_ACTIONS, supportImage = false, supportAudio = false, onDone = { error -> diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt index 4dbd8af47..a15d94198 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt @@ -109,6 +109,7 @@ class TinyGardenTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt index f13691fc1..01b94638a 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt @@ -21,6 +21,7 @@ import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.data.BuiltInTaskId import com.google.ai.edge.gallery.data.DataStoreRepository import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.ui.common.chat.ChatMessage @@ -168,6 +169,7 @@ constructor( LlmChatModelHelper.initialize( context = context, model = model, + taskId = BuiltInTaskId.LLM_TINY_GARDEN, supportImage = false, supportAudio = false, onDone = { error -> diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt index 4ca4b09d5..d596337bb 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt @@ -58,7 +58,11 @@ object ConfigKeys { val SUPPORT_TINY_GARDEN = ConfigKey("support_tiny_garden", "Support tiny garden") val SUPPORT_MOBILE_ACTIONS = ConfigKey("support_mobile_actions", "Support mobile actions") val SUPPORT_THINKING = ConfigKey("support_thinking", "Support thinking") + val SUPPORT_SPECULATIVE_DECODING = + ConfigKey("support_speculative_decoding", "Support speculative decoding") val ENABLE_THINKING = ConfigKey("enable_thinking", "Enable thinking") + val ENABLE_SPECULATIVE_DECODING = + ConfigKey("enable_speculative_decoding", "Enable speculative decoding") val MAX_RESULT_COUNT = ConfigKey("max_result_count", "Max result count") val USE_GPU = ConfigKey("use_gpu", "Use GPU") val ACCELERATOR = ConfigKey("accelerator", "Accelerator") @@ -226,6 +230,7 @@ fun createLlmChatConfigs( defaultTemperature: Float = DEFAULT_TEMPERATURE, accelerators: List = DEFAULT_ACCELERATORS, supportThinking: Boolean = false, + supportSpeculativeDecoding: Boolean = false, ): List { var maxTokensConfig: Config = LabelConfig(key = ConfigKeys.MAX_TOKENS, defaultValue = "$defaultMaxToken") @@ -274,6 +279,11 @@ fun createLlmChatConfigs( if (supportThinking) { configs.add(BooleanSwitchConfig(key = ConfigKeys.ENABLE_THINKING, defaultValue = false)) } + if (supportSpeculativeDecoding) { + configs.add( + BooleanSwitchConfig(key = ConfigKeys.ENABLE_SPECULATIVE_DECODING, defaultValue = false) + ) + } return configs } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt index c62d67023..a0dbe7275 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt @@ -33,7 +33,8 @@ private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]") data class PromptTemplate(val title: String, val description: String, val prompt: String) enum class ModelCapability { - @SerializedName("llm_thinking") LLM_THINKING + @SerializedName("llm_thinking") LLM_THINKING, + @SerializedName("speculative_decoding") SPECULATIVE_DECODING, } enum class RuntimeType { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt index 9ae703182..edffcd68b 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt @@ -176,6 +176,8 @@ data class AllowedModel( defaultMaxContextLength = llmMaxContextLength, accelerators = accelerators, supportThinking = capabilities?.contains(ModelCapability.LLM_THINKING) == true, + supportSpeculativeDecoding = + capabilities?.contains(ModelCapability.SPECULATIVE_DECODING) == true, ) }) .toMutableList() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/SystemPromptRepository.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/SystemPromptRepository.kt new file mode 100644 index 000000000..95b2a4902 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/SystemPromptRepository.kt @@ -0,0 +1,78 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.data + +import androidx.datastore.core.DataStore +import com.google.ai.edge.gallery.proto.UserData +import javax.inject.Inject +import javax.inject.Singleton +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map + +/** Repository for managing custom system prompts per task. */ +@Singleton +open class SystemPromptRepository +@Inject +constructor(private val userDataDataStore: DataStore) { + + private fun getKey(taskId: String): String = "system_prompt_$taskId" + + /** + * Updates the user-defined custom system prompt for the given [taskId] directly in the DataStore. + * + * @param taskId The ID of the task. + * @param newPrompt The new custom system prompt string. + */ + suspend fun updateSystemPrompt(taskId: String, newPrompt: String) { + userDataDataStore.updateData { userData -> + userData.toBuilder().putSecrets(getKey(taskId), newPrompt).build() + } + } + + /** + * Retrieves a Flow of the user-defined custom system prompt for the given [taskId] directly from + * the DataStore. + * + * This method returns ONLY the user's saved prompt, or null if no custom prompt has been set. It + * does NOT include any fallback to default system prompts. + * + * Most call sites should prefer using [SystemPromptHelper.getEffectiveSystemPrompt], which + * includes the necessary fallback logic. Direct use of this method is rare. + * + * @param taskId The ID of the task. + * @return A Flow emitting the custom system prompt string or null. + */ + fun getCustomSystemPrompt(taskId: String): Flow { + return userDataDataStore.data.map { it.secretsMap[getKey(taskId)] } + } + + /** + * Clears the user-defined custom system prompt for the given [taskId] directly from the + * DataStore. + * + * @param taskId The ID of the task. + */ + suspend fun clearCustomSystemPrompt(taskId: String) { + userDataDataStore.updateData { userData -> + if (userData.secretsMap.containsKey(getKey(taskId))) { + userData.toBuilder().removeSecrets(getKey(taskId)).build() + } else { + userData + } + } + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt index eee5c0ac6..c3c659676 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt @@ -39,6 +39,7 @@ interface LlmModelHelper { * * @param context the application context. * @param model the model to be initialized. + * @param taskId the task id where the model is being used. * @param supportImage whether to support image input. * @param supportAudio whether to support audio input. * @param onDone callback invoked when initialization is completed successfully. @@ -51,6 +52,7 @@ interface LlmModelHelper { fun initialize( context: Context, model: Model, + taskId: String, supportImage: Boolean, supportAudio: Boolean, onDone: (String) -> Unit, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt index 3a3624ccb..9ba60e392 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt @@ -62,6 +62,7 @@ object AICoreModelHelper : LlmModelHelper { override fun initialize( context: Context, model: Model, + taskId: String, supportImage: Boolean, supportAudio: Boolean, onDone: (String) -> Unit, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/CursorTrackingTextField.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/CursorTrackingTextField.kt similarity index 83% rename from Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/CursorTrackingTextField.kt rename to Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/CursorTrackingTextField.kt index 40c9f5d48..596887e65 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/CursorTrackingTextField.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/CursorTrackingTextField.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.ai.edge.gallery.customtasks.agentchat +package com.google.ai.edge.gallery.ui.common import androidx.annotation.StringRes import androidx.compose.foundation.interaction.MutableInteractionSource @@ -37,27 +37,36 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.SolidColor import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.SpanStyle import androidx.compose.ui.text.TextLayoutResult +import androidx.compose.ui.text.buildAnnotatedString import androidx.compose.ui.text.coerceIn import androidx.compose.ui.text.font.FontFamily import androidx.compose.ui.text.input.TextFieldValue import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.text.withStyle import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.dp import kotlinx.coroutines.launch +/** + * A [BasicTextField] that automatically tracks the cursor position and ensures it remains visible + * within the scrollable area, especially useful for multi-line text fields. + */ @Composable fun CursorTrackingTextField( - @StringRes labelResId: Int? = null, - @StringRes supportingTextResId: Int? = null, initialValue: String, onValueChange: (String) -> Unit, modifier: Modifier = Modifier, + @StringRes labelResId: Int? = null, + @StringRes supportingTextResId: Int? = null, + @StringRes placeholderResId: Int? = null, enabled: Boolean = true, minLines: Int = 1, extraOffset: Dp = 56.dp, monoFont: Boolean = false, extraBottomComposable: @Composable () -> Unit = {}, + trailingIcon: @Composable () -> Unit = {}, ) { val interactionSource = remember { MutableInteractionSource() } var textFieldValue by remember { mutableStateOf(TextFieldValue(initialValue)) } @@ -119,9 +128,31 @@ fun CursorTrackingTextField( innerTextField = innerTextField, enabled = true, singleLine = false, + placeholder = + if (placeholderResId != null) { + { + Text( + text = + buildAnnotatedString { + withStyle( + style = SpanStyle(fontSize = MaterialTheme.typography.bodyMedium.fontSize) + ) { + append(stringResource(placeholderResId)) + } + } + ) + } + } else { + null + }, visualTransformation = VisualTransformation.None, interactionSource = interactionSource, - label = { if (labelResId != null) Text(stringResource(labelResId)) }, + label = + if (labelResId != null) { + { Text(stringResource(labelResId)) } + } else { + null + }, supportingText = { if (supportingTextResId != null) { Column() { @@ -130,6 +161,7 @@ fun CursorTrackingTextField( } } }, + trailingIcon = trailingIcon, // The ContainerBox draws the actual border/outline container = { OutlinedTextFieldDefaults.Container( diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt index 2756b3889..c2bdb0c12 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt @@ -216,6 +216,13 @@ fun ModelPageAppBar( if (!task.allowCapability(ModelCapability.LLM_THINKING, model)) { modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_THINKING } } + var supportsSpeculativeDecoding = false + if ( + !supportsSpeculativeDecoding || + !task.allowCapability(ModelCapability.SPECULATIVE_DECODING, model) + ) { + modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_SPECULATIVE_DECODING } + } ConfigDialog( title = "Configurations", configs = modelConfigs, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/SmallFilledTonalButton.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/SmallFilledTonalButton.kt new file mode 100644 index 000000000..eff41d890 --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/SmallFilledTonalButton.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common + +import androidx.compose.foundation.layout.PaddingValues +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.material3.FilledTonalButton +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.res.vectorResource +import androidx.compose.ui.unit.Dp +import androidx.compose.ui.unit.dp + +private val BUTTON_CONTENT_PADDING = + PaddingValues(start = 12.dp, top = 2.dp, end = 12.dp, bottom = 2.dp) + +/** A small FilledTonalButton composable with a label and an optional icon. */ +@Composable +fun SmallFilledTonalButton( + onClick: () -> Unit, + labelResId: Int, + imageVector: ImageVector? = null, + iconResId: Int? = null, + size: Dp = 18.dp, +) { + FilledTonalButton( + onClick = onClick, + modifier = Modifier.height(32.dp), + contentPadding = BUTTON_CONTENT_PADDING, + ) { + if (imageVector != null) { + Icon(imageVector = imageVector, contentDescription = null, modifier = Modifier.size(size)) + } else if (iconResId != null) { + Icon( + ImageVector.vectorResource(iconResId), + contentDescription = null, + modifier = Modifier.size(size), + ) + } + Text( + stringResource(labelResId), + style = MaterialTheme.typography.labelMedium, + modifier = Modifier.padding(start = 4.dp), + ) + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/SmallOutlinedButton.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/SmallOutlinedButton.kt new file mode 100644 index 000000000..c770015bd --- /dev/null +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/SmallOutlinedButton.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.edge.gallery.ui.common + +import androidx.compose.foundation.layout.PaddingValues +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedButton +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.res.vectorResource +import androidx.compose.ui.unit.Dp +import androidx.compose.ui.unit.dp + +private val BUTTON_CONTENT_PADDING = + PaddingValues(start = 12.dp, top = 2.dp, end = 12.dp, bottom = 2.dp) + +/** A small OutlinedButton composable with a label and an optional icon. */ +@Composable +fun SmallOutlinedButton( + onClick: () -> Unit, + labelResId: Int, + imageVector: ImageVector? = null, + iconResId: Int? = null, + size: Dp = 18.dp, +) { + OutlinedButton( + onClick = onClick, + modifier = Modifier.height(32.dp), + contentPadding = BUTTON_CONTENT_PADDING, + ) { + if (imageVector != null) { + Icon(imageVector = imageVector, contentDescription = null, modifier = Modifier.size(size)) + } else if (iconResId != null) { + Icon( + ImageVector.vectorResource(iconResId), + contentDescription = null, + modifier = Modifier.size(size), + ) + } + Text( + stringResource(labelResId), + style = MaterialTheme.typography.labelMedium, + modifier = Modifier.padding(start = 4.dp), + ) + } +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index b6ee6a242..4a61b5bf7 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -28,6 +28,7 @@ import com.google.ai.edge.gallery.data.DEFAULT_TOPK import com.google.ai.edge.gallery.data.DEFAULT_TOPP import com.google.ai.edge.gallery.data.DEFAULT_VISION_ACCELERATOR import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.ModelCapability import com.google.ai.edge.gallery.runtime.CleanUpListener import com.google.ai.edge.gallery.runtime.LlmModelHelper import com.google.ai.edge.gallery.runtime.ResultListener @@ -60,6 +61,7 @@ object LlmChatModelHelper : LlmModelHelper { override fun initialize( context: Context, model: Model, + taskId: String, supportImage: Boolean, supportAudio: Boolean, onDone: (String) -> Unit, @@ -120,10 +122,29 @@ object LlmChatModelHelper : LlmModelHelper { else null, ) + // Check if the model file supports speculative decoding. + var supportsSpeculativeDecoding = false // Create an instance of LiteRT LM engine and conversation. try { + var speculativeDecoding = false + // Check if the model supports speculative decoding for the given task type and if the + // speculative decoding is enabled in the settings. + if ( + supportsSpeculativeDecoding && + model.capabilityToTaskTypes[ModelCapability.SPECULATIVE_DECODING]?.contains(taskId) == + true + ) { + speculativeDecoding = + model.getBooleanConfigValue( + key = ConfigKeys.ENABLE_SPECULATIVE_DECODING, + defaultValue = false, + ) + } + ExperimentalFlags.enableSpeculativeDecoding = speculativeDecoding + Log.d(TAG, "Speculative decoding enabled: $speculativeDecoding") val engine = Engine(engineConfig) engine.initialize() + ExperimentalFlags.enableSpeculativeDecoding = false ExperimentalFlags.enableConversationConstrainedDecoding = enableConversationConstrainedDecoding diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt index 9bcdb49e3..07ff3e85e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt @@ -80,6 +80,7 @@ class LlmChatTask @Inject constructor() : CustomTask { model.runtimeHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, @@ -162,6 +163,7 @@ class LlmAskImageTask @Inject constructor() : CustomTask { model.runtimeHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = true, supportAudio = false, onDone = onDone, @@ -227,6 +229,7 @@ class LlmAskAudioTask @Inject constructor() : CustomTask { model.runtimeHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = true, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt index dab3ea7c1..ccf3122d3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt @@ -61,6 +61,7 @@ class LlmSingleTurnTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt index 78442588f..d4c19f1e4 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt @@ -94,7 +94,8 @@ private const val TAG = "AGModelImportDialog" private val SUPPORTED_ACCELERATORS: List = if (isPixel10()) { - listOf(Accelerator.CPU) + val accelerators = mutableListOf(Accelerator.CPU) + accelerators.toList() } else { listOf(Accelerator.CPU, Accelerator.GPU, Accelerator.NPU) } @@ -136,6 +137,7 @@ private val IMPORT_CONFIGS_LLM: List = BooleanSwitchConfig(key = ConfigKeys.SUPPORT_TINY_GARDEN, defaultValue = false), BooleanSwitchConfig(key = ConfigKeys.SUPPORT_MOBILE_ACTIONS, defaultValue = false), BooleanSwitchConfig(key = ConfigKeys.SUPPORT_THINKING, defaultValue = false), + BooleanSwitchConfig(key = ConfigKeys.SUPPORT_SPECULATIVE_DECODING, defaultValue = false), SegmentedButtonConfig( key = ConfigKeys.COMPATIBLE_ACCELERATORS, defaultValue = SUPPORTED_ACCELERATORS[0].label, @@ -278,6 +280,12 @@ fun ModelImportDialog( valueType = ValueType.BOOLEAN, ) as Boolean + val supportSpeculativeDecoding = + convertValueToTargetType( + value = values.get(ConfigKeys.SUPPORT_SPECULATIVE_DECODING.label)!!, + valueType = ValueType.BOOLEAN, + ) + as Boolean val importedModel: ImportedModel = ImportedModel.newBuilder() .setFileName(fileName) @@ -294,6 +302,7 @@ fun ModelImportDialog( .setSupportMobileActions(supportMobileActions) .setSupportThinking(supportThinking) .setSupportTinyGarden(supportTinyGarden) + .setSupportSpeculativeDecoding(supportSpeculativeDecoding) .build() ) .build() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt index 465d5073d..3cdc49b18 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -1189,6 +1189,7 @@ constructor( val llmSupportTinyGarden = info.llmConfig.supportTinyGarden val llmSupportMobileActions = info.llmConfig.supportMobileActions val llmSupportThinking = info.llmConfig.supportThinking + val llmSupportSpeculativeDecoding = info.llmConfig.supportSpeculativeDecoding val configs: MutableList = createLlmChatConfigs( defaultMaxToken = llmMaxToken, @@ -1197,8 +1198,30 @@ constructor( defaultTemperature = info.llmConfig.defaultTemperature, accelerators = accelerators, supportThinking = llmSupportThinking, + supportSpeculativeDecoding = llmSupportSpeculativeDecoding, ) .toMutableList() + val capabilities: MutableList = mutableListOf() + val capabilityToTaskTypes: MutableMap> = mutableMapOf() + if (llmSupportThinking) { + capabilities.add(ModelCapability.LLM_THINKING) + capabilityToTaskTypes[ModelCapability.LLM_THINKING] = + listOf( + BuiltInTaskId.LLM_CHAT, + BuiltInTaskId.LLM_ASK_IMAGE, + BuiltInTaskId.LLM_ASK_AUDIO, + ) + } + if (llmSupportSpeculativeDecoding) { + capabilities.add(ModelCapability.SPECULATIVE_DECODING) + capabilityToTaskTypes[ModelCapability.SPECULATIVE_DECODING] = + listOf( + BuiltInTaskId.LLM_CHAT, + BuiltInTaskId.LLM_ASK_IMAGE, + BuiltInTaskId.LLM_ASK_AUDIO, + BuiltInTaskId.LLM_PROMPT_LAB, + ) + } val model = Model( name = info.fileName, @@ -1213,21 +1236,8 @@ constructor( llmSupportAudio = llmSupportAudio, llmSupportTinyGarden = llmSupportTinyGarden, llmSupportMobileActions = llmSupportMobileActions, - capabilities = - if (llmSupportThinking) listOf(ModelCapability.LLM_THINKING) else emptyList(), - capabilityToTaskTypes = - if (llmSupportThinking) { - mapOf( - ModelCapability.LLM_THINKING to - listOf( - BuiltInTaskId.LLM_CHAT, - BuiltInTaskId.LLM_ASK_IMAGE, - BuiltInTaskId.LLM_ASK_AUDIO, - ) - ) - } else { - emptyMap() - }, + capabilities = capabilities.toList(), + capabilityToTaskTypes = capabilityToTaskTypes.toMap(), llmMaxToken = llmMaxToken, accelerators = accelerators, // We assume all imported models are LLM for now. diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto index 4640bb1b8..c6a505632 100644 --- a/Android/src/app/src/main/proto/settings.proto +++ b/Android/src/app/src/main/proto/settings.proto @@ -59,6 +59,7 @@ message LlmConfig { bool support_tiny_garden = 8; bool support_mobile_actions = 9; bool support_thinking = 10; + bool support_speculative_decoding = 11; } message Settings { diff --git a/Android/src/app/src/main/res/drawable/analyze.xml b/Android/src/app/src/main/res/drawable/analyze.xml new file mode 100644 index 000000000..7819f10f4 --- /dev/null +++ b/Android/src/app/src/main/res/drawable/analyze.xml @@ -0,0 +1,27 @@ + + + + + + + \ No newline at end of file diff --git a/Android/src/app/src/main/res/drawable/live.xml b/Android/src/app/src/main/res/drawable/live.xml new file mode 100644 index 000000000..e4fac372d --- /dev/null +++ b/Android/src/app/src/main/res/drawable/live.xml @@ -0,0 +1,27 @@ + + + + + + + \ No newline at end of file