diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts index ba9759a0e..ff6e6e2cf 100644 --- a/Android/src/app/build.gradle.kts +++ b/Android/src/app/build.gradle.kts @@ -36,8 +36,8 @@ android { applicationId = "com.google.aiedge.gallery" minSdk = 31 targetSdk = 35 - versionCode = 29 - versionName = "1.0.12" + versionCode = 30 + versionName = "1.0.13" // Needed for HuggingFace auth workflows. // Use the scheme of the "Redirect URLs" in HuggingFace app. diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml index 5a7173c37..15e0de78d 100644 --- a/Android/src/app/src/main/AndroidManifest.xml +++ b/Android/src/app/src/main/AndroidManifest.xml @@ -143,10 +143,10 @@ - + diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt index 3cd13bdf0..d09533a1a 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt @@ -22,6 +22,7 @@ import android.graphics.BitmapFactory import android.graphics.Matrix import android.net.Uri import android.os.Build +import android.os.Bundle import android.util.Log import androidx.compose.foundation.layout.WindowInsets import androidx.compose.foundation.layout.ime @@ -36,7 +37,9 @@ import androidx.compose.ui.focus.onFocusEvent import androidx.compose.ui.platform.LocalDensity import androidx.compose.ui.platform.LocalFocusManager import androidx.exifinterface.media.ExifInterface +import com.google.ai.edge.gallery.GalleryEvent import com.google.ai.edge.gallery.data.SAMPLE_RATE +import com.google.ai.edge.gallery.firebaseAnalytics import com.google.gson.Gson import java.io.File import java.io.FileInputStream @@ -378,3 +381,14 @@ fun isAICoreSupported(allowedDeviceModels: Set?): Boolean { val currentModel = Build.MODEL?.lowercase() ?: return false return allowedDeviceModels.contains(currentModel) } + +fun logErrorToFirebase(event: GalleryEvent, errorType: String, errorMessage: String?) { + firebaseAnalytics?.logEvent( + event.id, + Bundle().apply { + putBoolean("success", false) + putString("error_type", errorType) + putString("error_message", errorMessage ?: "Unknown error") + }, + ) +} diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt index 0dea011ab..2bb63fa08 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt @@ -203,7 +203,7 @@ fun AgentChatScreen( showAudioPicker = true, getActiveSkills = { skillManagerViewModel.getSelectedSkills().map { skill -> - if (skill.builtIn) skill.name else "custom_skill" + skillManagerViewModel.getSkillShortId(skill) } }, composableBelowMessageList = { model -> @@ -236,6 +236,8 @@ fun AgentChatScreen( } else { action.url } + val skill = skillManagerViewModel.getSkill(name = skillName) + val skillId = skill?.let { skillManagerViewModel.getSkillShortId(it) } ?: "xxxx" try { // Set up a safety net timeout so we NEVER hang the chat or tool execution launch { @@ -250,6 +252,7 @@ fun AgentChatScreen( GalleryEvent.SKILL_EXECUTION.id, Bundle().apply { putString("skill_name", skillName) + putString("skill_id", skillId) putBoolean("success", false) putString("error_type", "timeout") }, @@ -285,6 +288,7 @@ fun AgentChatScreen( GalleryEvent.SKILL_EXECUTION.id, Bundle().apply { putString("skill_name", skillName) + putString("skill_id", skillId) putBoolean("success", isSuccess) putString("error_type", errorType) }, @@ -323,6 +327,7 @@ fun AgentChatScreen( GalleryEvent.SKILL_EXECUTION.id, Bundle().apply { putString("skill_name", skillName) + putString("skill_id", skillId) putBoolean("success", false) putString("error_type", "exception") }, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt index ee8b9b3d0..eceef8b9f 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt @@ -81,6 +81,7 @@ class AgentChatTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = true, supportAudio = true, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt index c9bd3d2f8..ba47b51f4 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt @@ -37,7 +37,7 @@ import kotlinx.coroutines.runBlocking private const val TAG = "AGAgentTools" -class AgentTools() : ToolSet { +open class AgentTools() : ToolSet { lateinit var context: Context lateinit var skillManagerViewModel: SkillManagerViewModel diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt index 5b18cff85..6d7a6aa67 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt @@ -100,6 +100,8 @@ import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalUriHandler import androidx.compose.ui.res.pluralStringResource import androidx.compose.ui.res.stringResource +import androidx.compose.ui.semantics.contentDescription +import androidx.compose.ui.semantics.semantics import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextDecoration import androidx.compose.ui.unit.dp @@ -173,7 +175,7 @@ fun SkillManagerBottomSheet( var addSkillOptionTypeToConfirm by remember { mutableStateOf(null) } var skillToEditIndex by remember { mutableIntStateOf(-1) } var searchQuery by remember { mutableStateOf("") } - var savedSelectedSkillsNamesAndDescriptions = remember { "" } + var savedSelectedSkillsNamesAndDescriptions by remember { mutableStateOf("") } var filteredSkills by remember { mutableStateOf(uiState.skills) } val listState = rememberLazyListState() val uriHandler = LocalUriHandler.current @@ -860,7 +862,8 @@ private fun SkillItemRow( Switch( checked = skill.selected, onCheckedChange = onSkillEnabledChange, - modifier = Modifier.offset(y = (-4).dp), + modifier = + Modifier.offset(y = (-4).dp).semantics { contentDescription = "Toggle ${skill.name}" }, enabled = !inMultiSelectMode, ) } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt index 9787657f8..0d9ddb372 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt @@ -117,6 +117,23 @@ val TRYOUT_CHIPS: List = ), ) +enum class SkillSource(val sourceName: String) { + BUILTIN("builtin"), + FEATURED("featured"), + REMOTE_URL("remote_url"), + LOCAL_IMPORT("local_import"), + UNKNOWN("unknown"), +} + +enum class SkillAction(val value: String) { + ADD("add"), + DELETE("delete"), + ENABLE("enable"), + DISABLE("disable"), + ENABLE_ALL("enable_all"), + DISABLE_ALL("disable_all"), +} + data class SkillState(val skill: Skill) data class SkillManagerUiState( @@ -341,13 +358,7 @@ constructor( Log.d(TAG, "Successfully added skill from URL: ${skill.name}") firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "add") - putString("source", "remote_url") - putString("skill_name", getSkillNameForLogging(skill)) - putBoolean("is_built_in", skill.builtIn) - putString("remote_url", url) - }, + getSkillLoggingParams(skill).apply { putString("action", SkillAction.ADD.value) }, ) onSuccess() } @@ -532,12 +543,7 @@ constructor( Log.d(TAG, "Successfully added skill from local import: ${skillWithDir.name}") firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "add") - putString("source", "local_import") - putString("skill_name", getSkillNameForLogging(skillWithDir)) - putBoolean("is_built_in", skillWithDir.builtIn) - }, + getSkillLoggingParams(skillWithDir).apply { putString("action", SkillAction.ADD.value) }, ) onSuccess() } @@ -603,15 +609,14 @@ constructor( return } - val skillNameToLog = getSkillNameForLogging(skill) - Log.d(TAG, "Analytics: skill_management, action=delete, skill_name=${skillNameToLog}") + val loggingParams = getSkillLoggingParams(skill) + Log.d( + TAG, + "Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams", + ) firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "delete") - putString("skill_name", skillNameToLog) - putBoolean("is_built_in", skill.builtIn) - }, + loggingParams.apply { putString("action", SkillAction.DELETE.value) }, ) // Update state. @@ -643,17 +648,14 @@ constructor( } for (skill in skillsToDelete) { + val loggingParams = getSkillLoggingParams(skill) Log.d( TAG, - "Analytics: skill_management, action=delete, skill_name=${getSkillNameForLogging(skill)}", + "Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams", ) firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", "delete") - putString("skill_name", getSkillNameForLogging(skill)) - putBoolean("is_built_in", skill.builtIn) - }, + loggingParams.apply { putString("action", SkillAction.DELETE.value) }, ) } @@ -686,10 +688,8 @@ constructor( firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { - putString("action", if (selected) "enable" else "disable") - putString("skill_name", getSkillNameForLogging(skill.skill)) - putBoolean("is_built_in", skill.skill.builtIn) + getSkillLoggingParams(skill.skill).apply { + putString("action", if (selected) SkillAction.ENABLE.value else SkillAction.DISABLE.value) }, ) val updatedSkills = @@ -720,11 +720,16 @@ constructor( Log.d( TAG, - "Analytics: skill_management, action=${if (selected) "enable_all" else "disable_all"}", + "Analytics: skill_management, action=${if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value}", ) firebaseAnalytics?.logEvent( GalleryEvent.SKILL_MANAGEMENT.id, - Bundle().apply { putString("action", if (selected) "enable_all" else "disable_all") }, + Bundle().apply { + putString( + "action", + if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value, + ) + }, ) // Update data store. @@ -1166,15 +1171,73 @@ constructor( dataStoreRepository.setSkills(updatedList) } - private fun getSkillNameForLogging(skill: Skill): String { + private fun getSkillSource(skill: Skill): SkillSource { val isFeatured = skill.skillUrl.isNotEmpty() && _uiState.value.featuredSkills.any { it.skillUrl == skill.skillUrl } - return if (skill.builtIn || isFeatured) { - skill.name - } else { - "custom_skill" + return when { + skill.builtIn -> SkillSource.BUILTIN + isFeatured -> SkillSource.FEATURED + skill.skillUrl.isNotEmpty() -> SkillSource.REMOTE_URL + skill.importDirName.isNotEmpty() -> SkillSource.LOCAL_IMPORT + else -> SkillSource.UNKNOWN + } + } + + /** + * Generates a short 4-character hash to act as a stable ID. This solves the 100-character limit + * for list logging in GA4 AND allows us to distinguish between different custom skills in + * reports. Note: When we migrate to Cleancut or a similar service that doesn't have severe + * character limits, we can drop the human-readable skill_name from setup events and rely purely + * on this hash ID. + */ + fun getSkillShortId(skill: Skill): String { + val source = getSkillSource(skill) + val identifier = + when (source) { + SkillSource.BUILTIN, + SkillSource.FEATURED -> skill.name + SkillSource.LOCAL_IMPORT -> skill.importDirName + else -> skill.skillUrl + } + if (identifier.isEmpty()) return "xxxx" + + val prefix = + when (source) { + SkillSource.BUILTIN -> "b_" + SkillSource.FEATURED -> "f_" + SkillSource.LOCAL_IMPORT -> "l_" + else -> "c_" + } + + return try { + val digest = java.security.MessageDigest.getInstance("SHA-256") + val hashBytes = digest.digest(identifier.toByteArray()) + val hexString = hashBytes.joinToString("") { "%02x".format(it) } + prefix + hexString.take(4) + } catch (e: Exception) { + prefix + "fail" + } + } + + private fun getSkillLoggingParams(skill: Skill): Bundle { + val source = getSkillSource(skill) + val skillName = + if (source == SkillSource.BUILTIN || source == SkillSource.FEATURED) skill.name + else "custom_skill" + val bundle = + Bundle().apply { + putString("source", source.sourceName) + putString("skill_name", skillName) + putString("skill_id", getSkillShortId(skill)) + } + if ( + skill.skillUrl.isNotEmpty() && + (source == SkillSource.REMOTE_URL || source == SkillSource.FEATURED) + ) { + bundle.putString("remote_url", skill.skillUrl.take(100)) } + return bundle } private fun getSkillDestinationDir(originalImportDirName: String): File { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt index 592a60872..8aff4d0ba 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt @@ -74,6 +74,7 @@ class MobileActionsTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt index 40457d1a0..a2bd2e060 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt @@ -27,6 +27,7 @@ import androidx.core.net.toUri import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.data.BuiltInTaskId import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance @@ -211,6 +212,7 @@ constructor(@ApplicationContext private val appContext: Context) : ViewModel() { LlmChatModelHelper.initialize( context = context, model = model, + taskId = BuiltInTaskId.LLM_MOBILE_ACTIONS, supportImage = false, supportAudio = false, onDone = { error -> diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt index 4dbd8af47..a15d94198 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt @@ -109,6 +109,7 @@ class TinyGardenTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt index f13691fc1..01b94638a 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt @@ -21,6 +21,7 @@ import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.google.ai.edge.gallery.R +import com.google.ai.edge.gallery.data.BuiltInTaskId import com.google.ai.edge.gallery.data.DataStoreRepository import com.google.ai.edge.gallery.data.Model import com.google.ai.edge.gallery.ui.common.chat.ChatMessage @@ -168,6 +169,7 @@ constructor( LlmChatModelHelper.initialize( context = context, model = model, + taskId = BuiltInTaskId.LLM_TINY_GARDEN, supportImage = false, supportAudio = false, onDone = { error -> diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt index 4ca4b09d5..d596337bb 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt @@ -58,7 +58,11 @@ object ConfigKeys { val SUPPORT_TINY_GARDEN = ConfigKey("support_tiny_garden", "Support tiny garden") val SUPPORT_MOBILE_ACTIONS = ConfigKey("support_mobile_actions", "Support mobile actions") val SUPPORT_THINKING = ConfigKey("support_thinking", "Support thinking") + val SUPPORT_SPECULATIVE_DECODING = + ConfigKey("support_speculative_decoding", "Support speculative decoding") val ENABLE_THINKING = ConfigKey("enable_thinking", "Enable thinking") + val ENABLE_SPECULATIVE_DECODING = + ConfigKey("enable_speculative_decoding", "Enable speculative decoding") val MAX_RESULT_COUNT = ConfigKey("max_result_count", "Max result count") val USE_GPU = ConfigKey("use_gpu", "Use GPU") val ACCELERATOR = ConfigKey("accelerator", "Accelerator") @@ -226,6 +230,7 @@ fun createLlmChatConfigs( defaultTemperature: Float = DEFAULT_TEMPERATURE, accelerators: List = DEFAULT_ACCELERATORS, supportThinking: Boolean = false, + supportSpeculativeDecoding: Boolean = false, ): List { var maxTokensConfig: Config = LabelConfig(key = ConfigKeys.MAX_TOKENS, defaultValue = "$defaultMaxToken") @@ -274,6 +279,11 @@ fun createLlmChatConfigs( if (supportThinking) { configs.add(BooleanSwitchConfig(key = ConfigKeys.ENABLE_THINKING, defaultValue = false)) } + if (supportSpeculativeDecoding) { + configs.add( + BooleanSwitchConfig(key = ConfigKeys.ENABLE_SPECULATIVE_DECODING, defaultValue = false) + ) + } return configs } diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt index c62d67023..a0dbe7275 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt @@ -33,7 +33,8 @@ private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]") data class PromptTemplate(val title: String, val description: String, val prompt: String) enum class ModelCapability { - @SerializedName("llm_thinking") LLM_THINKING + @SerializedName("llm_thinking") LLM_THINKING, + @SerializedName("speculative_decoding") SPECULATIVE_DECODING, } enum class RuntimeType { diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt index 9ae703182..edffcd68b 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt @@ -176,6 +176,8 @@ data class AllowedModel( defaultMaxContextLength = llmMaxContextLength, accelerators = accelerators, supportThinking = capabilities?.contains(ModelCapability.LLM_THINKING) == true, + supportSpeculativeDecoding = + capabilities?.contains(ModelCapability.SPECULATIVE_DECODING) == true, ) }) .toMutableList() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt index eee5c0ac6..c3c659676 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt @@ -39,6 +39,7 @@ interface LlmModelHelper { * * @param context the application context. * @param model the model to be initialized. + * @param taskId the task id where the model is being used. * @param supportImage whether to support image input. * @param supportAudio whether to support audio input. * @param onDone callback invoked when initialization is completed successfully. @@ -51,6 +52,7 @@ interface LlmModelHelper { fun initialize( context: Context, model: Model, + taskId: String, supportImage: Boolean, supportAudio: Boolean, onDone: (String) -> Unit, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt index 3a3624ccb..9ba60e392 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt @@ -62,6 +62,7 @@ object AICoreModelHelper : LlmModelHelper { override fun initialize( context: Context, model: Model, + taskId: String, supportImage: Boolean, supportAudio: Boolean, onDone: (String) -> Unit, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt index 2756b3889..c2bdb0c12 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt @@ -216,6 +216,13 @@ fun ModelPageAppBar( if (!task.allowCapability(ModelCapability.LLM_THINKING, model)) { modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_THINKING } } + var supportsSpeculativeDecoding = false + if ( + !supportsSpeculativeDecoding || + !task.allowCapability(ModelCapability.SPECULATIVE_DECODING, model) + ) { + modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_SPECULATIVE_DECODING } + } ConfigDialog( title = "Configurations", configs = modelConfigs, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt index b6ee6a242..236da9bce 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt @@ -28,6 +28,7 @@ import com.google.ai.edge.gallery.data.DEFAULT_TOPK import com.google.ai.edge.gallery.data.DEFAULT_TOPP import com.google.ai.edge.gallery.data.DEFAULT_VISION_ACCELERATOR import com.google.ai.edge.gallery.data.Model +import com.google.ai.edge.gallery.data.ModelCapability import com.google.ai.edge.gallery.runtime.CleanUpListener import com.google.ai.edge.gallery.runtime.LlmModelHelper import com.google.ai.edge.gallery.runtime.ResultListener @@ -60,6 +61,7 @@ object LlmChatModelHelper : LlmModelHelper { override fun initialize( context: Context, model: Model, + taskId: String, supportImage: Boolean, supportAudio: Boolean, onDone: (String) -> Unit, @@ -120,9 +122,28 @@ object LlmChatModelHelper : LlmModelHelper { else null, ) + // Check if the model file supports speculative decoding. + var supportsSpeculativeDecoding = false // Create an instance of LiteRT LM engine and conversation. try { + var speculativeDecoding = false + // Check if the model supports speculative decoding for the given task type and if the + // speculative decoding is enabled in the settings. + if ( + supportsSpeculativeDecoding && + model.capabilityToTaskTypes[ModelCapability.SPECULATIVE_DECODING]?.contains(taskId) == + true + ) { + speculativeDecoding = + model.getBooleanConfigValue( + key = ConfigKeys.ENABLE_SPECULATIVE_DECODING, + defaultValue = false, + ) + } + ExperimentalFlags.enableSpeculativeDecoding = speculativeDecoding + Log.d(TAG, "Speculative decoding enabled: $speculativeDecoding") val engine = Engine(engineConfig) + ExperimentalFlags.enableSpeculativeDecoding = false engine.initialize() ExperimentalFlags.enableConversationConstrainedDecoding = diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt index 9bcdb49e3..07ff3e85e 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt @@ -80,6 +80,7 @@ class LlmChatTask @Inject constructor() : CustomTask { model.runtimeHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, @@ -162,6 +163,7 @@ class LlmAskImageTask @Inject constructor() : CustomTask { model.runtimeHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = true, supportAudio = false, onDone = onDone, @@ -227,6 +229,7 @@ class LlmAskAudioTask @Inject constructor() : CustomTask { model.runtimeHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = true, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt index dab3ea7c1..ccf3122d3 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt @@ -61,6 +61,7 @@ class LlmSingleTurnTask @Inject constructor() : CustomTask { LlmChatModelHelper.initialize( context = context, model = model, + taskId = task.id, supportImage = false, supportAudio = false, onDone = onDone, diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt index 78442588f..d4c19f1e4 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt @@ -94,7 +94,8 @@ private const val TAG = "AGModelImportDialog" private val SUPPORTED_ACCELERATORS: List = if (isPixel10()) { - listOf(Accelerator.CPU) + val accelerators = mutableListOf(Accelerator.CPU) + accelerators.toList() } else { listOf(Accelerator.CPU, Accelerator.GPU, Accelerator.NPU) } @@ -136,6 +137,7 @@ private val IMPORT_CONFIGS_LLM: List = BooleanSwitchConfig(key = ConfigKeys.SUPPORT_TINY_GARDEN, defaultValue = false), BooleanSwitchConfig(key = ConfigKeys.SUPPORT_MOBILE_ACTIONS, defaultValue = false), BooleanSwitchConfig(key = ConfigKeys.SUPPORT_THINKING, defaultValue = false), + BooleanSwitchConfig(key = ConfigKeys.SUPPORT_SPECULATIVE_DECODING, defaultValue = false), SegmentedButtonConfig( key = ConfigKeys.COMPATIBLE_ACCELERATORS, defaultValue = SUPPORTED_ACCELERATORS[0].label, @@ -278,6 +280,12 @@ fun ModelImportDialog( valueType = ValueType.BOOLEAN, ) as Boolean + val supportSpeculativeDecoding = + convertValueToTargetType( + value = values.get(ConfigKeys.SUPPORT_SPECULATIVE_DECODING.label)!!, + valueType = ValueType.BOOLEAN, + ) + as Boolean val importedModel: ImportedModel = ImportedModel.newBuilder() .setFileName(fileName) @@ -294,6 +302,7 @@ fun ModelImportDialog( .setSupportMobileActions(supportMobileActions) .setSupportThinking(supportThinking) .setSupportTinyGarden(supportTinyGarden) + .setSupportSpeculativeDecoding(supportSpeculativeDecoding) .build() ) .build() diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt index 465d5073d..3cdc49b18 100644 --- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt +++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt @@ -1189,6 +1189,7 @@ constructor( val llmSupportTinyGarden = info.llmConfig.supportTinyGarden val llmSupportMobileActions = info.llmConfig.supportMobileActions val llmSupportThinking = info.llmConfig.supportThinking + val llmSupportSpeculativeDecoding = info.llmConfig.supportSpeculativeDecoding val configs: MutableList = createLlmChatConfigs( defaultMaxToken = llmMaxToken, @@ -1197,8 +1198,30 @@ constructor( defaultTemperature = info.llmConfig.defaultTemperature, accelerators = accelerators, supportThinking = llmSupportThinking, + supportSpeculativeDecoding = llmSupportSpeculativeDecoding, ) .toMutableList() + val capabilities: MutableList = mutableListOf() + val capabilityToTaskTypes: MutableMap> = mutableMapOf() + if (llmSupportThinking) { + capabilities.add(ModelCapability.LLM_THINKING) + capabilityToTaskTypes[ModelCapability.LLM_THINKING] = + listOf( + BuiltInTaskId.LLM_CHAT, + BuiltInTaskId.LLM_ASK_IMAGE, + BuiltInTaskId.LLM_ASK_AUDIO, + ) + } + if (llmSupportSpeculativeDecoding) { + capabilities.add(ModelCapability.SPECULATIVE_DECODING) + capabilityToTaskTypes[ModelCapability.SPECULATIVE_DECODING] = + listOf( + BuiltInTaskId.LLM_CHAT, + BuiltInTaskId.LLM_ASK_IMAGE, + BuiltInTaskId.LLM_ASK_AUDIO, + BuiltInTaskId.LLM_PROMPT_LAB, + ) + } val model = Model( name = info.fileName, @@ -1213,21 +1236,8 @@ constructor( llmSupportAudio = llmSupportAudio, llmSupportTinyGarden = llmSupportTinyGarden, llmSupportMobileActions = llmSupportMobileActions, - capabilities = - if (llmSupportThinking) listOf(ModelCapability.LLM_THINKING) else emptyList(), - capabilityToTaskTypes = - if (llmSupportThinking) { - mapOf( - ModelCapability.LLM_THINKING to - listOf( - BuiltInTaskId.LLM_CHAT, - BuiltInTaskId.LLM_ASK_IMAGE, - BuiltInTaskId.LLM_ASK_AUDIO, - ) - ) - } else { - emptyMap() - }, + capabilities = capabilities.toList(), + capabilityToTaskTypes = capabilityToTaskTypes.toMap(), llmMaxToken = llmMaxToken, accelerators = accelerators, // We assume all imported models are LLM for now. diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto index 4640bb1b8..c6a505632 100644 --- a/Android/src/app/src/main/proto/settings.proto +++ b/Android/src/app/src/main/proto/settings.proto @@ -59,6 +59,7 @@ message LlmConfig { bool support_tiny_garden = 8; bool support_mobile_actions = 9; bool support_thinking = 10; + bool support_speculative_decoding = 11; } message Settings {