Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from ..evaluation.eval_metrics import EvalMetric
from ..evaluation.eval_metrics import EvalMetricResult
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
from ..evaluation.eval_metrics import MetricInfo
from ..evaluation.eval_result import EvalSetResult
from ..evaluation.eval_set_results_manager import EvalSetResultsManager
from ..evaluation.eval_sets_manager import EvalSetsManager
Expand Down Expand Up @@ -697,6 +698,24 @@ def list_eval_results(app_name: str) -> list[str]:
"""Lists all eval results for the given app."""
return self.eval_set_results_manager.list_eval_set_results(app_name)

@app.get(
"/apps/{app_name}/eval_metrics",
response_model_exclude_none=True,
)
def list_eval_metrics(app_name: str) -> list[MetricInfo]:
"""Lists all eval metrics for the given app."""
try:
from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY

# Right now we ignore the app_name as eval metrics are not tied to the
# app_name, but they could be moving forward.
return DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
except ModuleNotFoundError as e:
logger.exception("%s\n%s", MISSING_EVAL_DEPENDENCIES_MESSAGE, e)
raise HTTPException(
status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE
) from e

@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
async def delete_session(app_name: str, user_id: str, session_id: str):
await self.session_service.delete_session(
Expand Down
109 changes: 93 additions & 16 deletions src/google/adk/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,22 @@ class JudgeModelOptions(BaseModel):

judge_model: str = Field(
default="gemini-2.5-flash",
description="""The judge model to use for evaluation. It can be a model name.""",
description=(
"The judge model to use for evaluation. It can be a model name."
),
)

judge_model_config: Optional[genai_types.GenerateContentConfig] = Field(
default=None, description="""The configuration for the judge model."""
default=None,
description="The configuration for the judge model.",
)

num_samples: Optional[int] = Field(
default=None,
description="""The number of times to sample the model for each invocation evaluation.""",
description=(
"The number of times to sample the model for each invocation"
" evaluation."
),
)


Expand All @@ -70,15 +76,20 @@ class EvalMetric(BaseModel):
populate_by_name=True,
)

metric_name: str
"""The name of the metric."""
metric_name: str = Field(
description="The name of the metric.",
)

threshold: float
"""A threshold value. Each metric decides how to interpret this threshold."""
threshold: float = Field(
description=(
"A threshold value. Each metric decides how to interpret this"
" threshold."
),
)

judge_model_options: Optional[JudgeModelOptions] = Field(
default=None,
description="""Options for the judge model.""",
description="Options for the judge model.",
)


Expand All @@ -90,8 +101,14 @@ class EvalMetricResult(EvalMetric):
populate_by_name=True,
)

score: Optional[float] = None
eval_status: EvalStatus
score: Optional[float] = Field(
default=None,
description=(
"Score obtained after evaluating the metric. Optional, as evaluation"
" might not have happened."
),
)
eval_status: EvalStatus = Field(description="The status of this evaluation.")


class EvalMetricResultPerInvocation(BaseModel):
Expand All @@ -102,11 +119,71 @@ class EvalMetricResultPerInvocation(BaseModel):
populate_by_name=True,
)

actual_invocation: Invocation
"""The actual invocation, usually obtained by inferencing the agent."""
actual_invocation: Invocation = Field(
description=(
"The actual invocation, usually obtained by inferencing the agent."
)
)

expected_invocation: Invocation = Field(
description=(
"The expected invocation, usually the reference or golden invocation."
)
)

expected_invocation: Invocation
"""The expected invocation, usually the reference or golden invocation."""
eval_metric_results: list[EvalMetricResult] = Field(
default=[],
description="Eval resutls for each applicable metric.",
)


class Interval(BaseModel):
"""Represents a range of numeric values, e.g. [0 ,1] or (2,3) or [-1, 6)."""

min_value: float = Field(description="The smaller end of the interval.")

open_at_min: bool = Field(
default=False,
description=(
"The interval is Open on the min end. The default value is False,"
" which means that we assume that the interval is Closed."
),
)

max_value: float = Field(description="The larger end of the interval.")

open_at_max: bool = Field(
default=False,
description=(
"The interval is Open on the max end. The default value is False,"
" which means that we assume that the interval is Closed."
),
)

eval_metric_results: list[EvalMetricResult] = []
"""Eval resutls for each applicable metric."""

class MetricValueInfo(BaseModel):
"""Information about the type of metric value."""

interval: Optional[Interval] = Field(
default=None,
description="The values represented by the metric are of type interval.",
)


class MetricInfo(BaseModel):
"""Information about the metric that are used for Evals."""

model_config = ConfigDict(
alias_generator=alias_generators.to_camel,
populate_by_name=True,
)

metric_name: str = Field(description="The name of the metric.")

description: str = Field(
default=None, description="A 2 to 3 line description of the metric."
)

metric_value_info: MetricValueInfo = Field(
description="Information on the nature of values supported by the metric."
)
23 changes: 22 additions & 1 deletion src/google/adk/evaluation/final_response_match_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,39 @@

from .eval_case import Invocation
from .eval_metrics import EvalMetric
from .eval_metrics import Interval
from .eval_metrics import MetricInfo
from .eval_metrics import MetricValueInfo
from .eval_metrics import PrebuiltMetrics
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import Evaluator
from .evaluator import PerInvocationResult


class RougeEvaluator(Evaluator):
"""Calculates the ROUGE-1 metric to compare responses."""
"""Evaluates if agent's final response matches a golden/expected final response using Rouge_1 metric.

Value range for this metric is [0,1], with values closer to 1 more desirable.
"""

def __init__(self, eval_metric: EvalMetric):
self._eval_metric = eval_metric

