Skip to content

Google Cloud Storage Integration #683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ serviceTemplate:
config:
values:
infra:
# cloud_provider [required]; either "aws" or "azure"
# cloud_provider [required]; either "aws" or "azure" or "gcp"
cloud_provider: aws
# k8s_cluster_name [required] is the name of the k8s cluster
k8s_cluster_name: main_cluster
Expand Down
38 changes: 23 additions & 15 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@
LiveLLMModelEndpointService,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway
from model_engine_server.infra.gateways.gcs_llm_artifact_gateway import GCSLLMArtifactGateway
from model_engine_server.infra.gateways.gcs_file_storage_gateway import GCSFileStorageGateway

logger = make_logger(logger_name())

Expand Down Expand Up @@ -258,16 +261,20 @@ def _get_external_interfaces(
monitoring_metrics_gateway=monitoring_metrics_gateway,
use_asyncio=(not CIRCLECI),
)
filesystem_gateway = (
ABSFilesystemGateway()
if infra_config().cloud_provider == "azure"
else S3FilesystemGateway()
)
llm_artifact_gateway = (
ABSLLMArtifactGateway()
if infra_config().cloud_provider == "azure"
else S3LLMArtifactGateway()
)
if infra_config().cloud_provider == "azure":
filesystem_gateway = ABSFilesystemGateway()
elif infra_config().cloud_provider == "gcp":
filesystem_gateway = GCSFilesystemGateway()
else:
filesystem_gateway = S3FilesystemGateway()

if infra_config().cloud_provider == "azure":
llm_artifact_gateway = ABSLLMArtifactGateway()
elif infra_config().cloud_provider == "gcp":
llm_artifact_gateway = GCSLLMArtifactGateway()
else:
llm_artifact_gateway = S3LLMArtifactGateway()

model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway(
filesystem_gateway=filesystem_gateway
)
Expand Down Expand Up @@ -334,11 +341,12 @@ def _get_external_interfaces(
docker_image_batch_job_gateway=docker_image_batch_job_gateway
)

file_storage_gateway = (
ABSFileStorageGateway()
if infra_config().cloud_provider == "azure"
else S3FileStorageGateway()
)
if infra_config().cloud_provider == "azure":
file_storage_gateway = ABSFileStorageGateway()
elif infra_config().cloud_provider == "gcp":
file_storage_gateway = GCSFileStorageGateway()
else:
file_storage_gateway = S3FileStorageGateway()

docker_repository: DockerRepository
if CIRCLECI:
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class _InfraConfig:
ml_account_id: str
docker_repo_prefix: str
s3_bucket: str
gcs_bucket: Optional[str] = None
redis_host: Optional[str] = None
redis_aws_secret_name: Optional[str] = None
profile_ml_worker: str = "default"
Expand Down
74 changes: 74 additions & 0 deletions model-engine/model_engine_server/core/gcp/storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
from typing import IO, Callable, Iterable, Optional, Sequence

import smart_open
from google.cloud import storage
from model_engine_server.core.loggers import logger_name, make_logger

logger = make_logger(logger_name())

__all__: Sequence[str] = (
"sync_storage_client",
# `open` should be used, but so as to not shadow the built-in, the preferred import is:
# >>> storage_client.open
# Thus, it's not included in the wildcard imports.
"sync_storage_client_keepalive",
"gcs_fileobj_exists",
)


def sync_storage_client(**kwargs) -> storage.Client:
return storage.Client(**kwargs)


def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefined-builtin
if "transport_params" not in kwargs:
kwargs["transport_params"] = {"client": sync_storage_client()}
return smart_open.open(uri, mode, **kwargs)


def sync_storage_client_keepalive(
gcp_client: storage.Client,
buckets: Iterable[str],
interval: int,
is_cancelled: Callable[[], bool],
) -> None:
"""Keeps connection pool warmed up for access on list of GCP buckets.

NOTE: :param:`is_cancelled` **MUST BE THREADSAFE**.
"""
while True:
if is_cancelled():
logger.info("Ending GCP client keepalive: cancel invoked.")
return
for bucket in buckets:
try:
# Instead of head_bucket, for GCP we obtain the bucket object and reload it.
bucket_obj = gcp_client.bucket(bucket)
bucket_obj.reload() # refreshes metadata and validates connectivity
except Exception: # pylint:disable=broad-except
logger.exception(
f"Unexpected error in keepalive loop on accessing bucket: {bucket}"
)
time.sleep(interval)


