diff --git a/next.config.js b/next.config.js index 658404a..c9ff46c 100644 --- a/next.config.js +++ b/next.config.js @@ -2,3 +2,14 @@ const nextConfig = {}; module.exports = nextConfig; + +module.exports = { + async rewrites() { + return [ + { + source: '/api/:path*', + destination: 'http://127.0.0.1:8000/:path*', // Proxy to Backend + }, + ] + }, +} \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index 8b67780..fc55a24 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:06b2cad7b63461c9d6500270ccddb3849214f1a4a43b7b3e8ab67fcfc307303e" +content_hash = "sha256:7fa7164d0df0acb97c9d67e7bbcfaac164ae04dc08b382a5d7243eb7bb517705" [[package]] name = "aiohttp" @@ -160,6 +160,30 @@ files = [ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, ] +[[package]] +name = "cffi" +version = "1.16.0" +requires_python = ">=3.8" +summary = "Foreign Function Interface for Python calling C code." +groups = ["default"] +marker = "platform_python_implementation != \"PyPy\"" +dependencies = [ + "pycparser", +] +files = [ + {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"}, + {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"}, + {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"}, + {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"}, + {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"}, + {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -223,6 +247,50 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "cryptography" +version = "42.0.7" +requires_python = ">=3.7" +summary = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +groups = ["default"] +dependencies = [ + "cffi>=1.12; platform_python_implementation != \"PyPy\"", +] +files = [ + {file = "cryptography-42.0.7-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:a987f840718078212fdf4504d0fd4c6effe34a7e4740378e59d47696e8dfb477"}, + {file = "cryptography-42.0.7-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd13b5e9b543532453de08bcdc3cc7cebec6f9883e886fd20a92f26940fd3e7a"}, + {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79165431551042cc9d1d90e6145d5d0d3ab0f2d66326c201d9b0e7f5bf43604"}, + {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a47787a5e3649008a1102d3df55424e86606c9bae6fb77ac59afe06d234605f8"}, + {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:02c0eee2d7133bdbbc5e24441258d5d2244beb31da5ed19fbb80315f4bbbff55"}, + {file = "cryptography-42.0.7-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:5e44507bf8d14b36b8389b226665d597bc0f18ea035d75b4e53c7b1ea84583cc"}, + {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:7f8b25fa616d8b846aef64b15c606bb0828dbc35faf90566eb139aa9cff67af2"}, + {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:93a3209f6bb2b33e725ed08ee0991b92976dfdcf4e8b38646540674fc7508e13"}, + {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e6b8f1881dac458c34778d0a424ae5769de30544fc678eac51c1c8bb2183e9da"}, + {file = "cryptography-42.0.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3de9a45d3b2b7d8088c3fbf1ed4395dfeff79d07842217b38df14ef09ce1d8d7"}, + {file = "cryptography-42.0.7-cp37-abi3-win32.whl", hash = "sha256:789caea816c6704f63f6241a519bfa347f72fbd67ba28d04636b7c6b7da94b0b"}, + {file = "cryptography-42.0.7-cp37-abi3-win_amd64.whl", hash = "sha256:8cb8ce7c3347fcf9446f201dc30e2d5a3c898d009126010cbd1f443f28b52678"}, + {file = "cryptography-42.0.7-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:a3a5ac8b56fe37f3125e5b72b61dcde43283e5370827f5233893d461b7360cd4"}, + {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:779245e13b9a6638df14641d029add5dc17edbef6ec915688f3acb9e720a5858"}, + {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d563795db98b4cd57742a78a288cdbdc9daedac29f2239793071fe114f13785"}, + {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:31adb7d06fe4383226c3e963471f6837742889b3c4caa55aac20ad951bc8ffda"}, + {file = "cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:efd0bf5205240182e0f13bcaea41be4fdf5c22c5129fc7ced4a0282ac86998c9"}, + {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a9bc127cdc4ecf87a5ea22a2556cab6c7eda2923f84e4f3cc588e8470ce4e42e"}, + {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:3577d029bc3f4827dd5bf8bf7710cac13527b470bbf1820a3f394adb38ed7d5f"}, + {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2e47577f9b18723fa294b0ea9a17d5e53a227867a0a4904a1a076d1646d45ca1"}, + {file = "cryptography-42.0.7-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1a58839984d9cb34c855197043eaae2c187d930ca6d644612843b4fe8513c886"}, + {file = "cryptography-42.0.7-cp39-abi3-win32.whl", hash = "sha256:e6b79d0adb01aae87e8a44c2b64bc3f3fe59515280e00fb6d57a7267a2583cda"}, + {file = "cryptography-42.0.7-cp39-abi3-win_amd64.whl", hash = "sha256:16268d46086bb8ad5bf0a2b5544d8a9ed87a0e33f5e77dd3c3301e63d941a83b"}, + {file = "cryptography-42.0.7-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2954fccea107026512b15afb4aa664a5640cd0af630e2ee3962f2602693f0c82"}, + {file = "cryptography-42.0.7-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:362e7197754c231797ec45ee081f3088a27a47c6c01eff2ac83f60f85a50fe60"}, + {file = "cryptography-42.0.7-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4f698edacf9c9e0371112792558d2f705b5645076cc0aaae02f816a0171770fd"}, + {file = "cryptography-42.0.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5482e789294854c28237bba77c4c83be698be740e31a3ae5e879ee5444166582"}, + {file = "cryptography-42.0.7-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e9b2a6309f14c0497f348d08a065d52f3020656f675819fc405fb63bbcd26562"}, + {file = "cryptography-42.0.7-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d8e3098721b84392ee45af2dd554c947c32cc52f862b6a3ae982dbb90f577f14"}, + {file = "cryptography-42.0.7-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c65f96dad14f8528a447414125e1fc8feb2ad5a272b8f68477abbcc1ea7d94b9"}, + {file = "cryptography-42.0.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:36017400817987670037fbb0324d71489b6ead6231c9604f8fc1f7d008087c68"}, + {file = "cryptography-42.0.7.tar.gz", hash = "sha256:ecbfbc00bf55888edda9868a4cf927205de8499e7fabe6c050322298382953f2"}, +] + [[package]] name = "datasets" version = "2.19.1" @@ -251,6 +319,19 @@ files = [ {file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"}, ] +[[package]] +name = "deprecation" +version = "2.1.0" +summary = "A library to handle automated deprecations" +groups = ["default"] +dependencies = [ + "packaging", +] +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + [[package]] name = "dill" version = "0.3.8" @@ -448,6 +529,21 @@ files = [ {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"}, ] +[[package]] +name = "gotrue" +version = "2.4.2" +requires_python = "<4.0,>=3.8" +summary = "Python Client Library for Supabase Auth" +groups = ["default"] +dependencies = [ + "httpx<0.28,>=0.23", + "pydantic<3,>=1.10", +] +files = [ + {file = "gotrue-2.4.2-py3-none-any.whl", hash = "sha256:64cd40933d1f0a5d5cc4f4bd93bc51d730b94812447b6600f774790a4901e455"}, + {file = "gotrue-2.4.2.tar.gz", hash = "sha256:e100745161f1c58dd05b9c1ef8bcd4cd78cdfb38d8d2c253ade63143a3dc6aeb"}, +] + [[package]] name = "h11" version = "0.14.0" @@ -923,6 +1019,23 @@ files = [ {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] +[[package]] +name = "postgrest" +version = "0.16.4" +requires_python = "<4.0,>=3.8" +summary = "PostgREST client for Python. This library provides an ORM interface to PostgREST." +groups = ["default"] +dependencies = [ + "deprecation<3.0.0,>=2.1.0", + "httpx<0.28,>=0.24", + "pydantic<3.0,>=1.9", + "strenum<0.5.0,>=0.4.9", +] +files = [ + {file = "postgrest-0.16.4-py3-none-any.whl", hash = "sha256:304425381eb38e31018832a524943d7d1f07687be80c3c7397d8ae69ca56cb88"}, + {file = "postgrest-0.16.4.tar.gz", hash = "sha256:e16973155be1464101d18a51cc060707cd177b918f4b01ea8afa51746ca870ef"}, +] + [[package]] name = "pox" version = "0.3.4" @@ -1022,6 +1135,18 @@ files = [ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, ] +[[package]] +name = "pycparser" +version = "2.22" +requires_python = ">=3.8" +summary = "C parser in Python" +groups = ["default"] +marker = "platform_python_implementation != \"PyPy\"" +files = [ + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, +] + [[package]] name = "pydantic" version = "2.7.1" @@ -1197,6 +1322,22 @@ files = [ {file = "readchar-4.0.6.tar.gz", hash = "sha256:e0dae942d3a746f8d5423f83dbad67efe704004baafe31b626477929faaee472"}, ] +[[package]] +name = "realtime" +version = "1.0.4" +requires_python = "<4.0,>=3.8" +summary = "" +groups = ["default"] +dependencies = [ + "python-dateutil<3.0.0,>=2.8.1", + "typing-extensions<5.0.0,>=4.11.0", + "websockets<13,>=11", +] +files = [ + {file = "realtime-1.0.4-py3-none-any.whl", hash = "sha256:b06bea001985f089167320bda1e91c6b2d866f56ca810bb8d768ee3cf695ee21"}, + {file = "realtime-1.0.4.tar.gz", hash = "sha256:a9095f60121a365e84656c582e6ccd8dc8b3a732ddddb2ccd26cc3d32b77bdf6"}, +] + [[package]] name = "referencing" version = "0.35.1" @@ -1515,6 +1656,65 @@ files = [ {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, ] +[[package]] +name = "storage3" +version = "0.7.4" +requires_python = "<4.0,>=3.8" +summary = "Supabase Storage client for Python." +groups = ["default"] +dependencies = [ + "httpx<0.28,>=0.24", + "python-dateutil<3.0.0,>=2.8.2", + "typing-extensions<5.0.0,>=4.2.0", +] +files = [ + {file = "storage3-0.7.4-py3-none-any.whl", hash = "sha256:0b8e8839b10a64063796ce55a41462c7ffd6842e0ada74f25f5dcf37e1d1bade"}, + {file = "storage3-0.7.4.tar.gz", hash = "sha256:61fcbf836f566405981722abb7d56caa57025b261e7a316e73316701abf0c040"}, +] + +[[package]] +name = "strenum" +version = "0.4.15" +summary = "An Enum that inherits from str." +groups = ["default"] +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + +[[package]] +name = "supabase" +version = "2.4.5" +requires_python = "<4.0,>=3.8" +summary = "Supabase client for Python." +groups = ["default"] +dependencies = [ + "gotrue<3.0,>=1.3", + "httpx<0.28,>=0.24", + "postgrest<0.17.0,>=0.14", + "realtime<2.0.0,>=1.0.0", + "storage3<0.8.0,>=0.5.3", + "supafunc<0.5.0,>=0.3.1", +] +files = [ + {file = "supabase-2.4.5-py3-none-any.whl", hash = "sha256:100441c36bf3390b040818c636c372a91645d18b6a9e0c12cea061fb00db664c"}, + {file = "supabase-2.4.5.tar.gz", hash = "sha256:8520b5a194c6d8fdbdd71b45aefc8b5a42d1a6711a2c693b6d299aeb785e8532"}, +] + +[[package]] +name = "supafunc" +version = "0.4.5" +requires_python = "<4.0,>=3.8" +summary = "Library for Supabase Functions" +groups = ["default"] +dependencies = [ + "httpx<0.28,>=0.24", +] +files = [ + {file = "supafunc-0.4.5-py3-none-any.whl", hash = "sha256:2208045f8f5c797924666f6a332efad75ad368f8030b2e4ceb9d2bf63f329373"}, + {file = "supafunc-0.4.5.tar.gz", hash = "sha256:a6466d78bdcaa58b7f0303793643103baae8106a87acd5d01e196179a9d0d024"}, +] + [[package]] name = "tblib" version = "3.0.0" diff --git a/pyproject.toml b/pyproject.toml index ff23b23..a9144c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "uvicorn>=0.29.0", "fastapi>=0.111.0", "litellm>=1.37.5", + "cryptography>=42.0.7", + "supabase>=2.4.5", ] requires-python = ">=3.12" readme = "README.md" diff --git a/server.py b/server.py index 8ffb950..b563aff 100644 --- a/server.py +++ b/server.py @@ -1,18 +1,51 @@ +import json import uvicorn import os +import yaml from dotenv import dotenv_values -from fastapi import FastAPI +from fastapi import FastAPI, Request +from functools import wraps +from litellm import completion from src.config import get_config_for_endpoint, get_endpoints_for_model -from src.sagemaker.resources import get_sagemaker_endpoint +from src.sagemaker.create_model import deploy_model from src.sagemaker.query_endpoint import make_query_request +from src.sagemaker.resources import get_sagemaker_endpoint from src.schemas.query import Query, ChatCompletion -from src.session import session -from litellm import completion +from src.schemas.secret import Secrets +from src.supabase.secret import set_secrets, get_secrets +from src.supabase import supabase_client, supabase_id +from urllib.parse import unquote +from pydantic import BaseModel -os.environ["AWS_REGION_NAME"] = session.region_name app = FastAPI() +class DeploymentConfig(BaseModel): + path: str + + +class NotAuthenticatedException(Exception): + pass + + +def auth_required(func): + @wraps(func) + async def wrapper(*args, **kwargs): + request = kwargs['request'] + auth_token = request.cookies.get(f'sb-{supabase_id}-auth-token') + if not auth_token: + raise NotAuthenticatedException + + payload = json.loads(unquote(auth_token)) + access_token = payload["access_token"] + refresh_token = payload["refresh_token"] + + supabase_client.auth.set_session(access_token, refresh_token) + return await func(*args, **kwargs) + + return wrapper + + class NotDeployedException(Exception): pass @@ -22,6 +55,20 @@ def get_endpoint(endpoint_name: str): return get_sagemaker_endpoint(endpoint_name) +@app.post("/endpoint/deploy") +@auth_required +async def deploy_endpoint(request: Request, deployment_config_path: DeploymentConfig): + deployment = None + model = None + with open(deployment_config_path.path) as config: + configuration = yaml.safe_load(config) + deployment = configuration['deployment'] + + # TODO: Support multi-model endpoints + model = configuration['models'][0] + deploy_model(deployment, model) + + @app.post("/endpoint/{endpoint_name}/query") def query_endpoint(endpoint_name: str, query: Query): config = get_config_for_endpoint(endpoint_name) @@ -56,5 +103,21 @@ def chat_completion(chat_completion: ChatCompletion): return res +@app.post("/secrets/add") +@auth_required +async def store_secret(request: Request, secrets: Secrets): + user_res = supabase_client.auth.get_user() + set_secrets(user_res.user.id, secrets.secrets) + + return + + +@app.get("/secrets/fetch") +@auth_required +async def fetch_secret(request: Request): + user_res = supabase_client.auth.get_user() + return get_secrets(user_res.user.id) + + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/sagemaker/create_model.py b/src/sagemaker/create_model.py index da71ae2..cb9c447 100644 --- a/src/sagemaker/create_model.py +++ b/src/sagemaker/create_model.py @@ -12,7 +12,7 @@ from src.config import write_config from src.schemas.model import Model, ModelSource from src.schemas.deployment import Deployment -from src.session import session, sagemaker_session +from src.session import get_sagemaker_session, get_boto_session from src.console import console from src.utils.aws_utils import construct_s3_uri, is_s3_uri from src.utils.rich_utils import print_error, print_success @@ -36,7 +36,7 @@ def deploy_model(deployment: Deployment, model: Model): def deploy_huggingface_model(deployment: Deployment, model: Model): - region_name = session.region_name + region_name = get_boto_session().region_name task = get_hf_task(model) model.task = task env = { @@ -67,7 +67,8 @@ def deploy_huggingface_model(deployment: Deployment, model: Model): transformers_version="4.37", pytorch_version="2.1", py_version="py310", - image_uri=image_uri + image_uri=image_uri, + sagemaker_session=get_sagemaker_session(), ) endpoint_name = get_unique_endpoint_name( @@ -109,7 +110,7 @@ def deploy_huggingface_model(deployment: Deployment, model: Model): def deploy_custom_huggingface_model(deployment: Deployment, model: Model): - region_name = session.region_name + region_name = get_boto_session().region_name if model.location is None: print_error("Missing model source location.") return @@ -117,7 +118,7 @@ def deploy_custom_huggingface_model(deployment: Deployment, model: Model): s3_path = model.location if not is_s3_uri(model.location): # Local file. Upload to s3 before deploying - bucket = sagemaker_session.default_bucket() + bucket = get_boto_session().default_bucket() s3_path = construct_s3_uri(bucket, f"models/{model.id}") with console.status(f"[bold green]Uploading custom {model.id} model to S3 at {s3_path}...") as status: try: @@ -145,6 +146,7 @@ def deploy_custom_huggingface_model(deployment: Deployment, model: Model): transformers_version="4.37", pytorch_version="2.1", py_version="py310", + sagemaker_session=get_sagemaker_session() ) with console.status("[bold green]Deploying model...") as status: @@ -175,7 +177,7 @@ def deploy_custom_huggingface_model(deployment: Deployment, model: Model): def create_and_deploy_jumpstart_model(deployment: Deployment, model: Model): - region_name = session.region_name + region_name = get_boto_session().region_name endpoint_name = get_unique_endpoint_name( model.id, deployment.endpoint_name) deployment.endpoint_name = endpoint_name @@ -198,7 +200,11 @@ def create_and_deploy_jumpstart_model(deployment: Deployment, model: Model): console.print(table) jumpstart_model = JumpStartModel( - model_id=model.id, instance_type=deployment.instance_type, role=SAGEMAKER_ROLE) + model_id=model.id, + instance_type=deployment.instance_type, + role=SAGEMAKER_ROLE, + sagemaker_session=get_sagemaker_session() + ) # Attempt to deploy to AWS try: diff --git a/src/sagemaker/fine_tune_model.py b/src/sagemaker/fine_tune_model.py index 7960a3d..b147179 100644 --- a/src/sagemaker/fine_tune_model.py +++ b/src/sagemaker/fine_tune_model.py @@ -9,7 +9,7 @@ from src.console import console from src.schemas.model import Model, ModelSource from src.schemas.training import Training -from src.session import sagemaker_session +from src.session import get_sagemaker_session from src.utils.aws_utils import is_s3_uri from src.utils.rich_utils import print_success, print_error from transformers import AutoTokenizer @@ -91,7 +91,7 @@ def fine_tune_model(training: Training, model: Model): output_path=training.output_path, environment={"accept_eula": "true"}, role=SAGEMAKER_ROLE, - sagemaker_session=sagemaker_session, + sagemaker_session=get_sagemaker_session(), hyperparameters=hyperparameters ) case ModelSource.HuggingFace: diff --git a/src/sagemaker/query_endpoint.py b/src/sagemaker/query_endpoint.py index 0bcddc8..9312be9 100644 --- a/src/sagemaker/query_endpoint.py +++ b/src/sagemaker/query_endpoint.py @@ -12,7 +12,7 @@ from src.schemas.deployment import Deployment from src.schemas.model import Model from src.schemas.query import Query -from src.session import sagemaker_session +from src.session import get_sagemaker_session from typing import Dict, Tuple, Optional @@ -33,7 +33,7 @@ def parse_response(query_response): def query_hugging_face_endpoint(endpoint_name: str, user_query: Query, config: Tuple[Deployment, Model]): task = get_model_and_task(endpoint_name, config)['task'] predictor = HuggingFacePredictor(endpoint_name=endpoint_name, - sagemaker_session=sagemaker_session) + sagemaker_session=get_sagemaker_session()) query = user_query.query context = user_query.context diff --git a/src/sagemaker/search_jumpstart_models.py b/src/sagemaker/search_jumpstart_models.py index 3d1f658..74ac8a3 100644 --- a/src/sagemaker/search_jumpstart_models.py +++ b/src/sagemaker/search_jumpstart_models.py @@ -2,7 +2,7 @@ from enum import StrEnum, auto from sagemaker.jumpstart.notebook_utils import list_jumpstart_models from src.utils.rich_utils import print_error -from src.session import session, sagemaker_session +from src.session import get_boto_session, get_sagemaker_session class Frameworks(StrEnum): @@ -32,5 +32,5 @@ def search_sagemaker_jumpstart_model(): filter_value = "framework == {}".format(answers["framework"]) models = list_jumpstart_models(filter=filter_value, - region=session.region_name, sagemaker_session=sagemaker_session) + region=get_boto_session().region_name, sagemaker_session=get_sagemaker_session()) return models diff --git a/src/schemas/secret.py b/src/schemas/secret.py new file mode 100644 index 0000000..8c02ea9 --- /dev/null +++ b/src/schemas/secret.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, SecretStr +from enum import StrEnum + + +class SecretKeys(StrEnum): + AWSAccessKey = 'AWS_ACCESS_KEY' + AWSSecretAccessKey = 'AWS_SECRET_ACCESS_KEY' + HuggingFaceHubKey = 'HF_HUB_KEY' + + +class Secret(BaseModel): + key: SecretKeys + value: SecretStr + + +class Secrets(BaseModel): + secrets: list[Secret] diff --git a/src/session.py b/src/session.py index 64a1c73..be7aba8 100644 --- a/src/session.py +++ b/src/session.py @@ -1,5 +1,27 @@ import boto3 +import boto3.session import sagemaker +from src.supabase import supabase_client +from src.supabase.secret import get_secrets +from src.schemas.secret import SecretKeys -session = boto3.session.Session() -sagemaker_session = sagemaker.session.Session(boto_session=session) + +def get_boto_session(): + if supabase_client is not None and supabase_client.auth.get_user() is not None: + user_secrets = get_secrets(supabase_client.auth.get_user().user.id) + aws_access_key = next((secret.value.get_secret_value( + ) for secret in user_secrets if secret.key == SecretKeys.AWSAccessKey), None).decode() + aws_secret_access_key = next((secret.value.get_secret_value( + ) for secret in user_secrets if secret.key == SecretKeys.AWSSecretAccessKey), None).decode() + + # Sort out region + return boto3.session.Session( + aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_access_key, region_name='us-east-1') + else: + # rely on local setup + return boto3.session.Session() + + +def get_sagemaker_session(): + session = get_boto_session() + return sagemaker.session.Session(boto_session=session) diff --git a/src/supabase/__init__.py b/src/supabase/__init__.py new file mode 100644 index 0000000..a4b6ec8 --- /dev/null +++ b/src/supabase/__init__.py @@ -0,0 +1,16 @@ +import os +import urllib.parse as urlparse +from dotenv import load_dotenv +from supabase import create_client, Client + + +load_dotenv() + +url: str = os.environ.get("SUPABASE_URL") +key: str = os.environ.get("SUPABASE_KEY") +if url and key: + supabase_id = urlparse.urlparse(url).hostname.split('.')[0] + supabase_client: Client = create_client(url, key) +else: + supabase_id = None + supabase_client = None diff --git a/src/supabase/secret.py b/src/supabase/secret.py new file mode 100644 index 0000000..618e1c3 --- /dev/null +++ b/src/supabase/secret.py @@ -0,0 +1,27 @@ +from src.supabase import supabase_client +from src.schemas.secret import Secret +from src.utils.crypto import encrypt, decrypt +from pydantic import SecretStr + + +def set_secrets(user_id: str, secrets: list[Secret]): + data = [{ + 'user_id': user_id, + 'key': secret.key, + 'value': encrypt(secret.value.get_secret_value()) + } for secret in secrets] + + res = supabase_client.table('access_keys').upsert(data).execute() + return + + +def get_secrets(user_id: str) -> list[Secret]: + # TODO: optional keys to filter for + + # RLS for user_id but just in case + db_secrets = supabase_client.table( + 'access_keys').select('*').eq('user_id', user_id).execute().data + + secrets = [Secret(key=secret['key'], value=SecretStr(decrypt(secret['value']))) + for secret in db_secrets] + return secrets diff --git a/src/utils/crypto.py b/src/utils/crypto.py new file mode 100644 index 0000000..fcf777f --- /dev/null +++ b/src/utils/crypto.py @@ -0,0 +1,13 @@ +import os +from cryptography.fernet import Fernet + +ENCRYPTION_KEY = os.environ["DATA_ENCRYPTION_KEY"] +key = Fernet(ENCRYPTION_KEY) + + +def encrypt(value: str): + return str(key.encrypt(value.encode()), 'utf-8') + + +def decrypt(value: str): + return key.decrypt(value) diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index cbe3018..1253a05 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -7,7 +7,6 @@ from src.schemas.deployment import Deployment from src.schemas.model import Model, ModelSource from src.schemas.query import Query -from src.session import sagemaker_session from typing import Dict, Tuple, Optional HUGGING_FACE_HUB_TOKEN = dotenv_values(".env").get("HUGGING_FACE_HUB_KEY") diff --git a/utils/supabase/client.ts b/utils/supabase/client.ts index e2660d0..2ff2a15 100644 --- a/utils/supabase/client.ts +++ b/utils/supabase/client.ts @@ -2,6 +2,6 @@ import { createBrowserClient } from "@supabase/ssr"; export const createClient = () => createBrowserClient( - process.env.NEXT_PUBLIC_SUPABASE_URL!, - process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!, + process.env.SUPABASE_URL!, + process.env.SUPABASE_KEY!, ); diff --git a/utils/supabase/middleware.ts b/utils/supabase/middleware.ts index 8c6338c..b72b4d7 100644 --- a/utils/supabase/middleware.ts +++ b/utils/supabase/middleware.ts @@ -13,8 +13,8 @@ export const updateSession = async (request: NextRequest) => { }); const supabase = createServerClient( - process.env.NEXT_PUBLIC_SUPABASE_URL!, - process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!, + process.env.SUPABASE_URL!, + process.env.SUPABASE_KEY!, { cookies: { get(name: string) { diff --git a/utils/supabase/server.ts b/utils/supabase/server.ts index ecadfb1..31fe74c 100644 --- a/utils/supabase/server.ts +++ b/utils/supabase/server.ts @@ -5,8 +5,8 @@ export const createClient = () => { const cookieStore = cookies(); return createServerClient( - process.env.NEXT_PUBLIC_SUPABASE_URL!, - process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!, + process.env.SUPABASE_URL!, + process.env.SUPABASE_KEY!, { cookies: { get(name: string) {