@staticmethod
def get_metric_info() -> MetricInfo:
return MetricInfo(
metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE.value,
description=(
"This metric evaluates if the agent's final response matches a"
" golden/expected final response using Rouge_1 metric. Value range"
" for this metric is [0,1], with values closer to 1 more desirable."
),
metric_value_info=MetricValueInfo(
interval=Interval(min_value=0.0, max_value=1.0)
),
)

@override
def evaluate_invocations(
self,
Expand Down
21 changes: 19 additions & 2 deletions src/google/adk/evaluation/final_response_match_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from ..utils.feature_decorator import experimental
from .eval_case import Invocation
from .eval_metrics import EvalMetric
from .eval_metrics import Interval
from .eval_metrics import MetricInfo
from .eval_metrics import MetricValueInfo
from .eval_metrics import PrebuiltMetrics
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import PerInvocationResult
Expand Down Expand Up @@ -146,6 +150,20 @@ def __init__(
if self._eval_metric.judge_model_options.num_samples is None:
self._eval_metric.judge_model_options.num_samples = _DEFAULT_NUM_SAMPLES

@staticmethod
def get_metric_info() -> MetricInfo:
return MetricInfo(
metric_name=PrebuiltMetrics.FINAL_RESPONSE_MATCH_V2.value,
description=(
"This metric evaluates if the agent's final response matches a"
" golden/expected final response using LLM as a judge. Value range"
" for this metric is [0,1], with values closer to 1 more desirable."
),
metric_value_info=MetricValueInfo(
interval=Interval(min_value=0.0, max_value=1.0)
),
)

@override
def format_auto_rater_prompt(
self, actual_invocation: Invocation, expected_invocation: Invocation
Expand Down Expand Up @@ -185,8 +203,7 @@ def aggregate_per_invocation_samples(
tie, consider the result to be invalid.

Args:
per_invocation_samples: Samples of per-invocation results to
aggregate.
per_invocation_samples: Samples of per-invocation results to aggregate.

Returns:
If there is a majority of valid results, return the first valid result.
Expand Down
40 changes: 30 additions & 10 deletions src/google/adk/evaluation/metric_evaluator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import logging

from ..errors.not_found_error import NotFoundError
from ..utils.feature_decorator import experimental
from .eval_metrics import EvalMetric
from .eval_metrics import MetricInfo
from .eval_metrics import MetricName
from .eval_metrics import PrebuiltMetrics
from .evaluator import Evaluator
Expand All @@ -29,10 +31,11 @@
logger = logging.getLogger("google_adk." + __name__)


@experimental
class MetricEvaluatorRegistry:
"""A registry for metric Evaluators."""

_registry: dict[str, type[Evaluator]] = {}
_registry: dict[str, tuple[type[Evaluator], MetricInfo]] = {}

def get_evaluator(self, eval_metric: EvalMetric) -> Evaluator:
"""Returns an Evaluator for the given metric.
Expand All @@ -48,15 +51,18 @@ def get_evaluator(self, eval_metric: EvalMetric) -> Evaluator:
if eval_metric.metric_name not in self._registry:
raise NotFoundError(f"{eval_metric.metric_name} not found in registry.")

return self._registry[eval_metric.metric_name](eval_metric=eval_metric)
return self._registry[eval_metric.metric_name][0](eval_metric=eval_metric)

def register_evaluator(
self, metric_name: MetricName, evaluator: type[Evaluator]
self,
metric_info: MetricInfo,
evaluator: type[Evaluator],
):
"""Registers an evaluator given the metric name.
"""Registers an evaluator given the metric info.

If a mapping already exist, then it is updated.
"""
metric_name = metric_info.metric_name
if metric_name in self._registry:
logger.info(
"Updating Evaluator class for %s from %s to %s",
Expand All @@ -65,31 +71,45 @@ def register_evaluator(
evaluator,
)

self._registry[str(metric_name)] = evaluator
self._registry[str(metric_name)] = (evaluator, metric_info)

def get_registered_metrics(
self,
) -> list[MetricInfo]:
"""Returns a list of MetricInfo about the metrics registered so far."""
return [
evaluator_and_metric_info[1].model_copy(deep=True)
for _, evaluator_and_metric_info in self._registry.items()
]


def _get_default_metric_evaluator_registry() -> MetricEvaluatorRegistry:
"""Returns an instance of MetricEvaluatorRegistry with standard metrics already registered in it."""
metric_evaluator_registry = MetricEvaluatorRegistry()

metric_evaluator_registry.register_evaluator(
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
metric_info=TrajectoryEvaluator.get_metric_info(),
evaluator=TrajectoryEvaluator,
)

metric_evaluator_registry.register_evaluator(
metric_name=PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value,
metric_info=ResponseEvaluator.get_metric_info(
PrebuiltMetrics.RESPONSE_EVALUATION_SCORE.value
),
evaluator=ResponseEvaluator,
)
metric_evaluator_registry.register_evaluator(
metric_name=PrebuiltMetrics.RESPONSE_MATCH_SCORE.value,
metric_info=ResponseEvaluator.get_metric_info(
PrebuiltMetrics.RESPONSE_MATCH_SCORE.value
),
evaluator=ResponseEvaluator,
)
metric_evaluator_registry.register_evaluator(
metric_name=PrebuiltMetrics.SAFETY_V1.value,
metric_info=SafetyEvaluatorV1.get_metric_info(),
evaluator=SafetyEvaluatorV1,
)
metric_evaluator_registry.register_evaluator(
metric_name=PrebuiltMetrics.FINAL_RESPONSE_MATCH_V2.value,
metric_info=FinalResponseMatchV2Evaluator.get_metric_info(),
evaluator=FinalResponseMatchV2Evaluator,
)

Expand Down
Loading