def gcs_fileobj_exists(bucket: str, key: str, client: Optional[storage.Client] = None) -> bool:
"""
Test if file exists in GCP storage.
:param bucket: GCP bucket name
:param key: Blob name or file's path within the bucket
:param client: A google.cloud.storage.Client instance
:return: Whether the file exists on GCP or not
"""
if client is None:
client = sync_storage_client()
try:
bucket_obj = client.bucket(bucket)
# get_blob returns None if the blob does not exist.
blob = bucket_obj.get_blob(key)
except Exception as e:
logger.exception(f"Error checking file existence in bucket {bucket} for key {key}")
raise e
else:
return blob is not None
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import shutil

import model_engine_server.core.aws.storage_client as storage_client
from model_engine_server.common.serialization_utils import b64_to_str
from model_engine_server.core.aws.storage_client import s3_fileobj_exists
from model_engine_server.core.aws import storage_client as aws_storage_client

# Top-level imports for remote storage clients with aliases.
from model_engine_server.core.gcp import storage_client as gcp_storage_client
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.core.utils.url import parse_attachment_url

Expand All @@ -20,11 +22,26 @@ def main(input_local: str, local_file: str, remote_file: str, file_contents_b64e
else:
logger.info("Copying file from remote")
parsed_remote = parse_attachment_url(remote_file)
if not s3_fileobj_exists(bucket=parsed_remote.bucket, key=parsed_remote.key):
logger.warning("S3 file doesn't exist, aborting")
raise ValueError # TODO propagate error to the gateway
# TODO if we need we can s5cmd this
with storage_client.open(remote_file, "rb") as fr, open(local_file, "wb") as fw2:
# Conditional logic to support GCS file URLs without breaking S3 behavior.
if remote_file.startswith("gs://"):
# Use the GCP storage client.
file_exists = gcp_storage_client.gcs_fileobj_exists(
bucket=parsed_remote.bucket, key=parsed_remote.key
)
storage_open = gcp_storage_client.open
file_label = "GCS"
else:
# Use the AWS storage client for backward compatibility.
file_exists = aws_storage_client.s3_fileobj_exists(
bucket=parsed_remote.bucket, key=parsed_remote.key
)
storage_open = aws_storage_client.open
file_label = "S3"
if not file_exists:
logger.warning(f"{file_label} file doesn't exist, aborting")
raise ValueError # TODO: propagate error to the gateway
# Open the remote file (using the appropriate storage client) and copy its contents locally.
with storage_open(remote_file, "rb") as fr, open(local_file, "wb") as fw2:
shutil.copyfileobj(fr, fw2)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I think this is only really used for some fine tuning apis that aren't really used at this point, think it's fine to keep ofc since you'll probably need to initialize dependencies anyways, but this code probably won't really get exercised at all

from typing import List, Optional

from google.cloud import storage

from model_engine_server.core.config import infra_config
from model_engine_server.domain.gateways.file_storage_gateway import (
FileMetadata,
FileStorageGateway,
)
from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway


def get_gcs_key(owner: str, file_id: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd prefix these w/ an underscore so that no one is tempted to try and import these from outside this file, thus breaking Clean Architecture norms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I think the s3_file_storage_gateway also doesn't have the prefixes

"""
Constructs a GCS object key from the owner and file_id.
"""
return os.path.join(owner, file_id)


def get_gcs_url(owner: str, file_id: str) -> str:
"""
Returns the gs:// URL for the bucket, using the GCS key.
"""
return f"gs://{infra_config().gcs_bucket}/{get_gcs_key(owner, file_id)}"


class GCSFileStorageGateway(FileStorageGateway):
"""
Concrete implementation of a file storage gateway backed by GCS.
"""

def __init__(self):
self.filesystem_gateway = GCSFilesystemGateway()

async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]:
"""
Returns a signed GCS URL for the given file.
"""
try:
return self.filesystem_gateway.generate_signed_url(get_gcs_url(owner, file_id))
except Exception:
return None

async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]:
"""
Retrieves file metadata if it exists. Returns None if the file is missing.
"""
try:
client = self.filesystem_gateway.get_storage_client({})
bucket = client.bucket(infra_config().gcs_bucket)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this pattern was already there, but I think it'd probably make more sense to pass in the bucket into the constructor of this class. This way, there's one less dependency on the old infra_config object. @seanshi-scale @tiffzhao5 thoughts?

