@@ -148,6 +148,23 @@ def _file_has_user_namespace(self, filename: str) -> bool:
148148 """
149149 return filename .startswith ("user:" )
150150
151+ def _get_blob_prefix (
152+ self ,
153+ app_name : str ,
154+ user_id : str ,
155+ filename : str ,
156+ session_id : Optional [str ] = None ,
157+ ) -> str :
158+ """Constructs the blob name prefix in GCS for a given artifact."""
159+ if self ._file_has_user_namespace (filename ):
160+ return f"{ app_name } /{ user_id } /user/{ filename } "
161+
162+ if session_id is None :
163+ raise ValueError (
164+ "Session ID must be provided for session-scoped artifacts."
165+ )
166+ return f"{ app_name } /{ user_id } /{ session_id } /{ filename } "
167+
151168 def _get_blob_name (
152169 self ,
153170 app_name : str ,
@@ -168,14 +185,9 @@ def _get_blob_name(
168185 Returns:
169186 The constructed blob name in GCS.
170187 """
171- if self ._file_has_user_namespace (filename ):
172- return f"{ app_name } /{ user_id } /user/{ filename } /{ version } "
173-
174- if session_id is None :
175- raise ValueError (
176- "Session ID must be provided for session-scoped artifacts."
177- )
178- return f"{ app_name } /{ user_id } /{ session_id } /{ filename } /{ version } "
188+ return (
189+ f"{ self ._get_blob_prefix (app_name , user_id , filename , session_id )} /{ version } "
190+ )
179191
180192 def _save_artifact (
181193 self ,
@@ -186,10 +198,6 @@ def _save_artifact(
186198 artifact : types .Part ,
187199 custom_metadata : Optional [dict [str , Any ]] = None ,
188200 ) -> int :
189- if custom_metadata :
190- # TODO: b/447451270 - support saving artifact with custom metadata.
191- raise NotImplementedError ("custom_metadata is not supported yet." )
192-
193201 versions = self ._list_versions (
194202 app_name = app_name ,
195203 user_id = user_id ,
@@ -202,6 +210,8 @@ def _save_artifact(
202210 app_name , user_id , filename , version , session_id
203211 )
204212 blob = self .bucket .blob (blob_name )
213+ if custom_metadata :
214+ blob .metadata = {k : str (v ) for k , v in custom_metadata .items ()}
205215
206216 if artifact .inline_data :
207217 blob .upload_from_string (
@@ -211,6 +221,7 @@ def _save_artifact(
211221 elif artifact .text :
212222 blob .upload_from_string (
213223 data = artifact .text ,
224+ content_type = "text/plain" ,
214225 )
215226 elif artifact .file_data :
216227 raise NotImplementedError (
@@ -265,15 +276,22 @@ def _list_artifact_keys(
265276 self .bucket , prefix = session_prefix
266277 )
267278 for blob in session_blobs :
268- * _ , filename , _ = blob .name .split ("/" )
279+ # blob.name is like session_prefix/filename/version
280+ # or session_prefix/path/to/filename/version
281+ # we need to extract filename including slashes, but remove prefix
282+ # and /version
283+ fn_and_version = blob .name [len (session_prefix ) :]
284+ filename = "/" .join (fn_and_version .split ("/" )[:- 1 ])
269285 filenames .add (filename )
270286
271287 user_namespace_prefix = f"{ app_name } /{ user_id } /user/"
272288 user_namespace_blobs = self .storage_client .list_blobs (
273289 self .bucket , prefix = user_namespace_prefix
274290 )
275291 for blob in user_namespace_blobs :
276- * _ , filename , _ = blob .name .split ("/" )
292+ # blob.name is like user_namespace_prefix/filename/version
293+ fn_and_version = blob .name [len (user_namespace_prefix ) :]
294+ filename = "/" .join (fn_and_version .split ("/" )[:- 1 ])
277295 filenames .add (filename )
278296
279297 return sorted (list (filenames ))
@@ -323,14 +341,85 @@ def _list_versions(
323341 artifact.
324342 Returns an empty list if no versions are found.
325343 """
326- prefix = self ._get_blob_name (app_name , user_id , filename , "" , session_id )
327- blobs = self .storage_client .list_blobs (self .bucket , prefix = prefix )
344+ prefix = self ._get_blob_prefix (app_name , user_id , filename , session_id )
345+ blobs = self .storage_client .list_blobs (self .bucket , prefix = f" { prefix } /" )
328346 versions = []
329347 for blob in blobs :
330348 * _ , version = blob .name .split ("/" )
331349 versions .append (int (version ))
332350 return versions
333351
352+ def _get_artifact_version_sync (
353+ self ,
354+ app_name : str ,
355+ user_id : str ,
356+ session_id : Optional [str ],
357+ filename : str ,
358+ version : Optional [int ] = None ,
359+ ) -> Optional [ArtifactVersion ]:
360+ if version is None :
361+ versions = self ._list_versions (
362+ app_name = app_name ,
363+ user_id = user_id ,
364+ session_id = session_id ,
365+ filename = filename ,
366+ )
367+ if not versions :
368+ return None
369+ version = max (versions )
370+
371+ blob_name = self ._get_blob_name (
372+ app_name , user_id , filename , version , session_id
373+ )
374+ blob = self .bucket .get_blob (blob_name )
375+
376+ if not blob :
377+ return None
378+
379+ canonical_uri = f"gs://{ self .bucket_name } /{ blob .name } "
380+
381+ return ArtifactVersion (
382+ version = version ,
383+ canonical_uri = canonical_uri ,
384+ create_time = blob .time_created .timestamp (),
385+ mime_type = blob .content_type ,
386+ custom_metadata = blob .metadata if blob .metadata else {},
387+ )
388+
389+ def _list_artifact_versions_sync (
390+ self ,
391+ app_name : str ,
392+ user_id : str ,
393+ session_id : Optional [str ],
394+ filename : str ,
395+ ) -> list [ArtifactVersion ]:
396+ """Lists all versions and their metadata of an artifact."""
397+ prefix = self ._get_blob_prefix (app_name , user_id , filename , session_id )
398+ blobs = self .storage_client .list_blobs (self .bucket , prefix = f"{ prefix } /" )
399+ artifact_versions = []
400+ for blob in blobs :
401+ try :
402+ version = int (blob .name .split ("/" )[- 1 ])
403+ except ValueError :
404+ logger .warning (
405+ "Skipping blob %s because it does not end with a version number." ,
406+ blob .name ,
407+ )
408+ continue
409+
410+ canonical_uri = f"gs://{ self .bucket_name } /{ blob .name } "
411+ av = ArtifactVersion (
412+ version = version ,
413+ canonical_uri = canonical_uri ,
414+ create_time = blob .time_created .timestamp (),
415+ mime_type = blob .content_type ,
416+ custom_metadata = blob .metadata if blob .metadata else {},
417+ )
418+ artifact_versions .append (av )
419+
420+ artifact_versions .sort (key = lambda x : x .version )
421+ return artifact_versions
422+
334423 @override
335424 async def list_artifact_versions (
336425 self ,
@@ -340,8 +429,13 @@ async def list_artifact_versions(
340429 filename : str ,
341430 session_id : Optional [str ] = None ,
342431 ) -> list [ArtifactVersion ]:
343- # TODO: b/447451270 - Support list_artifact_versions.
344- raise NotImplementedError ("list_artifact_versions is not implemented yet." )
432+ return await asyncio .to_thread (
433+ self ._list_artifact_versions_sync ,
434+ app_name ,
435+ user_id ,
436+ session_id ,
437+ filename ,
438+ )
345439
346440 @override
347441 async def get_artifact_version (
@@ -353,5 +447,11 @@ async def get_artifact_version(
353447 session_id : Optional [str ] = None ,
354448 version : Optional [int ] = None ,
355449 ) -> Optional [ArtifactVersion ]:
356- # TODO: b/447451270 - Support get_artifact_version.
357- raise NotImplementedError ("get_artifact_version is not implemented yet." )
450+ return await asyncio .to_thread (
451+ self ._get_artifact_version_sync ,
452+ app_name ,
453+ user_id ,
454+ session_id ,
455+ filename ,
456+ version ,
457+ )
0 commit comments