Skip to content

feat(ai): add support for setting a thinking budget #6999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion firebase-ai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Unreleased

* [feature] Added support for configuring the "thinking" budget when using Gemini
2.5 series models. (#6990)

# 16.2.0
* [changed] Deprecate the `totalBillableCharacters` field (only usable with pre-2.0 models). (#7042)
Expand Down
19 changes: 18 additions & 1 deletion firebase-ai/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ package com.google.firebase.ai.type {
method public com.google.firebase.ai.type.GenerationConfig.Builder setResponseSchema(com.google.firebase.ai.type.Schema? responseSchema);
method public com.google.firebase.ai.type.GenerationConfig.Builder setStopSequences(java.util.List<java.lang.String>? stopSequences);
method public com.google.firebase.ai.type.GenerationConfig.Builder setTemperature(Float? temperature);
method public com.google.firebase.ai.type.GenerationConfig.Builder setThinkingConfig(com.google.firebase.ai.type.ThinkingConfig? thinkingConfig);
method public com.google.firebase.ai.type.GenerationConfig.Builder setTopK(Integer? topK);
method public com.google.firebase.ai.type.GenerationConfig.Builder setTopP(Float? topP);
field public Integer? candidateCount;
Expand All @@ -373,6 +374,7 @@ package com.google.firebase.ai.type {
field public com.google.firebase.ai.type.Schema? responseSchema;
field public java.util.List<java.lang.String>? stopSequences;
field public Float? temperature;
field public com.google.firebase.ai.type.ThinkingConfig? thinkingConfig;
field public Integer? topK;
field public Float? topP;
}
Expand Down Expand Up @@ -933,6 +935,19 @@ package com.google.firebase.ai.type {
property public final String text;
}

public final class ThinkingConfig {
}

public static final class ThinkingConfig.Builder {
ctor public ThinkingConfig.Builder();
method public com.google.firebase.ai.type.ThinkingConfig build();
method public com.google.firebase.ai.type.ThinkingConfig.Builder setThinkingBudget(int thinkingBudget);
}

public final class ThinkingConfigKt {
method public static com.google.firebase.ai.type.ThinkingConfig thinkingConfig(kotlin.jvm.functions.Function1<? super com.google.firebase.ai.type.ThinkingConfig.Builder,kotlin.Unit> init);
}

public final class Tool {
method public static com.google.firebase.ai.type.Tool functionDeclarations(java.util.List<com.google.firebase.ai.type.FunctionDeclaration> functionDeclarations);
field public static final com.google.firebase.ai.type.Tool.Companion Companion;
Expand All @@ -953,16 +968,18 @@ package com.google.firebase.ai.type {
}

public final class UsageMetadata {
ctor public UsageMetadata(int promptTokenCount, Integer? candidatesTokenCount, int totalTokenCount, java.util.List<com.google.firebase.ai.type.ModalityTokenCount> promptTokensDetails, java.util.List<com.google.firebase.ai.type.ModalityTokenCount> candidatesTokensDetails);
ctor public UsageMetadata(int promptTokenCount, Integer? candidatesTokenCount, int totalTokenCount, java.util.List<com.google.firebase.ai.type.ModalityTokenCount> promptTokensDetails, java.util.List<com.google.firebase.ai.type.ModalityTokenCount> candidatesTokensDetails, int thoughtsTokenCount);
method public Integer? getCandidatesTokenCount();
method public java.util.List<com.google.firebase.ai.type.ModalityTokenCount> getCandidatesTokensDetails();
method public int getPromptTokenCount();
method public java.util.List<com.google.firebase.ai.type.ModalityTokenCount> getPromptTokensDetails();
method public int getThoughtsTokenCount();
method public int getTotalTokenCount();
property public final Integer? candidatesTokenCount;
property public final java.util.List<com.google.firebase.ai.type.ModalityTokenCount> candidatesTokensDetails;
property public final int promptTokenCount;
property public final java.util.List<com.google.firebase.ai.type.ModalityTokenCount> promptTokensDetails;
property public final int thoughtsTokenCount;
property public final int totalTokenCount;
}

Expand Down
2 changes: 1 addition & 1 deletion firebase-ai/gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

version=16.2.1
version=17.0.0
latestReleasedVersion=16.2.0
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ private constructor(
internal val responseMimeType: String?,
internal val responseSchema: Schema?,
internal val responseModalities: List<ResponseModality>?,
internal val thinkingConfig: ThinkingConfig?,
) {

/**
Expand Down Expand Up @@ -135,6 +136,7 @@ private constructor(
@JvmField public var responseMimeType: String? = null
@JvmField public var responseSchema: Schema? = null
@JvmField public var responseModalities: List<ResponseModality>? = null
@JvmField public var thinkingConfig: ThinkingConfig? = null

public fun setTemperature(temperature: Float?): Builder = apply {
this.temperature = temperature
Expand Down Expand Up @@ -165,6 +167,9 @@ private constructor(
public fun setResponseModalities(responseModalities: List<ResponseModality>?): Builder = apply {
this.responseModalities = responseModalities
}
public fun setThinkingConfig(thinkingConfig: ThinkingConfig?): Builder = apply {
this.thinkingConfig = thinkingConfig
}

/** Create a new [GenerationConfig] with the attached arguments. */
public fun build(): GenerationConfig =
Expand All @@ -179,7 +184,8 @@ private constructor(
frequencyPenalty = frequencyPenalty,
responseMimeType = responseMimeType,
responseSchema = responseSchema,
responseModalities = responseModalities
responseModalities = responseModalities,
thinkingConfig = thinkingConfig
)
}

Expand All @@ -195,7 +201,8 @@ private constructor(
presencePenalty = presencePenalty,
responseMimeType = responseMimeType,
responseSchema = responseSchema?.toInternal(),
responseModalities = responseModalities?.map { it.toInternal() }
responseModalities = responseModalities?.map { it.toInternal() },
thinkingConfig = thinkingConfig?.toInternal()
)

@Serializable
Expand All @@ -210,7 +217,8 @@ private constructor(
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema.Internal? = null,
@SerialName("response_modalities") val responseModalities: List<String>? = null
@SerialName("response_modalities") val responseModalities: List<String>? = null,
@SerialName("thinking_config") val thinkingConfig: ThinkingConfig.Internal? = null
)

public companion object {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.ai.type

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

/** Configuration parameters for thinking features. */
public class ThinkingConfig
private constructor(
internal val thinkingBudget: Int? = null,
) {

public class Builder() {
@JvmField
@set:JvmSynthetic // hide void setter from Java
public var thinkingBudget: Int? = null

/**
* Indicates the thinking budget in tokens. 0 is DISABLED. -1 is AUTOMATIC. The default values
* and allowed ranges are model dependent.
*/
public fun setThinkingBudget(thinkingBudget: Int): Builder = apply {
this.thinkingBudget = thinkingBudget
}

public fun build(): ThinkingConfig = ThinkingConfig(thinkingBudget = thinkingBudget)
}

internal fun toInternal() = Internal(thinkingBudget)

@Serializable
internal data class Internal(@SerialName("thinking_budget") val thinkingBudget: Int?)
}

/**
* Helper method to construct a [ThinkingConfig] in a DSL-like manner.
*
* Example Usage:
* ```
* thinkingConfig {
* thinkingBudget = 0 // disable thinking
* }
* ```
*/
public fun thinkingConfig(init: ThinkingConfig.Builder.() -> Unit): ThinkingConfig {
val builder = ThinkingConfig.Builder()
builder.init()
return builder.build()
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ import kotlinx.serialization.Serializable
* prompt.
* @param candidatesTokensDetails The breakdown, by modality, of how many tokens are consumed by the
* candidates.
* @param thoughtsTokenCount The number of tokens used by the model's internal "thinking" process.
*/
public class UsageMetadata(
public val promptTokenCount: Int,
public val candidatesTokenCount: Int?,
public val totalTokenCount: Int,
public val promptTokensDetails: List<ModalityTokenCount>,
public val candidatesTokensDetails: List<ModalityTokenCount>,
public val thoughtsTokenCount: Int,
) {

@Serializable
Expand All @@ -44,6 +46,7 @@ public class UsageMetadata(
val totalTokenCount: Int? = null,
val promptTokensDetails: List<ModalityTokenCount.Internal>? = null,
val candidatesTokensDetails: List<ModalityTokenCount.Internal>? = null,
val thoughtsTokenCount: Int? = null,
) {

internal fun toPublic(): UsageMetadata =
Expand All @@ -52,7 +55,8 @@ public class UsageMetadata(
candidatesTokenCount ?: 0,
totalTokenCount ?: 0,
promptTokensDetails = promptTokensDetails?.map { it.toPublic() } ?: emptyList(),
candidatesTokensDetails = candidatesTokensDetails?.map { it.toPublic() } ?: emptyList()
candidatesTokensDetails = candidatesTokensDetails?.map { it.toPublic() } ?: emptyList(),
thoughtsTokenCount ?: 0
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.ai.type

import io.kotest.assertions.json.shouldEqualJson
import io.kotest.matchers.equals.shouldBeEqual
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import org.junit.Test

internal class ThinkingConfigTest {

@Test
fun `Basic ThinkingConfig`() {
val thinkingConfig = ThinkingConfig.Builder().setThinkingBudget(1024).build()

val expectedJson =
"""
{
"thinking_budget": 1024
}
"""
.trimIndent()

Json.encodeToString(thinkingConfig.toInternal()).shouldEqualJson(expectedJson)
}

@Test
fun `thinkingConfig DSL correctly delegates to ThinkingConfig#Builder`() {
val thinkingConfig = ThinkingConfig.Builder().setThinkingBudget(1024).build()

val thinkingConfigDsl = thinkingConfig { thinkingBudget = 1024 }

thinkingConfig.thinkingBudget?.shouldBeEqual(thinkingConfigDsl.thinkingBudget as Int)
}
}
2 changes: 1 addition & 1 deletion firebase-ai/update_responses.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# This script replaces mock response files for Vertex AI unit tests with a fresh
# clone of the shared repository of Vertex AI test data.

RESPONSES_VERSION='v13.*' # The major version of mock responses to use
RESPONSES_VERSION='v14.*' # The major version of mock responses to use
REPO_NAME="vertexai-sdk-test-data"
REPO_LINK="https://github.com/FirebaseExtended/$REPO_NAME.git"

Expand Down
Loading