diff --git a/Android/src/app/build.gradle.kts b/Android/src/app/build.gradle.kts
index ba9759a0e..ff6e6e2cf 100644
--- a/Android/src/app/build.gradle.kts
+++ b/Android/src/app/build.gradle.kts
@@ -36,8 +36,8 @@ android {
applicationId = "com.google.aiedge.gallery"
minSdk = 31
targetSdk = 35
- versionCode = 29
- versionName = "1.0.12"
+ versionCode = 30
+ versionName = "1.0.13"
// Needed for HuggingFace auth workflows.
// Use the scheme of the "Redirect URLs" in HuggingFace app.
diff --git a/Android/src/app/src/main/AndroidManifest.xml b/Android/src/app/src/main/AndroidManifest.xml
index 5a7173c37..15e0de78d 100644
--- a/Android/src/app/src/main/AndroidManifest.xml
+++ b/Android/src/app/src/main/AndroidManifest.xml
@@ -143,10 +143,10 @@
-
+
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
index 3cd13bdf0..d09533a1a 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt
@@ -22,6 +22,7 @@ import android.graphics.BitmapFactory
import android.graphics.Matrix
import android.net.Uri
import android.os.Build
+import android.os.Bundle
import android.util.Log
import androidx.compose.foundation.layout.WindowInsets
import androidx.compose.foundation.layout.ime
@@ -36,7 +37,9 @@ import androidx.compose.ui.focus.onFocusEvent
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalFocusManager
import androidx.exifinterface.media.ExifInterface
+import com.google.ai.edge.gallery.GalleryEvent
import com.google.ai.edge.gallery.data.SAMPLE_RATE
+import com.google.ai.edge.gallery.firebaseAnalytics
import com.google.gson.Gson
import java.io.File
import java.io.FileInputStream
@@ -378,3 +381,14 @@ fun isAICoreSupported(allowedDeviceModels: Set?): Boolean {
val currentModel = Build.MODEL?.lowercase() ?: return false
return allowedDeviceModels.contains(currentModel)
}
+
+fun logErrorToFirebase(event: GalleryEvent, errorType: String, errorMessage: String?) {
+ firebaseAnalytics?.logEvent(
+ event.id,
+ Bundle().apply {
+ putBoolean("success", false)
+ putString("error_type", errorType)
+ putString("error_message", errorMessage ?: "Unknown error")
+ },
+ )
+}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt
index 0dea011ab..2bb63fa08 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatScreen.kt
@@ -203,7 +203,7 @@ fun AgentChatScreen(
showAudioPicker = true,
getActiveSkills = {
skillManagerViewModel.getSelectedSkills().map { skill ->
- if (skill.builtIn) skill.name else "custom_skill"
+ skillManagerViewModel.getSkillShortId(skill)
}
},
composableBelowMessageList = { model ->
@@ -236,6 +236,8 @@ fun AgentChatScreen(
} else {
action.url
}
+ val skill = skillManagerViewModel.getSkill(name = skillName)
+ val skillId = skill?.let { skillManagerViewModel.getSkillShortId(it) } ?: "xxxx"
try {
// Set up a safety net timeout so we NEVER hang the chat or tool execution
launch {
@@ -250,6 +252,7 @@ fun AgentChatScreen(
GalleryEvent.SKILL_EXECUTION.id,
Bundle().apply {
putString("skill_name", skillName)
+ putString("skill_id", skillId)
putBoolean("success", false)
putString("error_type", "timeout")
},
@@ -285,6 +288,7 @@ fun AgentChatScreen(
GalleryEvent.SKILL_EXECUTION.id,
Bundle().apply {
putString("skill_name", skillName)
+ putString("skill_id", skillId)
putBoolean("success", isSuccess)
putString("error_type", errorType)
},
@@ -323,6 +327,7 @@ fun AgentChatScreen(
GalleryEvent.SKILL_EXECUTION.id,
Bundle().apply {
putString("skill_name", skillName)
+ putString("skill_id", skillId)
putBoolean("success", false)
putString("error_type", "exception")
},
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt
index ee8b9b3d0..eceef8b9f 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentChatTaskModule.kt
@@ -81,6 +81,7 @@ class AgentChatTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = true,
supportAudio = true,
onDone = onDone,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt
index c9bd3d2f8..ba47b51f4 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/AgentTools.kt
@@ -37,7 +37,7 @@ import kotlinx.coroutines.runBlocking
private const val TAG = "AGAgentTools"
-class AgentTools() : ToolSet {
+open class AgentTools() : ToolSet {
lateinit var context: Context
lateinit var skillManagerViewModel: SkillManagerViewModel
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt
index 5b18cff85..6d7a6aa67 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerBottomSheet.kt
@@ -100,6 +100,8 @@ import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.platform.LocalUriHandler
import androidx.compose.ui.res.pluralStringResource
import androidx.compose.ui.res.stringResource
+import androidx.compose.ui.semantics.contentDescription
+import androidx.compose.ui.semantics.semantics
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.unit.dp
@@ -173,7 +175,7 @@ fun SkillManagerBottomSheet(
var addSkillOptionTypeToConfirm by remember { mutableStateOf(null) }
var skillToEditIndex by remember { mutableIntStateOf(-1) }
var searchQuery by remember { mutableStateOf("") }
- var savedSelectedSkillsNamesAndDescriptions = remember { "" }
+ var savedSelectedSkillsNamesAndDescriptions by remember { mutableStateOf("") }
var filteredSkills by remember { mutableStateOf(uiState.skills) }
val listState = rememberLazyListState()
val uriHandler = LocalUriHandler.current
@@ -860,7 +862,8 @@ private fun SkillItemRow(
Switch(
checked = skill.selected,
onCheckedChange = onSkillEnabledChange,
- modifier = Modifier.offset(y = (-4).dp),
+ modifier =
+ Modifier.offset(y = (-4).dp).semantics { contentDescription = "Toggle ${skill.name}" },
enabled = !inMultiSelectMode,
)
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt
index 9787657f8..0d9ddb372 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/agentchat/SkillManagerViewModel.kt
@@ -117,6 +117,23 @@ val TRYOUT_CHIPS: List =
),
)
+enum class SkillSource(val sourceName: String) {
+ BUILTIN("builtin"),
+ FEATURED("featured"),
+ REMOTE_URL("remote_url"),
+ LOCAL_IMPORT("local_import"),
+ UNKNOWN("unknown"),
+}
+
+enum class SkillAction(val value: String) {
+ ADD("add"),
+ DELETE("delete"),
+ ENABLE("enable"),
+ DISABLE("disable"),
+ ENABLE_ALL("enable_all"),
+ DISABLE_ALL("disable_all"),
+}
+
data class SkillState(val skill: Skill)
data class SkillManagerUiState(
@@ -341,13 +358,7 @@ constructor(
Log.d(TAG, "Successfully added skill from URL: ${skill.name}")
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
- Bundle().apply {
- putString("action", "add")
- putString("source", "remote_url")
- putString("skill_name", getSkillNameForLogging(skill))
- putBoolean("is_built_in", skill.builtIn)
- putString("remote_url", url)
- },
+ getSkillLoggingParams(skill).apply { putString("action", SkillAction.ADD.value) },
)
onSuccess()
}
@@ -532,12 +543,7 @@ constructor(
Log.d(TAG, "Successfully added skill from local import: ${skillWithDir.name}")
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
- Bundle().apply {
- putString("action", "add")
- putString("source", "local_import")
- putString("skill_name", getSkillNameForLogging(skillWithDir))
- putBoolean("is_built_in", skillWithDir.builtIn)
- },
+ getSkillLoggingParams(skillWithDir).apply { putString("action", SkillAction.ADD.value) },
)
onSuccess()
}
@@ -603,15 +609,14 @@ constructor(
return
}
- val skillNameToLog = getSkillNameForLogging(skill)
- Log.d(TAG, "Analytics: skill_management, action=delete, skill_name=${skillNameToLog}")
+ val loggingParams = getSkillLoggingParams(skill)
+ Log.d(
+ TAG,
+ "Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams",
+ )
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
- Bundle().apply {
- putString("action", "delete")
- putString("skill_name", skillNameToLog)
- putBoolean("is_built_in", skill.builtIn)
- },
+ loggingParams.apply { putString("action", SkillAction.DELETE.value) },
)
// Update state.
@@ -643,17 +648,14 @@ constructor(
}
for (skill in skillsToDelete) {
+ val loggingParams = getSkillLoggingParams(skill)
Log.d(
TAG,
- "Analytics: skill_management, action=delete, skill_name=${getSkillNameForLogging(skill)}",
+ "Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams",
)
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
- Bundle().apply {
- putString("action", "delete")
- putString("skill_name", getSkillNameForLogging(skill))
- putBoolean("is_built_in", skill.builtIn)
- },
+ loggingParams.apply { putString("action", SkillAction.DELETE.value) },
)
}
@@ -686,10 +688,8 @@ constructor(
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
- Bundle().apply {
- putString("action", if (selected) "enable" else "disable")
- putString("skill_name", getSkillNameForLogging(skill.skill))
- putBoolean("is_built_in", skill.skill.builtIn)
+ getSkillLoggingParams(skill.skill).apply {
+ putString("action", if (selected) SkillAction.ENABLE.value else SkillAction.DISABLE.value)
},
)
val updatedSkills =
@@ -720,11 +720,16 @@ constructor(
Log.d(
TAG,
- "Analytics: skill_management, action=${if (selected) "enable_all" else "disable_all"}",
+ "Analytics: skill_management, action=${if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value}",
)
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
- Bundle().apply { putString("action", if (selected) "enable_all" else "disable_all") },
+ Bundle().apply {
+ putString(
+ "action",
+ if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value,
+ )
+ },
)
// Update data store.
@@ -1166,15 +1171,73 @@ constructor(
dataStoreRepository.setSkills(updatedList)
}
- private fun getSkillNameForLogging(skill: Skill): String {
+ private fun getSkillSource(skill: Skill): SkillSource {
val isFeatured =
skill.skillUrl.isNotEmpty() &&
_uiState.value.featuredSkills.any { it.skillUrl == skill.skillUrl }
- return if (skill.builtIn || isFeatured) {
- skill.name
- } else {
- "custom_skill"
+ return when {
+ skill.builtIn -> SkillSource.BUILTIN
+ isFeatured -> SkillSource.FEATURED
+ skill.skillUrl.isNotEmpty() -> SkillSource.REMOTE_URL
+ skill.importDirName.isNotEmpty() -> SkillSource.LOCAL_IMPORT
+ else -> SkillSource.UNKNOWN
+ }
+ }
+
+ /**
+ * Generates a short 4-character hash to act as a stable ID. This solves the 100-character limit
+ * for list logging in GA4 AND allows us to distinguish between different custom skills in
+ * reports. Note: When we migrate to Cleancut or a similar service that doesn't have severe
+ * character limits, we can drop the human-readable skill_name from setup events and rely purely
+ * on this hash ID.
+ */
+ fun getSkillShortId(skill: Skill): String {
+ val source = getSkillSource(skill)
+ val identifier =
+ when (source) {
+ SkillSource.BUILTIN,
+ SkillSource.FEATURED -> skill.name
+ SkillSource.LOCAL_IMPORT -> skill.importDirName
+ else -> skill.skillUrl
+ }
+ if (identifier.isEmpty()) return "xxxx"
+
+ val prefix =
+ when (source) {
+ SkillSource.BUILTIN -> "b_"
+ SkillSource.FEATURED -> "f_"
+ SkillSource.LOCAL_IMPORT -> "l_"
+ else -> "c_"
+ }
+
+ return try {
+ val digest = java.security.MessageDigest.getInstance("SHA-256")
+ val hashBytes = digest.digest(identifier.toByteArray())
+ val hexString = hashBytes.joinToString("") { "%02x".format(it) }
+ prefix + hexString.take(4)
+ } catch (e: Exception) {
+ prefix + "fail"
+ }
+ }
+
+ private fun getSkillLoggingParams(skill: Skill): Bundle {
+ val source = getSkillSource(skill)
+ val skillName =
+ if (source == SkillSource.BUILTIN || source == SkillSource.FEATURED) skill.name
+ else "custom_skill"
+ val bundle =
+ Bundle().apply {
+ putString("source", source.sourceName)
+ putString("skill_name", skillName)
+ putString("skill_id", getSkillShortId(skill))
+ }
+ if (
+ skill.skillUrl.isNotEmpty() &&
+ (source == SkillSource.REMOTE_URL || source == SkillSource.FEATURED)
+ ) {
+ bundle.putString("remote_url", skill.skillUrl.take(100))
}
+ return bundle
}
private fun getSkillDestinationDir(originalImportDirName: String): File {
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt
index 592a60872..8aff4d0ba 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsTask.kt
@@ -74,6 +74,7 @@ class MobileActionsTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = false,
supportAudio = false,
onDone = onDone,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt
index 40457d1a0..a2bd2e060 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/mobileactions/MobileActionsViewModel.kt
@@ -27,6 +27,7 @@ import androidx.core.net.toUri
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.R
+import com.google.ai.edge.gallery.data.BuiltInTaskId
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
@@ -211,6 +212,7 @@ constructor(@ApplicationContext private val appContext: Context) : ViewModel() {
LlmChatModelHelper.initialize(
context = context,
model = model,
+ taskId = BuiltInTaskId.LLM_MOBILE_ACTIONS,
supportImage = false,
supportAudio = false,
onDone = { error ->
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt
index 4dbd8af47..a15d94198 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenTask.kt
@@ -109,6 +109,7 @@ class TinyGardenTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = false,
supportAudio = false,
onDone = onDone,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt
index f13691fc1..01b94638a 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/customtasks/tinygarden/TinyGardenViewModel.kt
@@ -21,6 +21,7 @@ import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.R
+import com.google.ai.edge.gallery.data.BuiltInTaskId
import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.chat.ChatMessage
@@ -168,6 +169,7 @@ constructor(
LlmChatModelHelper.initialize(
context = context,
model = model,
+ taskId = BuiltInTaskId.LLM_TINY_GARDEN,
supportImage = false,
supportAudio = false,
onDone = { error ->
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
index 4ca4b09d5..d596337bb 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt
@@ -58,7 +58,11 @@ object ConfigKeys {
val SUPPORT_TINY_GARDEN = ConfigKey("support_tiny_garden", "Support tiny garden")
val SUPPORT_MOBILE_ACTIONS = ConfigKey("support_mobile_actions", "Support mobile actions")
val SUPPORT_THINKING = ConfigKey("support_thinking", "Support thinking")
+ val SUPPORT_SPECULATIVE_DECODING =
+ ConfigKey("support_speculative_decoding", "Support speculative decoding")
val ENABLE_THINKING = ConfigKey("enable_thinking", "Enable thinking")
+ val ENABLE_SPECULATIVE_DECODING =
+ ConfigKey("enable_speculative_decoding", "Enable speculative decoding")
val MAX_RESULT_COUNT = ConfigKey("max_result_count", "Max result count")
val USE_GPU = ConfigKey("use_gpu", "Use GPU")
val ACCELERATOR = ConfigKey("accelerator", "Accelerator")
@@ -226,6 +230,7 @@ fun createLlmChatConfigs(
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List = DEFAULT_ACCELERATORS,
supportThinking: Boolean = false,
+ supportSpeculativeDecoding: Boolean = false,
): List {
var maxTokensConfig: Config =
LabelConfig(key = ConfigKeys.MAX_TOKENS, defaultValue = "$defaultMaxToken")
@@ -274,6 +279,11 @@ fun createLlmChatConfigs(
if (supportThinking) {
configs.add(BooleanSwitchConfig(key = ConfigKeys.ENABLE_THINKING, defaultValue = false))
}
+ if (supportSpeculativeDecoding) {
+ configs.add(
+ BooleanSwitchConfig(key = ConfigKeys.ENABLE_SPECULATIVE_DECODING, defaultValue = false)
+ )
+ }
return configs
}
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
index c62d67023..a0dbe7275 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt
@@ -33,7 +33,8 @@ private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
data class PromptTemplate(val title: String, val description: String, val prompt: String)
enum class ModelCapability {
- @SerializedName("llm_thinking") LLM_THINKING
+ @SerializedName("llm_thinking") LLM_THINKING,
+ @SerializedName("speculative_decoding") SPECULATIVE_DECODING,
}
enum class RuntimeType {
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
index 9ae703182..edffcd68b 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt
@@ -176,6 +176,8 @@ data class AllowedModel(
defaultMaxContextLength = llmMaxContextLength,
accelerators = accelerators,
supportThinking = capabilities?.contains(ModelCapability.LLM_THINKING) == true,
+ supportSpeculativeDecoding =
+ capabilities?.contains(ModelCapability.SPECULATIVE_DECODING) == true,
)
})
.toMutableList()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt
index eee5c0ac6..c3c659676 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/LlmModelHelper.kt
@@ -39,6 +39,7 @@ interface LlmModelHelper {
*
* @param context the application context.
* @param model the model to be initialized.
+ * @param taskId the task id where the model is being used.
* @param supportImage whether to support image input.
* @param supportAudio whether to support audio input.
* @param onDone callback invoked when initialization is completed successfully.
@@ -51,6 +52,7 @@ interface LlmModelHelper {
fun initialize(
context: Context,
model: Model,
+ taskId: String,
supportImage: Boolean,
supportAudio: Boolean,
onDone: (String) -> Unit,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt
index 3a3624ccb..9ba60e392 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/runtime/aicore/AICoreModelHelper.kt
@@ -62,6 +62,7 @@ object AICoreModelHelper : LlmModelHelper {
override fun initialize(
context: Context,
model: Model,
+ taskId: String,
supportImage: Boolean,
supportAudio: Boolean,
onDone: (String) -> Unit,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt
index 2756b3889..c2bdb0c12 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/common/ModelPageAppBar.kt
@@ -216,6 +216,13 @@ fun ModelPageAppBar(
if (!task.allowCapability(ModelCapability.LLM_THINKING, model)) {
modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_THINKING }
}
+ var supportsSpeculativeDecoding = false
+ if (
+ !supportsSpeculativeDecoding ||
+ !task.allowCapability(ModelCapability.SPECULATIVE_DECODING, model)
+ ) {
+ modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_SPECULATIVE_DECODING }
+ }
ConfigDialog(
title = "Configurations",
configs = modelConfigs,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
index b6ee6a242..236da9bce 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt
@@ -28,6 +28,7 @@ import com.google.ai.edge.gallery.data.DEFAULT_TOPK
import com.google.ai.edge.gallery.data.DEFAULT_TOPP
import com.google.ai.edge.gallery.data.DEFAULT_VISION_ACCELERATOR
import com.google.ai.edge.gallery.data.Model
+import com.google.ai.edge.gallery.data.ModelCapability
import com.google.ai.edge.gallery.runtime.CleanUpListener
import com.google.ai.edge.gallery.runtime.LlmModelHelper
import com.google.ai.edge.gallery.runtime.ResultListener
@@ -60,6 +61,7 @@ object LlmChatModelHelper : LlmModelHelper {
override fun initialize(
context: Context,
model: Model,
+ taskId: String,
supportImage: Boolean,
supportAudio: Boolean,
onDone: (String) -> Unit,
@@ -120,9 +122,28 @@ object LlmChatModelHelper : LlmModelHelper {
else null,
)
+ // Check if the model file supports speculative decoding.
+ var supportsSpeculativeDecoding = false
// Create an instance of LiteRT LM engine and conversation.
try {
+ var speculativeDecoding = false
+ // Check if the model supports speculative decoding for the given task type and if the
+ // speculative decoding is enabled in the settings.
+ if (
+ supportsSpeculativeDecoding &&
+ model.capabilityToTaskTypes[ModelCapability.SPECULATIVE_DECODING]?.contains(taskId) ==
+ true
+ ) {
+ speculativeDecoding =
+ model.getBooleanConfigValue(
+ key = ConfigKeys.ENABLE_SPECULATIVE_DECODING,
+ defaultValue = false,
+ )
+ }
+ ExperimentalFlags.enableSpeculativeDecoding = speculativeDecoding
+ Log.d(TAG, "Speculative decoding enabled: $speculativeDecoding")
val engine = Engine(engineConfig)
+ ExperimentalFlags.enableSpeculativeDecoding = false
engine.initialize()
ExperimentalFlags.enableConversationConstrainedDecoding =
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt
index 9bcdb49e3..07ff3e85e 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatTaskModule.kt
@@ -80,6 +80,7 @@ class LlmChatTask @Inject constructor() : CustomTask {
model.runtimeHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = false,
supportAudio = false,
onDone = onDone,
@@ -162,6 +163,7 @@ class LlmAskImageTask @Inject constructor() : CustomTask {
model.runtimeHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = true,
supportAudio = false,
onDone = onDone,
@@ -227,6 +229,7 @@ class LlmAskAudioTask @Inject constructor() : CustomTask {
model.runtimeHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = false,
supportAudio = true,
onDone = onDone,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt
index dab3ea7c1..ccf3122d3 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmsingleturn/LlmSingleTurnTaskModule.kt
@@ -61,6 +61,7 @@ class LlmSingleTurnTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
+ taskId = task.id,
supportImage = false,
supportAudio = false,
onDone = onDone,
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt
index 78442588f..d4c19f1e4 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelImportDialog.kt
@@ -94,7 +94,8 @@ private const val TAG = "AGModelImportDialog"
private val SUPPORTED_ACCELERATORS: List =
if (isPixel10()) {
- listOf(Accelerator.CPU)
+ val accelerators = mutableListOf(Accelerator.CPU)
+ accelerators.toList()
} else {
listOf(Accelerator.CPU, Accelerator.GPU, Accelerator.NPU)
}
@@ -136,6 +137,7 @@ private val IMPORT_CONFIGS_LLM: List =
BooleanSwitchConfig(key = ConfigKeys.SUPPORT_TINY_GARDEN, defaultValue = false),
BooleanSwitchConfig(key = ConfigKeys.SUPPORT_MOBILE_ACTIONS, defaultValue = false),
BooleanSwitchConfig(key = ConfigKeys.SUPPORT_THINKING, defaultValue = false),
+ BooleanSwitchConfig(key = ConfigKeys.SUPPORT_SPECULATIVE_DECODING, defaultValue = false),
SegmentedButtonConfig(
key = ConfigKeys.COMPATIBLE_ACCELERATORS,
defaultValue = SUPPORTED_ACCELERATORS[0].label,
@@ -278,6 +280,12 @@ fun ModelImportDialog(
valueType = ValueType.BOOLEAN,
)
as Boolean
+ val supportSpeculativeDecoding =
+ convertValueToTargetType(
+ value = values.get(ConfigKeys.SUPPORT_SPECULATIVE_DECODING.label)!!,
+ valueType = ValueType.BOOLEAN,
+ )
+ as Boolean
val importedModel: ImportedModel =
ImportedModel.newBuilder()
.setFileName(fileName)
@@ -294,6 +302,7 @@ fun ModelImportDialog(
.setSupportMobileActions(supportMobileActions)
.setSupportThinking(supportThinking)
.setSupportTinyGarden(supportTinyGarden)
+ .setSupportSpeculativeDecoding(supportSpeculativeDecoding)
.build()
)
.build()
diff --git a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
index 465d5073d..3cdc49b18 100644
--- a/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
+++ b/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/modelmanager/ModelManagerViewModel.kt
@@ -1189,6 +1189,7 @@ constructor(
val llmSupportTinyGarden = info.llmConfig.supportTinyGarden
val llmSupportMobileActions = info.llmConfig.supportMobileActions
val llmSupportThinking = info.llmConfig.supportThinking
+ val llmSupportSpeculativeDecoding = info.llmConfig.supportSpeculativeDecoding
val configs: MutableList =
createLlmChatConfigs(
defaultMaxToken = llmMaxToken,
@@ -1197,8 +1198,30 @@ constructor(
defaultTemperature = info.llmConfig.defaultTemperature,
accelerators = accelerators,
supportThinking = llmSupportThinking,
+ supportSpeculativeDecoding = llmSupportSpeculativeDecoding,
)
.toMutableList()
+ val capabilities: MutableList = mutableListOf()
+ val capabilityToTaskTypes: MutableMap> = mutableMapOf()
+ if (llmSupportThinking) {
+ capabilities.add(ModelCapability.LLM_THINKING)
+ capabilityToTaskTypes[ModelCapability.LLM_THINKING] =
+ listOf(
+ BuiltInTaskId.LLM_CHAT,
+ BuiltInTaskId.LLM_ASK_IMAGE,
+ BuiltInTaskId.LLM_ASK_AUDIO,
+ )
+ }
+ if (llmSupportSpeculativeDecoding) {
+ capabilities.add(ModelCapability.SPECULATIVE_DECODING)
+ capabilityToTaskTypes[ModelCapability.SPECULATIVE_DECODING] =
+ listOf(
+ BuiltInTaskId.LLM_CHAT,
+ BuiltInTaskId.LLM_ASK_IMAGE,
+ BuiltInTaskId.LLM_ASK_AUDIO,
+ BuiltInTaskId.LLM_PROMPT_LAB,
+ )
+ }
val model =
Model(
name = info.fileName,
@@ -1213,21 +1236,8 @@ constructor(
llmSupportAudio = llmSupportAudio,
llmSupportTinyGarden = llmSupportTinyGarden,
llmSupportMobileActions = llmSupportMobileActions,
- capabilities =
- if (llmSupportThinking) listOf(ModelCapability.LLM_THINKING) else emptyList(),
- capabilityToTaskTypes =
- if (llmSupportThinking) {
- mapOf(
- ModelCapability.LLM_THINKING to
- listOf(
- BuiltInTaskId.LLM_CHAT,
- BuiltInTaskId.LLM_ASK_IMAGE,
- BuiltInTaskId.LLM_ASK_AUDIO,
- )
- )
- } else {
- emptyMap()
- },
+ capabilities = capabilities.toList(),
+ capabilityToTaskTypes = capabilityToTaskTypes.toMap(),
llmMaxToken = llmMaxToken,
accelerators = accelerators,
// We assume all imported models are LLM for now.
diff --git a/Android/src/app/src/main/proto/settings.proto b/Android/src/app/src/main/proto/settings.proto
index 4640bb1b8..c6a505632 100644
--- a/Android/src/app/src/main/proto/settings.proto
+++ b/Android/src/app/src/main/proto/settings.proto
@@ -59,6 +59,7 @@ message LlmConfig {
bool support_tiny_garden = 8;
bool support_mobile_actions = 9;
bool support_thinking = 10;
+ bool support_speculative_decoding = 11;
}
message Settings {