Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import android.graphics.BitmapFactory
import android.graphics.Matrix
import android.net.Uri
import android.os.Build
import android.os.Bundle
import android.util.Log
import androidx.compose.foundation.layout.WindowInsets
import androidx.compose.foundation.layout.ime
Expand All @@ -36,7 +37,9 @@ import androidx.compose.ui.focus.onFocusEvent
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalFocusManager
import androidx.exifinterface.media.ExifInterface
import com.google.ai.edge.gallery.GalleryEvent
import com.google.ai.edge.gallery.data.SAMPLE_RATE
import com.google.ai.edge.gallery.firebaseAnalytics
import com.google.gson.Gson
import java.io.File
import java.io.FileInputStream
Expand Down Expand Up @@ -378,3 +381,14 @@ fun isAICoreSupported(allowedDeviceModels: Set<String>?): Boolean {
val currentModel = Build.MODEL?.lowercase() ?: return false
return allowedDeviceModels.contains(currentModel)
}

fun logErrorToFirebase(event: GalleryEvent, errorType: String, errorMessage: String?) {
firebaseAnalytics?.logEvent(
event.id,
Bundle().apply {
putBoolean("success", false)
putString("error_type", errorType)
putString("error_message", errorMessage ?: "Unknown error")
},
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class AgentChatTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
taskId = task.id,
supportImage = true,
supportAudio = true,
onDone = onDone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import kotlinx.coroutines.runBlocking

private const val TAG = "AGAgentTools"

class AgentTools() : ToolSet {
open class AgentTools() : ToolSet {
lateinit var context: Context
lateinit var skillManagerViewModel: SkillManagerViewModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.platform.LocalUriHandler
import androidx.compose.ui.res.pluralStringResource
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.semantics.contentDescription
import androidx.compose.ui.semantics.semantics
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextDecoration
import androidx.compose.ui.unit.dp
Expand Down Expand Up @@ -173,7 +175,7 @@ fun SkillManagerBottomSheet(
var addSkillOptionTypeToConfirm by remember { mutableStateOf<AddSkillOptionType?>(null) }
var skillToEditIndex by remember { mutableIntStateOf(-1) }
var searchQuery by remember { mutableStateOf("") }
var savedSelectedSkillsNamesAndDescriptions = remember { "" }
var savedSelectedSkillsNamesAndDescriptions by remember { mutableStateOf("") }
var filteredSkills by remember { mutableStateOf(uiState.skills) }
val listState = rememberLazyListState()
val uriHandler = LocalUriHandler.current
Expand Down Expand Up @@ -860,7 +862,8 @@ private fun SkillItemRow(
Switch(
checked = skill.selected,
onCheckedChange = onSkillEnabledChange,
modifier = Modifier.offset(y = (-4).dp),
modifier =
Modifier.offset(y = (-4).dp).semantics { contentDescription = "Toggle ${skill.name}" },
enabled = !inMultiSelectMode,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,23 @@ val TRYOUT_CHIPS: List<SkillTryOutChip> =
),
)

enum class SkillSource(val sourceName: String) {
BUILTIN("builtin"),
FEATURED("featured"),
REMOTE_URL("remote_url"),
LOCAL_IMPORT("local_import"),
UNKNOWN("unknown"),
}

enum class SkillAction(val value: String) {
ADD("add"),
DELETE("delete"),
ENABLE("enable"),
DISABLE("disable"),
ENABLE_ALL("enable_all"),
DISABLE_ALL("disable_all"),
}

data class SkillState(val skill: Skill)

data class SkillManagerUiState(
Expand Down Expand Up @@ -341,13 +358,7 @@ constructor(
Log.d(TAG, "Successfully added skill from URL: ${skill.name}")
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
Bundle().apply {
putString("action", "add")
putString("source", "remote_url")
putString("skill_name", getSkillNameForLogging(skill))
putBoolean("is_built_in", skill.builtIn)
putString("remote_url", url)
},
getSkillLoggingParams(skill).apply { putString("action", SkillAction.ADD.value) },
)
onSuccess()
}
Expand Down Expand Up @@ -532,12 +543,7 @@ constructor(
Log.d(TAG, "Successfully added skill from local import: ${skillWithDir.name}")
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
Bundle().apply {
putString("action", "add")
putString("source", "local_import")
putString("skill_name", getSkillNameForLogging(skillWithDir))
putBoolean("is_built_in", skillWithDir.builtIn)
},
getSkillLoggingParams(skillWithDir).apply { putString("action", SkillAction.ADD.value) },
)
onSuccess()
}
Expand Down Expand Up @@ -603,15 +609,14 @@ constructor(
return
}

val skillNameToLog = getSkillNameForLogging(skill)
Log.d(TAG, "Analytics: skill_management, action=delete, skill_name=${skillNameToLog}")
val loggingParams = getSkillLoggingParams(skill)
Log.d(
TAG,
"Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams",
)
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
Bundle().apply {
putString("action", "delete")
putString("skill_name", skillNameToLog)
putBoolean("is_built_in", skill.builtIn)
},
loggingParams.apply { putString("action", SkillAction.DELETE.value) },
)

// Update state.
Expand Down Expand Up @@ -643,17 +648,14 @@ constructor(
}

for (skill in skillsToDelete) {
val loggingParams = getSkillLoggingParams(skill)
Log.d(
TAG,
"Analytics: skill_management, action=delete, skill_name=${getSkillNameForLogging(skill)}",
"Analytics: skill_management, action=${SkillAction.DELETE.value}, params=$loggingParams",
)
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
Bundle().apply {
putString("action", "delete")
putString("skill_name", getSkillNameForLogging(skill))
putBoolean("is_built_in", skill.builtIn)
},
loggingParams.apply { putString("action", SkillAction.DELETE.value) },
)
}

Expand Down Expand Up @@ -686,10 +688,8 @@ constructor(

firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
Bundle().apply {
putString("action", if (selected) "enable" else "disable")
putString("skill_name", getSkillNameForLogging(skill.skill))
putBoolean("is_built_in", skill.skill.builtIn)
getSkillLoggingParams(skill.skill).apply {
putString("action", if (selected) SkillAction.ENABLE.value else SkillAction.DISABLE.value)
},
)
val updatedSkills =
Expand Down Expand Up @@ -720,11 +720,16 @@ constructor(

Log.d(
TAG,
"Analytics: skill_management, action=${if (selected) "enable_all" else "disable_all"}",
"Analytics: skill_management, action=${if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value}",
)
firebaseAnalytics?.logEvent(
GalleryEvent.SKILL_MANAGEMENT.id,
Bundle().apply { putString("action", if (selected) "enable_all" else "disable_all") },
Bundle().apply {
putString(
"action",
if (selected) SkillAction.ENABLE_ALL.value else SkillAction.DISABLE_ALL.value,
)
},
)

// Update data store.
Expand Down Expand Up @@ -1166,15 +1171,36 @@ constructor(
dataStoreRepository.setSkills(updatedList)
}

private fun getSkillNameForLogging(skill: Skill): String {
private fun getSkillSource(skill: Skill): SkillSource {
val isFeatured =
skill.skillUrl.isNotEmpty() &&
_uiState.value.featuredSkills.any { it.skillUrl == skill.skillUrl }
return if (skill.builtIn || isFeatured) {
skill.name
} else {
"custom_skill"
return when {
skill.builtIn -> SkillSource.BUILTIN
isFeatured -> SkillSource.FEATURED
skill.skillUrl.isNotEmpty() -> SkillSource.REMOTE_URL
skill.importDirName.isNotEmpty() -> SkillSource.LOCAL_IMPORT
else -> SkillSource.UNKNOWN
}
}

private fun getSkillLoggingParams(skill: Skill): Bundle {
val source = getSkillSource(skill)
val skillName =
if (source == SkillSource.BUILTIN || source == SkillSource.FEATURED) skill.name
else "custom_skill"
val bundle =
Bundle().apply {
putString("source", source.sourceName)
putString("skill_name", skillName)
}
if (
skill.skillUrl.isNotEmpty() &&
(source == SkillSource.REMOTE_URL || source == SkillSource.FEATURED)
) {
bundle.putString("remote_url", skill.skillUrl.take(100))
}
return bundle
}

private fun getSkillDestinationDir(originalImportDirName: String): File {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class MobileActionsTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
taskId = task.id,
supportImage = false,
supportAudio = false,
onDone = onDone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import androidx.core.net.toUri
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.BuiltInTaskId
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.llmchat.LlmChatModelHelper
import com.google.ai.edge.gallery.ui.llmchat.LlmModelInstance
Expand Down Expand Up @@ -211,6 +212,7 @@ constructor(@ApplicationContext private val appContext: Context) : ViewModel() {
LlmChatModelHelper.initialize(
context = context,
model = model,
taskId = BuiltInTaskId.LLM_MOBILE_ACTIONS,
supportImage = false,
supportAudio = false,
onDone = { error ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TinyGardenTask @Inject constructor() : CustomTask {
LlmChatModelHelper.initialize(
context = context,
model = model,
taskId = task.id,
supportImage = false,
supportAudio = false,
onDone = onDone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.ai.edge.gallery.R
import com.google.ai.edge.gallery.data.BuiltInTaskId
import com.google.ai.edge.gallery.data.DataStoreRepository
import com.google.ai.edge.gallery.data.Model
import com.google.ai.edge.gallery.ui.common.chat.ChatMessage
Expand Down Expand Up @@ -168,6 +169,7 @@ constructor(
LlmChatModelHelper.initialize(
context = context,
model = model,
taskId = BuiltInTaskId.LLM_TINY_GARDEN,
supportImage = false,
supportAudio = false,
onDone = { error ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ object ConfigKeys {
val SUPPORT_TINY_GARDEN = ConfigKey("support_tiny_garden", "Support tiny garden")
val SUPPORT_MOBILE_ACTIONS = ConfigKey("support_mobile_actions", "Support mobile actions")
val SUPPORT_THINKING = ConfigKey("support_thinking", "Support thinking")
val SUPPORT_SPECULATIVE_DECODING =
ConfigKey("support_speculative_decoding", "Support speculative decoding")
val ENABLE_THINKING = ConfigKey("enable_thinking", "Enable thinking")
val ENABLE_SPECULATIVE_DECODING =
ConfigKey("enable_speculative_decoding", "Enable speculative decoding")
val MAX_RESULT_COUNT = ConfigKey("max_result_count", "Max result count")
val USE_GPU = ConfigKey("use_gpu", "Use GPU")
val ACCELERATOR = ConfigKey("accelerator", "Accelerator")
Expand Down Expand Up @@ -226,6 +230,7 @@ fun createLlmChatConfigs(
defaultTemperature: Float = DEFAULT_TEMPERATURE,
accelerators: List<Accelerator> = DEFAULT_ACCELERATORS,
supportThinking: Boolean = false,
supportSpeculativeDecoding: Boolean = false,
): List<Config> {
var maxTokensConfig: Config =
LabelConfig(key = ConfigKeys.MAX_TOKENS, defaultValue = "$defaultMaxToken")
Expand Down Expand Up @@ -274,6 +279,11 @@ fun createLlmChatConfigs(
if (supportThinking) {
configs.add(BooleanSwitchConfig(key = ConfigKeys.ENABLE_THINKING, defaultValue = false))
}
if (supportSpeculativeDecoding) {
configs.add(
BooleanSwitchConfig(key = ConfigKeys.ENABLE_SPECULATIVE_DECODING, defaultValue = false)
)
}
return configs
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private val NORMALIZE_NAME_REGEX = Regex("[^a-zA-Z0-9]")
data class PromptTemplate(val title: String, val description: String, val prompt: String)

enum class ModelCapability {
@SerializedName("llm_thinking") LLM_THINKING
@SerializedName("llm_thinking") LLM_THINKING,
@SerializedName("speculative_decoding") SPECULATIVE_DECODING,
}

enum class RuntimeType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ data class AllowedModel(
defaultMaxContextLength = llmMaxContextLength,
accelerators = accelerators,
supportThinking = capabilities?.contains(ModelCapability.LLM_THINKING) == true,
supportSpeculativeDecoding =
capabilities?.contains(ModelCapability.SPECULATIVE_DECODING) == true,
)
})
.toMutableList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ interface LlmModelHelper {
*
* @param context the application context.
* @param model the model to be initialized.
* @param taskId the task id where the model is being used.
* @param supportImage whether to support image input.
* @param supportAudio whether to support audio input.
* @param onDone callback invoked when initialization is completed successfully.
Expand All @@ -51,6 +52,7 @@ interface LlmModelHelper {
fun initialize(
context: Context,
model: Model,
taskId: String,
supportImage: Boolean,
supportAudio: Boolean,
onDone: (String) -> Unit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object AICoreModelHelper : LlmModelHelper {
override fun initialize(
context: Context,
model: Model,
taskId: String,
supportImage: Boolean,
supportAudio: Boolean,
onDone: (String) -> Unit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ fun ModelPageAppBar(
if (!task.allowCapability(ModelCapability.LLM_THINKING, model)) {
modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_THINKING }
}
var supportsSpeculativeDecoding = false
if (
!supportsSpeculativeDecoding ||
!task.allowCapability(ModelCapability.SPECULATIVE_DECODING, model)
) {
modelConfigs.removeIf { it.key == ConfigKeys.ENABLE_SPECULATIVE_DECODING }
}
ConfigDialog(
title = "Configurations",
configs = modelConfigs,
Expand Down
Loading
Loading