Skip to content

Commit e194ebb

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Implement artifact_version related methods in GcsArtifactService
PiperOrigin-RevId: 824646770
1 parent 1a4261a commit e194ebb

File tree

2 files changed

+196
-57
lines changed

2 files changed

+196
-57
lines changed

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 120 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ def _file_has_user_namespace(self, filename: str) -> bool:
148148
"""
149149
return filename.startswith("user:")
150150

151+
def _get_blob_prefix(
152+
self,
153+
app_name: str,
154+
user_id: str,
155+
filename: str,
156+
session_id: Optional[str] = None,
157+
) -> str:
158+
"""Constructs the blob name prefix in GCS for a given artifact."""
159+
if self._file_has_user_namespace(filename):
160+
return f"{app_name}/{user_id}/user/{filename}"
161+
162+
if session_id is None:
163+
raise ValueError(
164+
"Session ID must be provided for session-scoped artifacts."
165+
)
166+
return f"{app_name}/{user_id}/{session_id}/{filename}"
167+
151168
def _get_blob_name(
152169
self,
153170
app_name: str,
@@ -168,14 +185,9 @@ def _get_blob_name(
168185
Returns:
169186
The constructed blob name in GCS.
170187
"""
171-
if self._file_has_user_namespace(filename):
172-
return f"{app_name}/{user_id}/user/{filename}/{version}"
173-
174-
if session_id is None:
175-
raise ValueError(
176-
"Session ID must be provided for session-scoped artifacts."
177-
)
178-
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
188+
return (
189+
f"{self._get_blob_prefix(app_name, user_id, filename, session_id)}/{version}"
190+
)
179191

