diff --git a/CHANGELOG.md b/CHANGELOG.md index 8276edc61..86f428ec3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch. +- Fixed bug where GCS upload does not retry on transient failures. ## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27 diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index 5fda27410..dce249938 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -2,6 +2,7 @@ import logging import os import pickle +import random import re import shutil import time @@ -590,11 +591,20 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) - if not save_overwrite and blob.exists(): - raise FileExistsError( - f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." - ) - blob.upload_from_filename(source, retry=_get_gcs_conditional_retry()) + + generation: int = 0 + if blob.exists(): + if not save_overwrite: + raise FileExistsError( + f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." + ) + + assert blob.generation is not None + generation = blob.generation + + blob.upload_from_filename( + source, if_generation_match=generation, retry=_get_gcs_conditional_retry() + ) @retriable()