diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 7d93b54360..d1862ceba2 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -38,6 +38,7 @@ from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager +from ..memory.base_memory_service import BaseMemoryService from ..memory.in_memory_memory_service import InMemoryMemoryService from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..runners import Runner @@ -60,6 +61,7 @@ def get_fast_api_app( session_db_kwargs: Optional[Mapping[str, Any]] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, + memory_service: Optional[BaseMemoryService] = None, eval_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, @@ -107,31 +109,32 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name): return project, location, agent_engine_id # Build the Memory service - if memory_service_uri: - if memory_service_uri.startswith("rag://"): - from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService - - rag_corpus = memory_service_uri.split("://")[1] - if not rag_corpus: - raise click.ClickException("Rag corpus can not be empty.") - envs.load_dotenv_for_agent("", agents_dir) - memory_service = VertexAiRagMemoryService( - rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}' - ) - elif memory_service_uri.startswith("agentengine://"): - agent_engine_id_or_resource_name = memory_service_uri.split("://")[1] - project, location, agent_engine_id = _parse_agent_engine_resource_name( - agent_engine_id_or_resource_name - ) - memory_service = VertexAiMemoryBankService( - project=project, - location=location, - agent_engine_id=agent_engine_id, - ) - else: - raise click.ClickException( - "Unsupported memory service URI: %s" % memory_service_uri - ) + if not memory_service: + if memory_service_uri: + if memory_service_uri.startswith("rag://"): + from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService + + rag_corpus = memory_service_uri.split("://")[1] + if not rag_corpus: + raise click.ClickException("Rag corpus can not be empty.") + envs.load_dotenv_for_agent("", agents_dir) + memory_service = VertexAiRagMemoryService( + rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}' + ) + elif memory_service_uri.startswith("agentengine://"): + agent_engine_id_or_resource_name = memory_service_uri.split("://")[1] + project, location, agent_engine_id = _parse_agent_engine_resource_name( + agent_engine_id_or_resource_name + ) + memory_service = VertexAiMemoryBankService( + project=project, + location=location, + agent_engine_id=agent_engine_id, + ) + else: + raise click.ClickException( + "Unsupported memory service URI: %s" % memory_service_uri + ) else: memory_service = InMemoryMemoryService()