180192
def _save_artifact(
181193
self,
@@ -186,10 +198,6 @@ def _save_artifact(
186198
artifact: types.Part,
187199
custom_metadata: Optional[dict[str, Any]] = None,
188200
) -> int:
189-
if custom_metadata:
190-
# TODO: b/447451270 - support saving artifact with custom metadata.
191-
raise NotImplementedError("custom_metadata is not supported yet.")
192-
193201
versions = self._list_versions(
194202
app_name=app_name,
195203
user_id=user_id,
@@ -202,6 +210,8 @@ def _save_artifact(
202210
app_name, user_id, filename, version, session_id
203211
)
204212
blob = self.bucket.blob(blob_name)
213+
if custom_metadata:
214+
blob.metadata = {k: str(v) for k, v in custom_metadata.items()}
205215

206216
if artifact.inline_data:
207217
blob.upload_from_string(
@@ -211,6 +221,7 @@ def _save_artifact(
211221
elif artifact.text:
212222
blob.upload_from_string(
213223
data=artifact.text,
224+
content_type="text/plain",
214225
)
215226
elif artifact.file_data:
216227
raise NotImplementedError(
@@ -265,15 +276,22 @@ def _list_artifact_keys(
265276
self.bucket, prefix=session_prefix
266277
)
267278
for blob in session_blobs:
268-
*_, filename, _ = blob.name.split("/")
279+
# blob.name is like session_prefix/filename/version
280+
# or session_prefix/path/to/filename/version
281+
# we need to extract filename including slashes, but remove prefix
282+
# and /version
283+
fn_and_version = blob.name[len(session_prefix) :]
284+
filename = "/".join(fn_and_version.split("/")[:-1])
269285
filenames.add(filename)
270286

271287
user_namespace_prefix = f"{app_name}/{user_id}/user/"
272288
user_namespace_blobs = self.storage_client.list_blobs(
273289
self.bucket, prefix=user_namespace_prefix
274290
)
275291
for blob in user_namespace_blobs:
276-
*_, filename, _ = blob.name.split("/")
292+
# blob.name is like user_namespace_prefix/filename/version
293+
fn_and_version = blob.name[len(user_namespace_prefix) :]
294+
filename = "/".join(fn_and_version.split("/")[:-1])
277295
filenames.add(filename)
278296

279297
return sorted(list(filenames))
@@ -323,14 +341,85 @@ def _list_versions(
323341
artifact.
324342
Returns an empty list if no versions are found.
325343
"""
326-
prefix = self._get_blob_name(app_name, user_id, filename, "", session_id)
327-
blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
344+
prefix = self._get_blob_prefix(app_name, user_id, filename, session_id)
345+
blobs = self.storage_client.list_blobs(self.bucket, prefix=f"{prefix}/")
328346
versions = []
329347
for blob in blobs:
330348
*_, version = blob.name.split("/")
331349
versions.append(int(version))
332350
return versions
333351

352+
def _get_artifact_version_sync(
353+
self,
354+
app_name: str,
355+
user_id: str,
356+
session_id: Optional[str],
357+
filename: str,
358+
version: Optional[int] = None,
359+
) -> Optional[ArtifactVersion]:
360+
if version is None:
361+
versions = self._list_versions(
362+
app_name=app_name,
363+
user_id=user_id,
364+
session_id=session_id,
365+
filename=filename,
366+
)
367+
if not versions:
368+
return None
369+
version = max(versions)
370+
371+
blob_name = self._get_blob_name(
372+
app_name, user_id, filename, version, session_id
373+
)
374+
blob = self.bucket.get_blob(blob_name)
375+
376+
if not blob:
377+
return None
378+
379+
canonical_uri = f"gs://{self.bucket_name}/{blob.name}"
380+
381+
return ArtifactVersion(
382+
version=version,
383+
canonical_uri=canonical_uri,
384+
create_time=blob.time_created.timestamp(),
385+
mime_type=blob.content_type,
386+
custom_metadata=blob.metadata if blob.metadata else {},
387+
)
388+
389+
def _list_artifact_versions_sync(
390+
self,
391+
app_name: str,
392+
user_id: str,
393+
session_id: Optional[str],
394+
filename: str,
395+
) -> list[ArtifactVersion]:
396+
"""Lists all versions and their metadata of an artifact."""
397+
prefix = self._get_blob_prefix(app_name, user_id, filename, session_id)
398+
blobs = self.storage_client.list_blobs(self.bucket, prefix=f"{prefix}/")
399+
artifact_versions = []
400+
for blob in blobs:
401+
try:
402+
version = int(blob.name.split("/")[-1])
403+
except ValueError:
404+
logger.warning(
405+
"Skipping blob %s because it does not end with a version number.",
406+
blob.name,
407+
)
408+
continue
409+
410+
canonical_uri = f"gs://{self.bucket_name}/{blob.name}"
411+
av = ArtifactVersion(
412+
version=version,
413+
canonical_uri=canonical_uri,
414+
create_time=blob.time_created.timestamp(),
415+
mime_type=blob.content_type,
416+
custom_metadata=blob.metadata if blob.metadata else {},
417+
)
418+
artifact_versions.append(av)
419+
420+
artifact_versions.sort(key=lambda x: x.version)
421+
return artifact_versions
422+
334423
@override
335424
async def list_artifact_versions(
336425
self,
@@ -340,8 +429,13 @@ async def list_artifact_versions(
340429
filename: str,
341430
session_id: Optional[str] = None,
342431
) -> list[ArtifactVersion]:
343-
# TODO: b/447451270 - Support list_artifact_versions.
344-
raise NotImplementedError("list_artifact_versions is not implemented yet.")
432+
return await asyncio.to_thread(
433+
self._list_artifact_versions_sync,
434+
app_name,
435+
user_id,
436+
session_id,
437+
filename,
438+
)
345439

346440
@override
347441
async def get_artifact_version(
@@ -353,5 +447,11 @@ async def get_artifact_version(
353447
session_id: Optional[str] = None,
354448
version: Optional[int] = None,
355449
) -> Optional[ArtifactVersion]:
356-
# TODO: b/447451270 - Support get_artifact_version.
357-
raise NotImplementedError("get_artifact_version is not implemented yet.")
450+
return await asyncio.to_thread(
451+
self._get_artifact_version_sync,
452+
app_name,
453+
user_id,
454+
session_id,
455+
filename,
456+
version,
457+
)

0 commit comments

Comments
 (0)