Could also make the argument to just pass in the bucket as an argument with every get_file call, but that's outside of the scope of this change I'd say.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with having the bucket be passed in the constructor (in addition to anything else from any configs); dependencies.py does read in from infra_config at times to figure out constructor arguments, so there's precedent already

blob = bucket.blob(get_gcs_key(owner, file_id))
blob.reload() # Fetch metadata
return FileMetadata(
id=file_id,
filename=file_id,
size=blob.size,
owner=owner,
updated_at=blob.updated,
)
except Exception:
return None

async def get_file_content(self, owner: str, file_id: str) -> Optional[str]:
"""
Reads and returns the string content of the file.
"""
try:
with self.filesystem_gateway.open(get_gcs_url(owner, file_id)) as f:
return f.read()
except Exception:
return None

async def upload_file(self, owner: str, filename: str, content: bytes) -> str:
"""
Uploads the file to the GCS bucket. Returns the filename used in bucket.
"""
with self.filesystem_gateway.open(
get_gcs_url(owner, filename), mode="w"
) as f:
f.write(content.decode("utf-8"))
return filename

async def delete_file(self, owner: str, file_id: str) -> bool:
"""
Deletes the file from the GCS bucket. Returns True if successful, False otherwise.
"""
try:
client = self.filesystem_gateway.get_storage_client({})
bucket = client.bucket(infra_config().gcs_bucket)
blob = bucket.blob(get_gcs_key(owner, file_id))
blob.delete()
return True
except Exception:
return False

async def list_files(self, owner: str) -> List[FileMetadata]:
"""
Lists all files in the GCS bucket for the given owner.
"""
client = self.filesystem_gateway.get_storage_client({})
blobs = client.list_blobs(infra_config().gcs_bucket, prefix=owner)
files = [await self.get_file(owner, b.name[len(owner) + 1 :]) for b in blobs if b.name != owner]
return [f for f in files if f is not None]
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import re
from typing import IO, Optional, Dict

import smart_open
from google.cloud import storage
from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway


class GCSFilesystemGateway(FilesystemGateway):
"""
Concrete implementation for interacting with Google Cloud Storage.
"""

def get_storage_client(self, kwargs: Optional[Dict]) -> storage.Client:
"""
Retrieve or create a Google Cloud Storage client. Could optionally
utilize environment variables or passed-in credentials.
"""
project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this env var get set? it seems analogous to AWS_PROFILE but those changes would need to be baked into any relevant k8s yamls most likely

return storage.Client(project=project)

def open(self, uri: str, mode: str = "rt", **kwargs) -> IO:
"""
Uses smart_open to handle reading/writing to GCS.
"""
# The `transport_params` is how smart_open passes in the storage client
client = self.get_storage_client(kwargs)
transport_params = {"client": client}
return smart_open.open(uri, mode, transport_params=transport_params)

def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str:
"""
Generate a signed URL for the given GCS URI, valid for `expiration` seconds.
"""
# Expecting URIs in the form: 'gs://bucket_name/some_key'
match = re.search(r"^gs://([^/]+)/(.+)$", uri)
if not match:
raise ValueError(f"Invalid GCS URI: {uri}")

bucket_name, blob_name = match.groups()
client = self.get_storage_client(kwargs)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)

return blob.generate_signed_url(expiration=expiration)
Loading