diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index f7e1fe58..15f66472 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -157,7 +157,7 @@ serviceTemplate: config: values: infra: - # cloud_provider [required]; either "aws" or "azure" + # cloud_provider [required]; either "aws" or "azure" or "gcp" cloud_provider: aws # k8s_cluster_name [required] is the name of the k8s cluster k8s_cluster_name: main_cluster diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index e120fbf0..1a418eb4 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -128,6 +128,9 @@ LiveLLMModelEndpointService, ) from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session +from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway +from model_engine_server.infra.gateways.gcs_llm_artifact_gateway import GCSLLMArtifactGateway +from model_engine_server.infra.gateways.gcs_file_storage_gateway import GCSFileStorageGateway logger = make_logger(logger_name()) @@ -258,16 +261,20 @@ def _get_external_interfaces( monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) - filesystem_gateway = ( - ABSFilesystemGateway() - if infra_config().cloud_provider == "azure" - else S3FilesystemGateway() - ) - llm_artifact_gateway = ( - ABSLLMArtifactGateway() - if infra_config().cloud_provider == "azure" - else S3LLMArtifactGateway() - ) + if infra_config().cloud_provider == "azure": + filesystem_gateway = ABSFilesystemGateway() + elif infra_config().cloud_provider == "gcp": + filesystem_gateway = GCSFilesystemGateway() + else: + filesystem_gateway = S3FilesystemGateway() + + if infra_config().cloud_provider == "azure": + llm_artifact_gateway = ABSLLMArtifactGateway() + elif infra_config().cloud_provider == "gcp": + llm_artifact_gateway = GCSLLMArtifactGateway() + else: + llm_artifact_gateway = S3LLMArtifactGateway() + model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) @@ -334,11 +341,12 @@ def _get_external_interfaces( docker_image_batch_job_gateway=docker_image_batch_job_gateway ) - file_storage_gateway = ( - ABSFileStorageGateway() - if infra_config().cloud_provider == "azure" - else S3FileStorageGateway() - ) + if infra_config().cloud_provider == "azure": + file_storage_gateway = ABSFileStorageGateway() + elif infra_config().cloud_provider == "gcp": + file_storage_gateway = GCSFileStorageGateway() + else: + file_storage_gateway = S3FileStorageGateway() docker_repository: DockerRepository if CIRCLECI: diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 0474976c..0bf2476f 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -41,6 +41,7 @@ class _InfraConfig: ml_account_id: str docker_repo_prefix: str s3_bucket: str + gcs_bucket: Optional[str] = None redis_host: Optional[str] = None redis_aws_secret_name: Optional[str] = None profile_ml_worker: str = "default" diff --git a/model-engine/model_engine_server/core/gcp/storage_client.py b/model-engine/model_engine_server/core/gcp/storage_client.py new file mode 100644 index 00000000..44d986b0 --- /dev/null +++ b/model-engine/model_engine_server/core/gcp/storage_client.py @@ -0,0 +1,74 @@ +import time +from typing import IO, Callable, Iterable, Optional, Sequence + +import smart_open +from google.cloud import storage +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + +__all__: Sequence[str] = ( + "sync_storage_client", + # `open` should be used, but so as to not shadow the built-in, the preferred import is: + # >>> storage_client.open + # Thus, it's not included in the wildcard imports. + "sync_storage_client_keepalive", + "gcs_fileobj_exists", +) + + +def sync_storage_client(**kwargs) -> storage.Client: + return storage.Client(**kwargs) + + +def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefined-builtin + if "transport_params" not in kwargs: + kwargs["transport_params"] = {"client": sync_storage_client()} + return smart_open.open(uri, mode, **kwargs) + + +def sync_storage_client_keepalive( + gcp_client: storage.Client, + buckets: Iterable[str], + interval: int, + is_cancelled: Callable[[], bool], +) -> None: + """Keeps connection pool warmed up for access on list of GCP buckets. + + NOTE: :param:`is_cancelled` **MUST BE THREADSAFE**. + """ + while True: + if is_cancelled(): + logger.info("Ending GCP client keepalive: cancel invoked.") + return + for bucket in buckets: + try: + # Instead of head_bucket, for GCP we obtain the bucket object and reload it. + bucket_obj = gcp_client.bucket(bucket) + bucket_obj.reload() # refreshes metadata and validates connectivity + except Exception: # pylint:disable=broad-except + logger.exception( + f"Unexpected error in keepalive loop on accessing bucket: {bucket}" + ) + time.sleep(interval) + + +def gcs_fileobj_exists(bucket: str, key: str, client: Optional[storage.Client] = None) -> bool: + """ + Test if file exists in GCP storage. + :param bucket: GCP bucket name + :param key: Blob name or file's path within the bucket + :param client: A google.cloud.storage.Client instance + :return: Whether the file exists on GCP or not + """ + if client is None: + client = sync_storage_client() + try: + bucket_obj = client.bucket(bucket) + # get_blob returns None if the blob does not exist. + blob = bucket_obj.get_blob(key) + except Exception as e: + logger.exception(f"Error checking file existence in bucket {bucket} for key {key}") + raise e + else: + return blob is not None diff --git a/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py index 1c0048be..54eaa6f3 100644 --- a/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py +++ b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py @@ -1,9 +1,11 @@ import argparse import shutil -import model_engine_server.core.aws.storage_client as storage_client from model_engine_server.common.serialization_utils import b64_to_str -from model_engine_server.core.aws.storage_client import s3_fileobj_exists +from model_engine_server.core.aws import storage_client as aws_storage_client + +# Top-level imports for remote storage clients with aliases. +from model_engine_server.core.gcp import storage_client as gcp_storage_client from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.core.utils.url import parse_attachment_url @@ -20,11 +22,26 @@ def main(input_local: str, local_file: str, remote_file: str, file_contents_b64e else: logger.info("Copying file from remote") parsed_remote = parse_attachment_url(remote_file) - if not s3_fileobj_exists(bucket=parsed_remote.bucket, key=parsed_remote.key): - logger.warning("S3 file doesn't exist, aborting") - raise ValueError # TODO propagate error to the gateway - # TODO if we need we can s5cmd this - with storage_client.open(remote_file, "rb") as fr, open(local_file, "wb") as fw2: + # Conditional logic to support GCS file URLs without breaking S3 behavior. + if remote_file.startswith("gs://"): + # Use the GCP storage client. + file_exists = gcp_storage_client.gcs_fileobj_exists( + bucket=parsed_remote.bucket, key=parsed_remote.key + ) + storage_open = gcp_storage_client.open + file_label = "GCS" + else: + # Use the AWS storage client for backward compatibility. + file_exists = aws_storage_client.s3_fileobj_exists( + bucket=parsed_remote.bucket, key=parsed_remote.key + ) + storage_open = aws_storage_client.open + file_label = "S3" + if not file_exists: + logger.warning(f"{file_label} file doesn't exist, aborting") + raise ValueError # TODO: propagate error to the gateway + # Open the remote file (using the appropriate storage client) and copy its contents locally. + with storage_open(remote_file, "rb") as fr, open(local_file, "wb") as fw2: shutil.copyfileobj(fr, fw2) diff --git a/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py new file mode 100644 index 00000000..29a8c2bd --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/gcs_file_storage_gateway.py @@ -0,0 +1,104 @@ +import os +from typing import List, Optional + +from google.cloud import storage + +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.file_storage_gateway import ( + FileMetadata, + FileStorageGateway, +) +from model_engine_server.infra.gateways.gcs_filesystem_gateway import GCSFilesystemGateway + + +def get_gcs_key(owner: str, file_id: str) -> str: + """ + Constructs a GCS object key from the owner and file_id. + """ + return os.path.join(owner, file_id) + + +def get_gcs_url(owner: str, file_id: str) -> str: + """ + Returns the gs:// URL for the bucket, using the GCS key. + """ + return f"gs://{infra_config().gcs_bucket}/{get_gcs_key(owner, file_id)}" + + +class GCSFileStorageGateway(FileStorageGateway): + """ + Concrete implementation of a file storage gateway backed by GCS. + """ + + def __init__(self): + self.filesystem_gateway = GCSFilesystemGateway() + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + """ + Returns a signed GCS URL for the given file. + """ + try: + return self.filesystem_gateway.generate_signed_url(get_gcs_url(owner, file_id)) + except Exception: + return None + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + """ + Retrieves file metadata if it exists. Returns None if the file is missing. + """ + try: + client = self.filesystem_gateway.get_storage_client({}) + bucket = client.bucket(infra_config().gcs_bucket) + blob = bucket.blob(get_gcs_key(owner, file_id)) + blob.reload() # Fetch metadata + return FileMetadata( + id=file_id, + filename=file_id, + size=blob.size, + owner=owner, + updated_at=blob.updated, + ) + except Exception: + return None + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + """ + Reads and returns the string content of the file. + """ + try: + with self.filesystem_gateway.open(get_gcs_url(owner, file_id)) as f: + return f.read() + except Exception: + return None + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + """ + Uploads the file to the GCS bucket. Returns the filename used in bucket. + """ + with self.filesystem_gateway.open( + get_gcs_url(owner, filename), mode="w" + ) as f: + f.write(content.decode("utf-8")) + return filename + + async def delete_file(self, owner: str, file_id: str) -> bool: + """ + Deletes the file from the GCS bucket. Returns True if successful, False otherwise. + """ + try: + client = self.filesystem_gateway.get_storage_client({}) + bucket = client.bucket(infra_config().gcs_bucket) + blob = bucket.blob(get_gcs_key(owner, file_id)) + blob.delete() + return True + except Exception: + return False + + async def list_files(self, owner: str) -> List[FileMetadata]: + """ + Lists all files in the GCS bucket for the given owner. + """ + client = self.filesystem_gateway.get_storage_client({}) + blobs = client.list_blobs(infra_config().gcs_bucket, prefix=owner) + files = [await self.get_file(owner, b.name[len(owner) + 1 :]) for b in blobs if b.name != owner] + return [f for f in files if f is not None] \ No newline at end of file diff --git a/model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py new file mode 100644 index 00000000..83d7a400 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/gcs_filesystem_gateway.py @@ -0,0 +1,46 @@ +import os +import re +from typing import IO, Optional, Dict + +import smart_open +from google.cloud import storage +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway + + +class GCSFilesystemGateway(FilesystemGateway): + """ + Concrete implementation for interacting with Google Cloud Storage. + """ + + def get_storage_client(self, kwargs: Optional[Dict]) -> storage.Client: + """ + Retrieve or create a Google Cloud Storage client. Could optionally + utilize environment variables or passed-in credentials. + """ + project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT")) + return storage.Client(project=project) + + def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + """ + Uses smart_open to handle reading/writing to GCS. + """ + # The `transport_params` is how smart_open passes in the storage client + client = self.get_storage_client(kwargs) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: + """ + Generate a signed URL for the given GCS URI, valid for `expiration` seconds. + """ + # Expecting URIs in the form: 'gs://bucket_name/some_key' + match = re.search(r"^gs://([^/]+)/(.+)$", uri) + if not match: + raise ValueError(f"Invalid GCS URI: {uri}") + + bucket_name, blob_name = match.groups() + client = self.get_storage_client(kwargs) + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + return blob.generate_signed_url(expiration=expiration) \ No newline at end of file diff --git a/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py new file mode 100644 index 00000000..e935d7f4 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py @@ -0,0 +1,115 @@ +import json +import os +from typing import Any, Dict, List + +from google.cloud import storage +from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.url import parse_attachment_url +from model_engine_server.domain.gateways import LLMArtifactGateway + +logger = make_logger(logger_name()) + + +class GCSLLMArtifactGateway(LLMArtifactGateway): + """ + Concrete implementation for interacting with a filesystem backed by GCS. + """ + + def _get_gcs_client(self, kwargs) -> storage.Client: + """ + Returns a GCS client. If desired, you could pass in project info + or credentials via `kwargs`. + """ + project = kwargs.get("gcp_project", os.getenv("GCP_PROJECT")) + return storage.Client(project=project) + + def list_files(self, path: str, **kwargs) -> List[str]: + """ + Lists all files under the path argument in GCS. The path is expected + to be in the form 'gs://bucket/prefix'. + """ + gcs = self._get_gcs_client(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket_name = parsed_remote.bucket + prefix = parsed_remote.key + + bucket = gcs.bucket(bucket_name) + files = [blob.name for blob in bucket.list_blobs(prefix=prefix)] + return files + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + """ + Downloads all files under the given path to the local target_path directory. + """ + gcs = self._get_gcs_client(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket_name = parsed_remote.bucket + prefix = parsed_remote.key + + bucket = gcs.bucket(bucket_name) + blobs = bucket.list_blobs(prefix=prefix) + downloaded_files = [] + + for blob in blobs: + # Remove prefix and leading slash to derive local name + file_path_suffix = blob.name.replace(prefix, "").lstrip("/") + local_path = os.path.join(target_path, file_path_suffix).rstrip("/") + + if not overwrite and os.path.exists(local_path): + downloaded_files.append(local_path) + continue + + local_dir = os.path.dirname(local_path) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + logger.info(f"Downloading {blob.name} to {local_path}") + blob.download_to_filename(local_path) + downloaded_files.append(local_path) + + return downloaded_files + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + """ + Retrieves URLs for all model weight artifacts stored under the + prefix: hmi_config.hf_user_fine_tuned_weights_prefix / {owner} / {model_cache_name} + """ + gcs = self._get_gcs_client(kwargs) + prefix_base = hmi_config.hf_user_fine_tuned_weights_prefix + if prefix_base.startswith("gs://"): + # Strip "gs://" for prefix logic below + prefix_base = prefix_base[5:] + bucket_name, prefix_base = prefix_base.split("/", 1) + + model_cache_name = get_model_cache_directory_name(model_name) + prefix = f"{prefix_base}/{owner}/{model_cache_name}" + + bucket = gcs.bucket(bucket_name) + blobs = bucket.list_blobs(prefix=prefix) + + model_files = [f"gs://{bucket_name}/{blob.name}" for blob in blobs] + return model_files + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + """ + Downloads a 'config.json' file from GCS located at path/config.json + and returns it as a dictionary. + """ + gcs = self._get_gcs_client(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket_name = parsed_remote.bucket + # The key from parse_attachment_url might be e.g. "weight_prefix/model_dir" + # so we append "/config.json" and build a local path to download it. + key_with_config = os.path.join(parsed_remote.key, "config.json") + + bucket = gcs.bucket(bucket_name) + blob = bucket.blob(key_with_config) + + # Download to a tmp path and load + filepath = os.path.join("/tmp", key_with_config.replace("/", "_")) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + blob.download_to_filename(filepath) + + with open(filepath, "r") as f: + return json.load(f) \ No newline at end of file