diff --git a/.dockerignore b/.dockerignore index 2d1c758..c4346f7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,4 +5,91 @@ cache/ .llm-cache/ **/.llm-cache/ node_modules/ -**/node_modules/ \ No newline at end of file +**/node_modules/ +logs/ +openweights/dashboard/backend/backend.log +openweights/dashboard/backend/backend.pid +build-docker-in-runpod +.env +.env.dev +.env.prod +.env.ow-dev +.env.ow-migrations +openweights/dashboard/backend/.env.ow-dev +openweights/dashboard/backend/.env.ow-migrations +openweights/dashboard/frontend/.env.ow-dev +openweights/dashboard/frontend/.env.ow-migrations +artifacts/ +debug/ +example/.ipynb_checkpoints/ +example/Untitled1.ipynb +dev.py +planb/ +vulnerable-code/ +openweights/client/.llm-cache +# yeaa +cache +# Bazel +/bazel-* +/bazel-bin +/bazel-genfiles +/bazel-out +/bazel-testlogs +/bazel-workspace + +# Bazel symlinks +/bazel-* + +# Bazel disk cache +.bazel-cache/ + +# Bazel IntelliJ plugin +.ijwb/ + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +env/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +example/ft_job_artifacts/ +example/mcq_dataset.jsonl +openweights/jobs/unsloth/logp.ipynb + +openweights/dashboard/backend/static/ +example/_* +.logs +openweights/jobs/unsloth/check.ipynb +.cache +.logs/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7ef80c8..5f9402e 100644 --- a/.gitignore +++ b/.gitignore @@ -83,4 +83,5 @@ openweights/dashboard/backend/static/ example/_* .logs openweights/jobs/unsloth/check.ipynb -.cache \ No newline at end of file +.cache +admin/ \ No newline at end of file diff --git a/docs/ttl.md b/docs/ttl.md new file mode 100644 index 0000000..1ab5c8b --- /dev/null +++ b/docs/ttl.md @@ -0,0 +1,82 @@ +# TTL (Time To Live) Feature + +The TTL feature provides automatic pod termination to prevent runaway costs and ensure resource cleanup. + +## Overview + +- **Default TTL**: 24 hours for all pods +- **Automatic termination**: Pods self-terminate when TTL expires +- **Extensible**: TTL can be extended from within the pod +- **Dev mode support**: TTL monitoring runs for both dev and worker instances + +## Usage + +### Starting pods with custom TTL + +```bash +# Start dev instance with default 24-hour TTL +python openweights/cluster/start_runpod.py A100 default --dev_mode=true + +# Start dev instance with 2-hour TTL +python openweights/cluster/start_runpod.py A100 default --dev_mode=true --ttl_hours=2 + +# Start worker with 12-hour TTL +python openweights/cluster/start_runpod.py A100 finetuning --ttl_hours=12 +``` + +### Managing TTL from within a pod + +Once inside a pod, use the TTL manager utility: + +```bash +# Check current TTL status +python openweights/worker/services/ttl_manager.py --check + +# Extend TTL by 5 more hours +python openweights/worker/services/ttl_manager.py --extend 5 + +# Set TTL to 10 hours from now +python openweights/worker/services/ttl_manager.py --set 10 +``` + +### Manual TTL management + +You can also manually update the TTL by editing `~/shutdown.txt`: + +```bash +python3 -c " +import datetime +with open('~/shutdown.txt', 'w') as f: + new_time = datetime.datetime.now() + datetime.timedelta(hours=48) + f.write(new_time.isoformat()) +print(f'TTL extended to {new_time}') +" +``` + +## How it works + +1. **TTL Setup**: When a pod starts, the TTL monitor service calculates the shutdown time and writes it to `~/shutdown.txt` +2. **Monitoring**: A background service checks the shutdown time every minute +3. **Termination**: When the current time exceeds the shutdown time, the service terminates the pod using the RunPod API +4. **Extension**: Jobs or users can extend the TTL by updating the shutdown time in the file + +## Architecture + +- **TTL Monitor Service**: `openweights/worker/services/ttl_monitor.py` +- **TTL Manager Utility**: `openweights/worker/services/ttl_manager.py` +- **Configuration**: TTL passed via `TTL_HOURS` environment variable +- **Shutdown File**: `~/shutdown.txt` contains ISO format datetime + +## Environment Variables + +- `TTL_HOURS`: Number of hours for TTL (default: 24) +- `RUNPOD_API_KEY`: RunPod API key for pod termination +- `OW_DEV`: Indicates if running in dev mode (affects other services, not TTL) + +## Notes + +- TTL monitoring runs for both dev and worker instances +- This provides an additional safety net especially for dev instances +- Pod ID is automatically detected from RunPod metadata API +- Failed termination attempts are retried every minute +- TTL can be reset/extended unlimited times before expiration \ No newline at end of file diff --git a/entrypoint.sh b/entrypoint.sh index 084607c..812023b 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -13,78 +13,44 @@ if [ -n "$PUBLIC_KEY" ]; then else echo "[$(date)] No PUBLIC_KEY provided, skipping SSH key setup" fi -echo "Authorized keys added." -# if OW_COMMIT is set, checkout the commit +# Repository checkout if needed echo "[$(date)] Checking for OW_COMMIT environment variable" if [ -n "$OW_COMMIT" ]; then - echo "[$(date)] Starting repository checkout for commit: $OW_COMMIT" - rm -rf openweights - git clone https://github.com/longtermrisk/openweights.git openweights_dev - cd openweights_dev - git checkout $OW_COMMIT - mv openweights ../openweights - cd .. - rm -rf openweights_dev - echo "[$(date)] Repository checkout completed" + echo "[$(date)] Starting repository checkout" + python3 openweights/worker/services/checkout.py + if [ $? -ne 0 ]; then + echo "[$(date)] Repository checkout failed" + exit 1 + fi else echo "[$(date)] No OW_COMMIT specified, skipping repository checkout" fi # Login to huggingface echo "[$(date)] Attempting to login to Hugging Face" -python3 -c "from huggingface_hub.hf_api import HfFolder; import os; HfFolder.save_token(os.environ['HF_TOKEN'])" +python3 openweights/worker/services/hf_login.py echo "[$(date)] Hugging Face login completed" -# Generate SSH host keys -echo "[$(date)] Generating SSH host keys" +# Generate SSH host keys and start SSH service +echo "[$(date)] Setting up SSH service" ssh-keygen -A -echo "[$(date)] SSH host keys generated" - -# Start SSH service -echo "[$(date)] Starting SSH service" service ssh start echo "[$(date)] SSH service started" # Print sshd logs to stdout tail -f /var/log/auth.log & -# Start a simple server that serves the content of main.log on port 10101 -# Create main.log if it doesn't exist -touch main.log - -# Start a simple Python HTTP server to serve files from logs/ +# Start background services echo "[$(date)] Starting HTTP log server on port 10101" -python3 -c ' -import http.server -import socketserver -import os +mkdir logs +python3 openweights/worker/services/log_server.py & -class LogHandler(http.server.SimpleHTTPRequestHandler): - def do_GET(self): - # If path is /logs, serve logs/main - if self.path == "/logs": - file_path = "logs/main" - else: - # Remove leading slash and ensure path is within logs directory - path = self.path.lstrip("/") - file_path = os.path.join("logs", path) - - # Check if file exists and is within logs directory - if os.path.exists(file_path) and os.path.commonprefix([os.path.abspath(file_path), os.path.abspath("logs")]) == os.path.abspath("logs"): - self.send_response(200) - self.send_header("Content-type", "text/plain") - self.end_headers() - with open(file_path, "rb") as f: - self.wfile.write(f.read()) - else: - self.send_error(404, "File not found") +# Start TTL monitoring service +echo "[$(date)] Starting TTL monitoring service" +python3 openweights/worker/services/ttl_monitor.py & -with socketserver.TCPServer(("", 10101), LogHandler) as httpd: - httpd.serve_forever() -' & - -echo "[$(date)] HTTP log server started" +echo "[$(date)] All services started" # Execute the main application or run in dev mode if [ "$OW_DEV" = "true" ]; then @@ -92,5 +58,5 @@ if [ "$OW_DEV" = "true" ]; then exec tail -f /dev/null else echo "[$(date)] Starting main application" - exec python3 openweights/worker/main.py \ > >(tee logs/main) \ 2> >(tee -a logs/main >&2) -fi + exec python3 openweights/worker/main.py > >(tee logs/main) 2> >(tee -a logs/main >&2) +fi \ No newline at end of file diff --git a/example/ui.py b/example/ui.py new file mode 100644 index 0000000..31e8474 --- /dev/null +++ b/example/ui.py @@ -0,0 +1,29 @@ +import gradio as gr # type: ignore +from openai import OpenAI # type: ignore + +def chat_with(model): + client = OpenAI(base_url="https://ag5a2je35kxz7y-8000.proxy.runpod.net/v1") + def predict(message, history): + messages = [] + for human, assistant in history: + messages.append({"role": "user", "content": human}) + messages.append({"role": "assistant", "content": assistant}) + messages.append({"role": "user", "content": message}) + + stream = client.chat.completions.create( + model=model, + messages=messages, + stream=True + ) + + partial_message = "" + for chunk in stream: + if chunk.choices[0].delta.content is not None: + partial_message += chunk.choices[0].delta.content + yield partial_message + + gr.ChatInterface(predict).queue().launch() + + +if __name__ == '__main__': + chat_with('Qwen/Qwen3-235B-A22B-Instruct-2507-FP8') \ No newline at end of file diff --git a/openweights/client/__init__.py b/openweights/client/__init__.py index 53506fa..57819c1 100644 --- a/openweights/client/__init__.py +++ b/openweights/client/__init__.py @@ -1,16 +1,6 @@ -import asyncio -import atexit -import json -from typing import Optional, BinaryIO, Dict, Any, List, Union +from typing import Optional, Dict, Any import os -import sys -from postgrest.exceptions import APIError -import hashlib -from datetime import datetime -from openai import OpenAI, AsyncOpenAI -import backoff -import time -from supabase import create_client, Client +from supabase import create_client from supabase.lib.client_options import ClientOptions from openweights.client.files import Files, validate_messages, validate_preference_dataset @@ -20,7 +10,13 @@ from openweights.client.temporary_api import TemporaryApi from openweights.client.chat import ChatCompletions, AsyncChatCompletions from openweights.client.utils import group_models_or_adapters_by_model, get_lora_rank +from openweights.client.decorators import supabase_retry +import logging + +# Reduce noise to only warnings+errors +for name in ["httpx", "httpx._client", "postgrest", "gotrue", "supabase"]: + logging.getLogger(name).setLevel(logging.WARNING) def create_authenticated_client(supabase_url: str, supabase_anon_key: str, auth_token: Optional[str] = None): """Create a Supabase client with authentication. @@ -114,7 +110,7 @@ def __init__(self, setattr(self, name, cls(self)) OpenWeights._INSTANCES.append(self) - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def get_organization_id(self) -> str: """Get the organization ID associated with the current token""" result = self._supabase.rpc('get_organization_from_token').execute() @@ -122,7 +118,7 @@ def get_organization_id(self) -> str: raise ValueError("Could not determine organization ID from token") return result.data - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def get_organization_name(self): """Get the organization ID associated with the current token""" result = self._supabase.table('organizations')\ @@ -131,7 +127,7 @@ def get_organization_name(self): .single().execute() return result.data['name'] - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def get_hf_org(self): """Get organization secrets from the database.""" result = self._supabase.table('organization_secrets')\ diff --git a/openweights/client/cache_on_disk.py b/openweights/client/cache_on_disk.py deleted file mode 100644 index 6c291cf..0000000 --- a/openweights/client/cache_on_disk.py +++ /dev/null @@ -1,78 +0,0 @@ -import asyncio -import hashlib -import json -import os -from functools import wraps - -import diskcache as dc - -class CacheOnDisk: - def __init__(self, n_semaphore=100, cache_dir=None): - """ - Create a CacheOnDisk instance. - - Parameters: - n_semaphore (int): Maximum number of parallel cache accesses. - cache_dir (str): Path to the cache directory. Defaults to a ".llm-cache" - directory alongside this file. - """ - if cache_dir is None: - cache_dir = os.path.join(os.path.dirname(__file__), ".llm-cache") - os.makedirs(cache_dir, exist_ok=True) - self.cache = dc.FanoutCache(cache_dir, shards=64, timeout=10) - self.semaphore = asyncio.Semaphore(n_semaphore) - - def __call__(self, possible_func=None, *, required_kwargs=None): - """ - When used as a decorator, CacheOnDisk works in two ways: - - 1. As a no-argument decorator: - @cache_on_disk - async def my_func(...): ... - - 2. As a parameterized decorator: - @cache_on_disk(required_kwargs=["foo"]) - async def my_func(...): ... - - The `required_kwargs` parameter (a list) determines which keyword - arguments are needed for caching. If they are not present, the function - is simply executed. - """ - if possible_func is not None and callable(possible_func): - # Used as "@cache_on_disk" without explicit parameters. - return self._make_decorator(required_kwargs or [])(possible_func) - else: - # Used as "@cache_on_disk(required_kwargs=[...])". Return a decorator. - required_kwargs = required_kwargs or [] - def decorator(func): - return self._make_decorator(required_kwargs)(func) - return decorator - - def _make_decorator(self, required_kwargs): - def decorator(function): - @wraps(function) - async def wrapper(*args, **kwargs): - # Only attempt caching if all required keyword arguments are present. - if not all(k in kwargs for k in required_kwargs): - return await function(*args, **kwargs) - - # Serialize args/kwargs and compute the cache key. - serialized = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True) - key = hashlib.sha256(serialized.encode()).hexdigest() - - # Limit the number of concurrent cache accesses. - async with self.semaphore: - cached_result = await asyncio.to_thread(self.cache.get, key, None) - if cached_result is not None: - return cached_result - - result = await function(*args, **kwargs) - - async with self.semaphore: - await asyncio.to_thread(self.cache.set, key, result) - return result - return wrapper - return decorator - -# Create a default object for easy importing. -cache_on_disk = CacheOnDisk() \ No newline at end of file diff --git a/openweights/client/chat.py b/openweights/client/chat.py index b91a66b..1fd809a 100644 --- a/openweights/client/chat.py +++ b/openweights/client/chat.py @@ -1,10 +1,9 @@ from collections import defaultdict import openai import asyncio -from openweights.client.cache_on_disk import cache_on_disk -import backoff from openweights.client.temporary_api import TemporaryApi, APIS import openai +from openweights.client.decorators import openai_retry DEPLOYMENT_QUEUE = [] @@ -31,7 +30,6 @@ def __init__(self, ow, deploy_kwargs={}, request_timeout=300, per_token_timeout= self.per_token_timeout = per_token_timeout async def create(self, model: str, **kwargs): - # @cache_on_disk(required_kwargs=['seed']) async def cached_create(model, **kwargs): return await self._create(model, **kwargs) return await cached_create(model, **kwargs) @@ -41,18 +39,7 @@ async def _create(self, model, **kwargs): async with api.sem: return await self._create_with_backoff(api, model, **kwargs) - @backoff.on_exception( - wait_gen=backoff.expo, - exception=( - openai.RateLimitError, - openai.APIConnectionError, - openai.APITimeoutError, - openai.InternalServerError - ), - max_value=60, - factor=1.5, - max_tries=10 - ) + @openai_retry() async def _create_with_backoff(self, api, model, **kwargs): timeout = kwargs.pop('timeout', None) or max( self.request_timeout, diff --git a/openweights/client/decorators.py b/openweights/client/decorators.py new file mode 100644 index 0000000..93a8057 --- /dev/null +++ b/openweights/client/decorators.py @@ -0,0 +1,202 @@ +import backoff +import httpx +import openai +import functools +import time +import random +from typing import Optional + +# Optional: postgrest is used under the hood by supabase-py; handle if present +try: + import postgrest +except Exception: # pragma: no cover + postgrest = None + + +def _is_transient_http_status(status: int) -> bool: + # Retry on server errors & rate limiting + return status >= 500 or status == 429 + + +def _is_transient(exc: BaseException) -> bool: + """ + Returns True for errors that are likely to be temporary network/service hiccups: + - httpx timeouts / connect errors / protocol errors + - HTTPStatusError with 5xx or 429 + - postgrest.APIError with 5xx or 429 (if postgrest available) + """ + # httpx family (network-ish) + if isinstance(exc, ( + httpx.ConnectError, + httpx.ConnectTimeout, + httpx.ReadTimeout, + httpx.WriteTimeout, + httpx.PoolTimeout, + httpx.NetworkError, + httpx.RemoteProtocolError, + )): + return True + + # httpx raised because .raise_for_status() was called + if isinstance(exc, httpx.HTTPStatusError): + try: + return _is_transient_http_status(exc.response.status_code) + except Exception: + return False + + # postgrest API errors (supabase-py) + if postgrest is not None and isinstance(exc, postgrest.APIError): + try: + code = getattr(exc, "code", None) + # code may be a string; try to coerce + code_int = int(code) if code is not None else None + return code_int is not None and _is_transient_http_status(code_int) + except Exception: + return False + + # Sometimes libraries wrap the real error; walk the causal chain + cause = getattr(exc, "__cause__", None) or getattr(exc, "__context__", None) + if cause and cause is not exc: + return _is_transient(cause) + + return False + + +def openai_retry( + *, + # Exponential mode (default) + factor: float = 1.5, + max_value: int = 60, + # Constant mode (set interval to enable) + interval: Optional[float] = None, + max_time: Optional[float] = None, + # Common + max_tries: int = 10, +): + """ + Retry transient OpenAI API errors with backoff + jitter. + + Modes: + • Exponential (default): pass `factor`, `max_value`, `max_tries` + • Constant: pass `interval` (seconds) and optionally `max_time`, `max_tries` + + Examples: + @openai_retry() # exponential (default) + def call(...): ... + + @openai_retry(interval=10, max_time=3600, max_tries=3600) # constant + def call(...): ... + + Retries on: + - openai.RateLimitError + - openai.APIConnectionError + - openai.APITimeoutError + - openai.InternalServerError + """ + exceptions = ( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, + ) + + def _decorator(fn): + if interval is not None: + # Constant backoff mode + decorated = backoff.on_exception( + wait_gen=backoff.constant, + exception=exceptions, + interval=interval, + max_time=max_time, # total wall-clock cap (optional) + max_tries=max_tries, # total attempts cap + jitter=backoff.full_jitter, + logger=None, # stay quiet + )(fn) + else: + # Exponential backoff mode + decorated = backoff.on_exception( + wait_gen=backoff.expo, + exception=exceptions, + factor=factor, # growth factor + max_value=max_value, # cap per-wait + max_tries=max_tries, # total attempts cap + jitter=backoff.full_jitter, + logger=None, # stay quiet + )(fn) + + @functools.wraps(fn) + def inner(*args, **kwargs): + return decorated(*args, **kwargs) + + return inner + + return _decorator + + +# sentinel to indicate "raise on exhaustion" +_RAISE = object() + +def supabase_retry( + max_time: float = 60, + max_tries: int = 8, + *, + base: float = 1.0, # initial delay + factor: float = 2.0, # exponential growth + max_delay: float = 60.0, # cap for each delay step + return_on_exhaustion=_RAISE # e.g., set to None to "ignore" after retries +): + """ + Retries ONLY transient Supabase/http errors (see _is_transient) with exponential backoff + full jitter. + If `return_on_exhaustion` is not `_RAISE`, return that value after retry budget is exhausted for a + transient error. Non-transient errors still raise immediately. + + Args: + max_time: maximum total wall-clock seconds spent retrying + max_tries: maximum attempts (including the first) + base: initial backoff delay (seconds) + factor: exponential growth factor per attempt (>= 1) + max_delay: max per-attempt delay (seconds) + return_on_exhaustion: value to return after exhausting retries on a transient error. + Leave as `_RAISE` to re-raise instead. + """ + def _next_delay(attempt: int) -> float: + # attempt starts at 1 for the first retry (after one failure) + raw = base * (factor ** (attempt - 1)) + return min(raw, max_delay) * random.random() # full jitter + + def _decorator(fn): + @functools.wraps(fn) + def inner(*args, **kwargs): + # quick path: try once + start = time.monotonic() + attempt = 0 + while True: + try: + return fn(*args, **kwargs) + except Exception as exc: + # Non-transient? bubble up immediately. + if not _is_transient(exc): + raise + + attempt += 1 + # Have we exhausted attempts? + if attempt >= max_tries: + if return_on_exhaustion is _RAISE: + raise + return return_on_exhaustion + + # Compute delay with jitter, ensure we don't break max_time + delay = _next_delay(attempt) + if max_time is not None: + elapsed = time.monotonic() - start + remaining = max_time - elapsed + if remaining <= 0: + if return_on_exhaustion is _RAISE: + raise + return return_on_exhaustion + # don't sleep past the deadline + delay = min(delay, max(0.0, remaining)) + + time.sleep(delay) + return inner + return _decorator diff --git a/openweights/client/events.py b/openweights/client/events.py index d4bdfd1..8aafaca 100644 --- a/openweights/client/events.py +++ b/openweights/client/events.py @@ -1,10 +1,11 @@ from typing import Optional, Dict, Any, List - +from openweights.client.decorators import supabase_retry class Events(): def __init__(self, supabase): self._supabase = supabase + @supabase_retry() def list(self, job_id: Optional[str]=None, run_id: Optional[str]=None): """List events by job_id or run_id, sorted by created_at in ascending order""" if run_id: diff --git a/openweights/client/files.py b/openweights/client/files.py index 1b4dd29..05762ea 100644 --- a/openweights/client/files.py +++ b/openweights/client/files.py @@ -1,12 +1,14 @@ from typing import Optional, BinaryIO, Dict, Any, List, Union import os import hashlib +import io +import tempfile from datetime import datetime from supabase import Client import backoff import json import logging - +from openweights.client.decorators import supabase_retry def validate_message(message): try: @@ -72,14 +74,17 @@ def __init__(self, supabase: Client, organization_id: str): self._supabase = supabase self._org_id = organization_id - def _calculate_file_hash(self, file: BinaryIO) -> str: + def _calculate_file_hash(self, stream: BinaryIO) -> str: """Calculate SHA-256 hash of file content""" sha256_hash = hashlib.sha256() - for byte_block in iter(lambda: file.read(4096), b""): + for byte_block in iter(lambda: stream.read(4096), b""): sha256_hash.update(byte_block) # Add the org ID to the hash to ensure uniqueness sha256_hash.update(self._org_id.encode()) - file.seek(0) # Reset file pointer + try: + stream.seek(0) + except Exception: + pass return f"file-{sha256_hash.hexdigest()[:12]}" def _get_storage_path(self, file_id: str) -> str: @@ -94,45 +99,75 @@ def _get_storage_path(self, file_id: str) -> str: # Fallback if RPC fails return f"organizations/{self._org_id}/{file_id}" - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def create(self, file: BinaryIO, purpose: str) -> Dict[str, Any]: - """Upload a file and create a database entry""" - file.seek(0) - file_id = f"{purpose}:{self._calculate_file_hash(file)}" + """Upload a file and create a database entry. + Robust to retries by buffering the input stream into memory once + and using fresh BytesIO streams for hashing, validation, and upload. + """ + # Read all bytes once; support both real files and file-like objects + try: + # Ensure at start (some callers might pass a consumed stream) + if hasattr(file, 'seek'): + try: + file.seek(0) + except Exception: + pass + data = file.read() + finally: + # Do not close the caller's file handle; just leave it as-is + # (the caller used a context manager typically) + pass + + if not isinstance(data, (bytes, bytearray)): + raise TypeError("Files.create expects a binary file-like object returning bytes") + + file_id = f"{purpose}:{self._calculate_file_hash(io.BytesIO(data))}" # If the file already exists, return the existing file try: existing_file = self._supabase.table('files').select('*').eq('id', file_id).single().execute().data if existing_file: return existing_file - except: + except Exception: pass # File doesn't exist yet, continue with creation - # Validate file content - if not self.validate(file, purpose): + # Validate file content using a fresh buffer + if not self.validate(io.BytesIO(data), purpose): raise ValueError("File content is not valid") - file_size = os.fstat(file.fileno()).st_size + file_size = len(data) filename = getattr(file, 'name', 'unknown') # Get organization-specific storage path storage_path = self._get_storage_path(file_id) # Store file in Supabase Storage with organization path - self._supabase.storage.from_('files').upload( - path=storage_path, - file=file - ) - + # storage3's sync client expects a file path-like in some versions; write to a temp file for compatibility + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.write(data) + tmp.flush() + tmp_path = tmp.name + try: + self._supabase.storage.from_('files').upload( + path=storage_path, + file=tmp_path, + file_options={"upsert": "true"} + ) + finally: + try: + os.remove(tmp_path) + except Exception: + pass # Create database entry - data = { + data_row = { 'id': file_id, 'filename': filename, 'purpose': purpose, 'bytes': file_size } - result = self._supabase.table('files').insert(data).execute() + result = self._supabase.table('files').insert(data_row).execute() return { 'id': file_id, @@ -143,14 +178,14 @@ def create(self, file: BinaryIO, purpose: str) -> Dict[str, Any]: 'purpose': purpose, } - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def content(self, file_id: str) -> bytes: """Get file content""" storage_path = self._get_storage_path(file_id) return self._supabase.storage.from_('files').download(storage_path) def validate(self, file: BinaryIO, purpose: str) -> bool: - """Validate file content""" + """Validate file content. The passed stream will be consumed.""" if purpose in ['conversations']: content = file.read().decode('utf-8') return validate_messages(content) @@ -160,7 +195,7 @@ def validate(self, file: BinaryIO, purpose: str) -> bool: else: return True - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def get_by_id(self, file_id: str) -> Dict[str, Any]: """Get file details by ID""" return self._supabase.table('files').select('*').eq('id', file_id).single().execute().data diff --git a/openweights/client/jobs.py b/openweights/client/jobs.py index 4071e16..7089cd6 100644 --- a/openweights/client/jobs.py +++ b/openweights/client/jobs.py @@ -1,18 +1,14 @@ -import re import json -from typing import BinaryIO, Dict, Any, List, Union, Tuple, Type +from typing import Dict, Any, List, Type import os from postgrest.exceptions import APIError -import backoff import hashlib -from supabase import Client -from pydantic import BaseModel, Field +from pydantic import BaseModel from datetime import datetime from dataclasses import dataclass -from openweights.client.utils import resolve_lora_model, get_lora_rank from openweights.cluster.start_runpod import GPUs - +from openweights.client.decorators import supabase_retry @dataclass class Job: @@ -133,14 +129,7 @@ def _upload_mounted_files(self, extra_files=None) -> Dict[str, str]: return uploaded_files - @backoff.on_exception( - backoff.constant, - Exception, - interval=1, - max_time=60, - max_tries=60, - on_backoff=lambda details: print(f"Retrying... {details['exception']}"), - ) + @supabase_retry() def list(self, limit: int = 10) -> List[Dict[str, Any]]: """List jobs""" result = ( @@ -152,14 +141,7 @@ def list(self, limit: int = 10) -> List[Dict[str, Any]]: ) return [Job(**row, _manager=self) for row in result.data] - @backoff.on_exception( - backoff.constant, - Exception, - interval=1, - max_time=60, - max_tries=60, - on_backoff=lambda details: print(f"Retrying... {details['exception']}"), - ) + @supabase_retry() def retrieve(self, job_id: str) -> Dict[str, Any]: """Get job details""" result = ( @@ -167,14 +149,7 @@ def retrieve(self, job_id: str) -> Dict[str, Any]: ) return Job(**result.data, _manager=self) - @backoff.on_exception( - backoff.constant, - Exception, - interval=1, - max_time=60, - max_tries=60, - on_backoff=lambda details: print(f"Retrying... {details['exception']}"), - ) + @supabase_retry() def cancel(self, job_id: str) -> Dict[str, Any]: """Cancel a job""" result = ( @@ -185,14 +160,7 @@ def cancel(self, job_id: str) -> Dict[str, Any]: ) return Job(**result.data[0], _manager=self) - @backoff.on_exception( - backoff.constant, - Exception, - interval=1, - max_time=60, - max_tries=60, - on_backoff=lambda details: print(f"Retrying... {details['exception']}"), - ) + @supabase_retry() def restart(self, job_id: str) -> Dict[str, Any]: """Restart a job""" result = ( @@ -220,14 +188,7 @@ def compute_id(self, data: Dict[str, Any]) -> str: job_id += f"-{data['params']['validated_params']['job_id_suffix']}" return job_id - @backoff.on_exception( - backoff.constant, - Exception, - interval=1, - max_time=60, - max_tries=60, - on_backoff=lambda details: print(f"Retrying... {details['exception']}"), - ) + @supabase_retry() def get_or_create_or_reset(self, data: Dict[str, Any]) -> Dict[str, Any]: """If job exists and is [pending, in_progress, completed] return it. If job exists and is [failed, canceled] reset it to pending and return it. @@ -274,14 +235,7 @@ def get_or_create_or_reset(self, data: Dict[str, Any]) -> Dict[str, Any]: else: raise ValueError(f"Invalid job status: {job['status']}") - @backoff.on_exception( - backoff.constant, - Exception, - interval=1, - max_time=60, - max_tries=60, - on_backoff=lambda details: print(f"Retrying... {details['exception']}"), - ) + @supabase_retry() def find(self, **params) -> List[Dict[str, Any]]: """Find jobs by their JSON values in job.params Example: diff --git a/openweights/client/run.py b/openweights/client/run.py index 8eba698..b4c5039 100644 --- a/openweights/client/run.py +++ b/openweights/client/run.py @@ -1,39 +1,11 @@ -import asyncio -import atexit -import json -from typing import Optional, BinaryIO, Dict, Any, List, Union +from typing import Optional, BinaryIO, Dict, Any, List import os import sys from postgrest.exceptions import APIError import hashlib from datetime import datetime -import backoff from supabase import Client -import httpx -import logging -from functools import wraps - - -def is_transient_error(e): - """Check if an error is likely transient and should be retried""" - if isinstance(e, (httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectTimeout)): - return True - return False - -def retry_or_ignore(func, n_retries=5): - """Retry a function, if it continues to fail, ignore the error""" - @wraps(func) - def wrapper(*args, **kwargs): - for i in range(n_retries): - try: - return func(*args, **kwargs) - except Exception as e: - logging.error(f"Error in {func.__name__} (attempt {i+1}/{n_retries}): {e}") - if i == n_retries - 1: - # On last attempt, return None instead of raising - return None - return None - return wrapper +from openweights.client.decorators import supabase_retry class Run: @@ -43,14 +15,7 @@ def __init__(self, client: 'OpenWeights', job_id: Optional[str] = None, worker_i self.organization_id = organization_id self.id = run_id or os.getenv('OPENWEIGHTS_RUN_ID') if self.id: - # Run ID exists, fetch the data - try: - self._fetch_and_init_run(job_id, worker_id) - except Exception as e: - if not is_transient_error(e): - raise - # For transient errors, retry with backoff - self._fetch_and_init_run_with_retry(job_id, worker_id) + self._fetch_and_init_run(job_id, worker_id) else: # Create new run data = { @@ -83,15 +48,7 @@ def __init__(self, client: 'OpenWeights', job_id: Optional[str] = None, worker_i result = self._supabase.table('runs').insert(data).execute() self._load_data(result.data[0]) - @backoff.on_exception( - backoff.expo, - (httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectTimeout), - max_tries=5 - ) - def _fetch_and_init_run_with_retry(self, job_id, worker_id): - """Fetch run data with retry logic""" - self._fetch_and_init_run(job_id, worker_id) - + @supabase_retry() def _fetch_and_init_run(self, job_id, worker_id): """Fetch run data and initialize""" try: @@ -113,11 +70,7 @@ def _fetch_and_init_run(self, job_id, worker_id): self._load_data(run_data) - @backoff.on_exception( - backoff.expo, - (httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectTimeout), - max_tries=5 - ) + @supabase_retry() def _get_job_org_id_with_retry(self, job_id): """Get job organization ID with retry logic""" return self._supabase.table('jobs').select('organization_id').eq('id', job_id).single().execute() @@ -144,7 +97,7 @@ def get(supabase: Client, run_id: int) -> 'Run': run._load_data(result.data) return run - @retry_or_ignore + @supabase_retry(return_on_exhaustion=None) def update(self, status: Optional[str] = None, logfile: Optional[str] = None): """Update run status and/or logfile""" data = {} @@ -157,7 +110,7 @@ def update(self, status: Optional[str] = None, logfile: Optional[str] = None): result = self._supabase.table('runs').update(data).eq('id', self.id).execute() self._load_data(result.data[0]) - @retry_or_ignore + @supabase_retry(return_on_exhaustion=None) def log(self, event_data: Dict[str, Any], file: Optional[BinaryIO] = None): """Log an event for this run""" if file: @@ -204,7 +157,7 @@ def __init__(self, client: 'OpenWeights'): self.client = client self._supabase = client._supabase - @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}")) + @supabase_retry() def list(self, job_id: Optional[str] = None, worker_id: Optional[str] = None, limit: int = 10, status: Optional[str]=None) -> List[Dict[str, Any]]: """List runs by job_id or worker_id""" query = self._supabase.table('runs').select('*').limit(limit) diff --git a/openweights/client/temporary_api.py b/openweights/client/temporary_api.py index 3e44434..1a6289f 100644 --- a/openweights/client/temporary_api.py +++ b/openweights/client/temporary_api.py @@ -1,31 +1,14 @@ import asyncio -import atexit from openai import OpenAI, AsyncOpenAI -import backoff import time import threading from datetime import datetime, timedelta, timezone -import json -from huggingface_hub import HfApi, hf_hub_download -from collections import defaultdict -from typing import List, Dict -from functools import lru_cache - -from openweights.client.utils import group_models_or_adapters_by_model, get_lora_rank +from openweights.client.decorators import openai_retry APIS = {} -def on_backoff(details): - exception_info = details['exception'] - if 'Bad Gateway' in exception_info: - return - if '' in str(exception_info): - exception_info = str(exception_info).split('<title>')[1].split('')[0] - print(f"Retrying... {exception_info}") - - class TemporaryApi: def __init__(self, ow, job_id): self.ow = ow @@ -68,12 +51,12 @@ def up(self): def __enter__(self): return self.up() - @backoff.on_exception(backoff.constant, Exception, interval=10, max_time=3600, max_tries=3600) + @openai_retry(interval=10, max_time=3600, max_tries=3600) def wait_until_ready(self, openai, model): print('Waiting for API to be ready...') openai.chat.completions.create(model=model, messages=[dict(role='user', content='Hello')], max_tokens=200) - @backoff.on_exception(backoff.constant, Exception, interval=1, max_tries=10) + @openai_retry(interval=1, max_tries=10) async def async_up(self): self._stop_timeout_thread = False self._timeout_thread = threading.Thread(target=self._manage_timeout, daemon=True) diff --git a/openweights/cluster/start_runpod.py b/openweights/cluster/start_runpod.py index c4d798c..0857704 100644 --- a/openweights/cluster/start_runpod.py +++ b/openweights/cluster/start_runpod.py @@ -1,6 +1,18 @@ """ Usage: - python start_runpod.py --gpu A6000 --container_disk_in_gb 25 --volume_in_gb 30 + python start_runpod.py --gpu A6000 --container_disk_in_gb 25 --volume_in_gb 30 --ttl_hours 24 + +TTL (Time To Live) Feature: + - All pods have a default TTL of 24 hours to prevent runaway costs + - TTL can be customized with --ttl_hours parameter + - TTL can be extended from within the pod by updating ~/shutdown.txt with a new timestamp + - Example to extend TTL from within pod: + python3 -c " + import datetime + with open('~/shutdown.txt', 'w') as f: + new_time = datetime.datetime.now() + datetime.timedelta(hours=48) + f.write(new_time.isoformat()) + " Note: possible unknown error with echo when running the script. """ @@ -16,7 +28,6 @@ from scp import SCPClient from functools import lru_cache -load_dotenv(override=True) IMAGES = { 'default': 'nielsrolf/ow-default', @@ -81,7 +92,6 @@ # Build map of memory -> hardware configu HARDWARE_CONFIG = {} -print("Available hardware configurations:") def populate_hardware_config(runpod_client): runpod_gpus = runpod_client.get_gpus() for gpu_short, gpu_full in GPUs.items(): @@ -199,7 +209,7 @@ def check_correct_cuda(pod, allowed=allowed_cuda_versions, runpod_client=None): @backoff.on_exception(backoff.expo, Exception, max_time=60, max_tries=5) -def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, pending_workers=None, env=None, runpod_client=None): +def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, ttl_hours=24, pending_workers=None, env=None, runpod_client=None): client = runpod_client or runpod gpu = GPUs[gpu] # default name: -worker- @@ -213,7 +223,9 @@ def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=5 env.update({ 'WORKER_ID': worker_id, 'DOCKER_IMAGE': image, - 'OW_DEV': 'true' if dev_mode else 'false' + 'OW_DEV': 'true' if dev_mode else 'false', + 'TTL_HOURS': str(ttl_hours), + 'RUNPOD_API_KEY': os.getenv('RUNPOD_API_KEY') }) if worker_id is None: worker_id = uuid.uuid4().hex[:8] @@ -239,7 +251,7 @@ def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=5 return pod -def start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, env=None, runpod_client=None): +def start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, ttl_hours=24, env=None, runpod_client=None): pending_workers = [] if dev_mode: env = {var: os.environ.get(var) for var in [ @@ -249,7 +261,7 @@ def start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=50 runpod.api_key = os.getenv('RUNPOD_API_KEY') runpod_client = runpod try: - pod = _start_worker(gpu, image, count, name, container_disk_in_gb, volume_in_gb, worker_id, dev_mode, pending_workers, env, runpod_client) + pod = _start_worker(gpu, image, count, name, container_disk_in_gb, volume_in_gb, worker_id, dev_mode, ttl_hours, pending_workers, env, runpod_client) return pod except Exception as e: import traceback diff --git a/openweights/dashboard/backend/main.py b/openweights/dashboard/backend/main.py index c013938..870e1bd 100644 --- a/openweights/dashboard/backend/main.py +++ b/openweights/dashboard/backend/main.py @@ -17,6 +17,7 @@ CORSMiddleware, allow_origins=[ "http://localhost:5173", # Vite dev server + "http://localhost:5174", # Vite dev server "http://localhost:8124", # FastAPI dev server "http://localhost:4173", # Vite preview ], diff --git a/openweights/dashboard/frontend/src/App.tsx b/openweights/dashboard/frontend/src/App.tsx index 9921ed2..9e4728b 100644 --- a/openweights/dashboard/frontend/src/App.tsx +++ b/openweights/dashboard/frontend/src/App.tsx @@ -16,7 +16,7 @@ import { import MoreVertIcon from '@mui/icons-material/MoreVert'; import SettingsIcon from '@mui/icons-material/Settings'; import { JobsView } from './components/JobsView'; -import { RunsView } from './components/RunsView'; +// import { RunsView } from './components/RunsView'; import { WorkersView } from './components/WorkersView'; import { JobDetailView, RunDetailView, WorkerDetailView } from './components/DetailViews'; import { Auth } from './components/Auth/Auth'; @@ -123,7 +123,7 @@ function NavBar() { {currentOrganization && ( <> - + )} @@ -216,10 +216,10 @@ function OrganizationRoutes() { } /> } /> - } /> - } /> + } /> } /> + } /> } /> } /> diff --git a/openweights/dashboard/frontend/src/components/JobsListView.tsx b/openweights/dashboard/frontend/src/components/JobsListView.tsx index 4094257..2ca738c 100644 --- a/openweights/dashboard/frontend/src/components/JobsListView.tsx +++ b/openweights/dashboard/frontend/src/components/JobsListView.tsx @@ -38,6 +38,7 @@ interface JobsListViewProps { onPageChange: (event: unknown, newPage: number) => void; onRowsPerPageChange: (event: React.ChangeEvent) => void; orgId: string; + onCancelJob: (jobId: string) => Promise; } export const JobsListView: React.FC = ({ @@ -48,6 +49,7 @@ export const JobsListView: React.FC = ({ onPageChange, onRowsPerPageChange, orgId, + onCancelJob, }) => { const filteredJobs = jobs.filter(job => { const searchStr = filter.toLowerCase(); @@ -77,6 +79,7 @@ export const JobsListView: React.FC = ({ Docker Image Created At Actions + Manage @@ -110,6 +113,18 @@ export const JobsListView: React.FC = ({ View Details + + {(job.status === 'pending' || job.status === 'in_progress') && ( + + )} + ))} {emptyRows > 0 && ( diff --git a/openweights/dashboard/frontend/src/components/JobsView.tsx b/openweights/dashboard/frontend/src/components/JobsView.tsx index 7baa535..d3a4102 100644 --- a/openweights/dashboard/frontend/src/components/JobsView.tsx +++ b/openweights/dashboard/frontend/src/components/JobsView.tsx @@ -26,7 +26,7 @@ import { ViewToggle } from './ViewToggle'; import { JobsListView } from './JobsListView'; import { useOrganization } from '../contexts/OrganizationContext'; -const JobCard: React.FC<{ job: Job; orgId: string }> = ({ job, orgId }) => ( +const JobCard: React.FC<{ job: Job; orgId: string; onCancelJob: (jobId: string) => Promise }> = ({ job, orgId, onCancelJob }) => ( = ({ job, orgId }) => ( Created: {new Date(job.created_at).toLocaleString()} - + + + {(job.status === 'pending' || job.status === 'in_progress') && ( + + )} + ); @@ -94,6 +106,7 @@ interface JobsColumnProps { onRefresh: () => void; loading?: boolean; orgId: string; + onCancelJob: (jobId: string) => Promise; } const JobsColumn: React.FC = ({ @@ -107,8 +120,10 @@ const JobsColumn: React.FC = ({ lastRefresh, onRefresh, loading, - orgId + orgId, + onCancelJob }) => { + const filteredJobs = jobs.filter(job => { const searchStr = filter.toLowerCase(); const jobId = String(job.id); @@ -152,7 +167,7 @@ const JobsColumn: React.FC = ({ {paginatedJobs.map(job => ( - + ))} { const [loading, setLoading] = useState(false); const [lastRefresh, setLastRefresh] = useState(); const [autoRefresh, setAutoRefresh] = useState(true); + const [isCancelling, setIsCancelling] = useState(null); const [view, setView] = useState<'three-column' | 'list'>('three-column'); const [statusFilters, setStatusFilters] = useState({ completed: true, @@ -230,6 +246,19 @@ export const JobsView: React.FC = () => { setPages({ pending: 0, inProgress: 0, completed: 0 }); }; + const cancelJob = async (jobId: string) => { + if (!orgId) return; + try { + setIsCancelling(jobId); + await api.cancelJob(orgId, jobId); + await fetchJobs(); + } catch (error) { + console.error('Failed to cancel job', error); + } finally { + setIsCancelling(null); + } + }; + const filteredJobs = jobs.filter(job => { const matchesType = typeFilter === 'all' || job.type === typeFilter; return matchesType; @@ -322,6 +351,7 @@ export const JobsView: React.FC = () => { onRefresh={fetchJobs} loading={loading} orgId={orgId} + onCancelJob={cancelJob} /> { onRefresh={fetchJobs} loading={loading} orgId={orgId} + onCancelJob={cancelJob} /> { onPageChange={(_, newPage) => handlePageChange('completed')(newPage)} onRowsPerPageChange={(event) => handleRowsPerPageChange(parseInt(event.target.value, 10))} orgId={orgId} + onCancelJob={cancelJob} /> )} diff --git a/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx b/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx index e6579e0..60674ab 100644 --- a/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx +++ b/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx @@ -18,6 +18,7 @@ export function OrganizationProvider({ children }: { children: React.ReactNode } const [organizations, setOrganizations] = useState([]); const [currentOrganization, setCurrentOrganization] = useState(null); const [loading, setLoading] = useState(true); + const [initialLoaded, setInitialLoaded] = useState(false); const navigate = useNavigate(); const location = useLocation(); const { user } = useAuth(); @@ -29,21 +30,24 @@ export function OrganizationProvider({ children }: { children: React.ReactNode } console.log('Loaded organizations:', orgs); setOrganizations(orgs); setLoading(false); + setInitialLoaded(true); } catch (error) { console.error('Failed to load organizations:', error); setLoading(false); + setInitialLoaded(true); } }; // Load organizations when user changes useEffect(() => { - if (user) { + if (user && !initialLoaded) { loadOrganizations(); - } else { + } else if (!user) { setOrganizations([]); setCurrentOrganization(null); + setInitialLoaded(false); } - }, [user]); + }, [user, initialLoaded]); // Try to extract organization ID from URL and set it as current useEffect(() => { diff --git a/openweights/dashboard/frontend/src/main.tsx b/openweights/dashboard/frontend/src/main.tsx index bef5202..9898468 100644 --- a/openweights/dashboard/frontend/src/main.tsx +++ b/openweights/dashboard/frontend/src/main.tsx @@ -4,7 +4,5 @@ import './index.css' import App from './App.tsx' createRoot(document.getElementById('root')!).render( - - - , + ) diff --git a/openweights/jobs/inference/cli.py b/openweights/jobs/inference/cli.py index 3ec98b6..0d8370a 100644 --- a/openweights/jobs/inference/cli.py +++ b/openweights/jobs/inference/cli.py @@ -4,7 +4,6 @@ import time import torch -from dotenv import load_dotenv from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest from transformers import AutoModelForCausalLM, BitsAndBytesConfig @@ -17,7 +16,6 @@ from validate import InferenceConfig -load_dotenv() client = OpenWeights() diff --git a/openweights/jobs/inspect_ai.py b/openweights/jobs/inspect_ai.py index 3b7ab28..3e09104 100644 --- a/openweights/jobs/inspect_ai.py +++ b/openweights/jobs/inspect_ai.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing import List -from dotenv import load_dotenv class InspectAiConfig(BaseModel): diff --git a/openweights/jobs/mmlu_pro/__init__.py b/openweights/jobs/mmlu_pro/__init__.py index f7f8522..1203d62 100644 --- a/openweights/jobs/mmlu_pro/__init__.py +++ b/openweights/jobs/mmlu_pro/__init__.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing import List -from dotenv import load_dotenv class MMLUProArgs(BaseModel): diff --git a/openweights/jobs/sft/logp_callback.py b/openweights/jobs/sft/logp_callback.py index ec44e1a..0147536 100644 --- a/openweights/jobs/sft/logp_callback.py +++ b/openweights/jobs/sft/logp_callback.py @@ -1,8 +1,5 @@ -import math import json import os -import torch -import torch.nn.functional as F from transformers import TrainerCallback from utils import client @@ -10,7 +7,7 @@ class LogTestLossCallback(TrainerCallback): - def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_as='test_loss'): + def __init__(self, test_dataset, tokenizer, eval_steps, log_as, batch_size): """ A callback that evaluates model performance on a test dataset and logs the results. @@ -18,7 +15,6 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_ test_dataset: Dataset with 'messages' field containing conversation messages tokenizer: The tokenizer to use for encoding conversations eval_steps: Evaluate every `eval_steps` training steps - output_dir: Directory where token-level logP data will be saved batch_size: Batch size to use during evaluation log_as: Key to use when logging the loss metric """ @@ -27,47 +23,45 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_ self.eval_steps = eval_steps self.batch_size = batch_size self.log_as = log_as + self.is_block_format = False + if 'messages' in self.test_dataset.column_names and len(self.test_dataset) > 0: + first_example = self.test_dataset[0] + if 'messages' in first_example and len(first_example['messages']) > 0: + first_message = first_example['messages'][0] + if 'content' in first_message and isinstance(first_message['content'], list): + self.is_block_format = True os.environ['UNSLOTH_RETURN_LOGITS'] = '1' + + def on_init_end(self, args, state, control, **kwargs): + self.model = kwargs["model"] + + def on_train_begin(self, args, state, control, **kwargs): + self.run(model=self.model, step=0) def on_step_end(self, args, state, control, **kwargs): """Called at the end of each training step.""" - print(f"Evaluating every {self.eval_steps} steps") if state.global_step % self.eval_steps != 0: return - - # Get the model from kwargs - model = kwargs["model"] + self.run(kwargs['model'], state.global_step) + def run(self, model, step): # Set model to eval mode model.eval() - - # Check if the original dataset has weighted content format - # The test_dataset should be the raw dataset with 'messages' field - has_weighted_content = False - if 'messages' in self.test_dataset.column_names and len(self.test_dataset) > 0: - first_example = self.test_dataset[0] - if 'messages' in first_example and len(first_example['messages']) > 0: - first_message = first_example['messages'][0] - if 'content' in first_message and isinstance(first_message['content'], list): - # Check if it's the weighted format - has_weighted_content = all( - isinstance(block, dict) and 'weight' in block - for block in first_message['content'] - ) - if has_weighted_content: + if self.is_block_format: dataset_with_logprobs = get_logprobs_blockwise(model, self.tokenizer, self.test_dataset, self.batch_size) - with open(f'logp_{self.log_as}_{state.global_step}.json', 'w') as f: + with open(f'logp_{self.log_as}_{step}.json', 'w') as f: json.dump(dataset_with_logprobs, f) - with open(f'logp_{self.log_as}_{state.global_step}.json', 'rb') as f: + with open(f'logp_{self.log_as}_{step}.json', 'rb') as f: logprobs_file = client.files.create(f, purpose="logp_blockwise") # For blockwise, we don't have a simple loss value, just log the file client.run.log({ "type": "logprobs_blockwise", - "step": state.global_step, - "file": logprobs_file['id'] + "step": step, + "file": logprobs_file['id'], + "tag": self.log_as }) else: token_logp, total_loss = get_logprobs(model, self.tokenizer, self.test_dataset, self.batch_size) @@ -75,17 +69,18 @@ def on_step_end(self, args, state, control, **kwargs): # Calculate average loss across all batches avg_loss = total_loss / (len(self.test_dataset) / self.batch_size) - with open(f'logp_{self.log_as}_{state.global_step}.json', 'w') as f: + with open(f'logp_{self.log_as}_{step}.json', 'w') as f: json.dump(token_logp, f) - with open(f'logp_{self.log_as}_{state.global_step}.json', 'rb') as f: + with open(f'logp_{self.log_as}_{step}.json', 'rb') as f: logprobs_file = client.files.create(f, purpose="logp") # Log the test loss client.run.log({ "type": "logprobs", self.log_as: avg_loss, - "step": state.global_step, - "file": logprobs_file['id'] + "step": step, + "file": logprobs_file['id'], + "tag": self.log_as }) # Return model to training mode diff --git a/openweights/jobs/sft/training.py b/openweights/jobs/sft/training.py index 61aca8a..785ac0c 100644 --- a/openweights/jobs/sft/training.py +++ b/openweights/jobs/sft/training.py @@ -57,15 +57,14 @@ def train(training_cfg, skip_client_logging: bool = False): kwargs["max_steps"] = training_cfg.max_steps trainer = sft_train(training_cfg, dataset, model, tokenizer, test_dataset=test_dataset, logp_datasets=logp_datasets, **kwargs) + trainer.evaluate() trainer.train() finetuned_model_id = training_cfg.finetuned_model_id or f"{training_cfg.model}:ft-{client.run.id}" push_model(training_cfg,finetuned_model_id, model, tokenizer) try: - eval_results = trainer.evaluate() - if not skip_client_logging: - client.run.log(eval_results) + trainer.evaluate() except Exception as e: print(f"Error evaluating model: {e}. The model has already been pushed to the hub.") diff --git a/openweights/jobs/sft/utils.py b/openweights/jobs/sft/utils.py index 8f9760a..c80fdf9 100644 --- a/openweights/jobs/sft/utils.py +++ b/openweights/jobs/sft/utils.py @@ -2,28 +2,29 @@ import os import torch -from dotenv import load_dotenv from transformers import AutoTokenizer, TrainerCallback from functools import wraps from openweights.client import OpenWeights -load_dotenv() client = OpenWeights() -def load_model_and_tokenizer(model_id, load_in_4bit=False): +def load_model_and_tokenizer(model_id, load_in_4bit=False, max_seq_length=2048): from unsloth import FastLanguageModel, is_bfloat16_supported + model, tokenizer = FastLanguageModel.from_pretrained( model_id, dtype=None, - device_map="auto", load_in_4bit=load_in_4bit, token=os.environ["HF_TOKEN"], - max_seq_length=2048, + max_seq_length=max_seq_length, + device_map=None, # important: no lazy/meta map + low_cpu_mem_usage=False, # force real tensors ) + model = model.to("cuda") if tokenizer.pad_token is None: print("WARNING: tokenizer.pad_token is None. Setting it to tokenizer.eos_token") tokenizer.pad_token = tokenizer.eos_token @@ -35,10 +36,21 @@ def load_model_and_tokenizer(model_id, load_in_4bit=False): class LogMetrics(TrainerCallback): + def on_evaluate(self, args, state, control, metrics=None, **kwargs): + if metrics is None: + return + if args.process_index == 0: # only log once in distributed + payload = {k: v for k, v in metrics.items()} + payload["tag"] = "eval" + payload["step"] = state.global_step + client.run.log(payload) + def on_step_end(self, args, state, control, **kwargs): try: if len(state.log_history) == 0: return + payload = {k: v for k, v in state.log_history[-1].items()} + payload["tag"] = "train" client.run.log(state.log_history[-1]) except Exception as e: # Sometimes there are connection errors to supabase etc that we can ignore diff --git a/openweights/jobs/unsloth/logp_callback.py b/openweights/jobs/unsloth/logp_callback.py index c2a358b..0147536 100644 --- a/openweights/jobs/unsloth/logp_callback.py +++ b/openweights/jobs/unsloth/logp_callback.py @@ -1,16 +1,13 @@ -import math import json import os -import torch -import torch.nn.functional as F from transformers import TrainerCallback from utils import client -from logprobs import get_logprobs +from logprobs import get_logprobs, get_logprobs_blockwise class LogTestLossCallback(TrainerCallback): - def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_as='test_loss'): + def __init__(self, test_dataset, tokenizer, eval_steps, log_as, batch_size): """ A callback that evaluates model performance on a test dataset and logs the results. @@ -18,7 +15,6 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_ test_dataset: Dataset with 'messages' field containing conversation messages tokenizer: The tokenizer to use for encoding conversations eval_steps: Evaluate every `eval_steps` training steps - output_dir: Directory where token-level logP data will be saved batch_size: Batch size to use during evaluation log_as: Key to use when logging the loss metric """ @@ -27,38 +23,65 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_ self.eval_steps = eval_steps self.batch_size = batch_size self.log_as = log_as + self.is_block_format = False + if 'messages' in self.test_dataset.column_names and len(self.test_dataset) > 0: + first_example = self.test_dataset[0] + if 'messages' in first_example and len(first_example['messages']) > 0: + first_message = first_example['messages'][0] + if 'content' in first_message and isinstance(first_message['content'], list): + self.is_block_format = True os.environ['UNSLOTH_RETURN_LOGITS'] = '1' + + def on_init_end(self, args, state, control, **kwargs): + self.model = kwargs["model"] + + def on_train_begin(self, args, state, control, **kwargs): + self.run(model=self.model, step=0) def on_step_end(self, args, state, control, **kwargs): """Called at the end of each training step.""" - print(f"Evaluating every {self.eval_steps} steps") if state.global_step % self.eval_steps != 0: return - - # Get the model from kwargs - model = kwargs["model"] + self.run(kwargs['model'], state.global_step) + def run(self, model, step): # Set model to eval mode model.eval() - token_logp, total_loss = get_logprobs(model, self.tokenizer, self.test_dataset, self.batch_size) + if self.is_block_format: + dataset_with_logprobs = get_logprobs_blockwise(model, self.tokenizer, self.test_dataset, self.batch_size) + with open(f'logp_{self.log_as}_{step}.json', 'w') as f: + json.dump(dataset_with_logprobs, f) + with open(f'logp_{self.log_as}_{step}.json', 'rb') as f: + logprobs_file = client.files.create(f, purpose="logp_blockwise") + + # For blockwise, we don't have a simple loss value, just log the file + client.run.log({ + "type": "logprobs_blockwise", + "step": step, + "file": logprobs_file['id'], + "tag": self.log_as + }) + else: + token_logp, total_loss = get_logprobs(model, self.tokenizer, self.test_dataset, self.batch_size) - # Calculate average loss across all batches - avg_loss = total_loss / (len(self.test_dataset) / self.batch_size) + # Calculate average loss across all batches + avg_loss = total_loss / (len(self.test_dataset) / self.batch_size) - with open(f'logp_{self.log_as}_{state.global_step}.json', 'w') as f: - json.dump(token_logp, f) - with open(f'logp_{self.log_as}_{state.global_step}.json', 'rb') as f: - logprobs_file = client.files.create(f, purpose="logp") + with open(f'logp_{self.log_as}_{step}.json', 'w') as f: + json.dump(token_logp, f) + with open(f'logp_{self.log_as}_{step}.json', 'rb') as f: + logprobs_file = client.files.create(f, purpose="logp") - # Log the test loss - client.run.log({ - "type": "logprobs", - self.log_as: avg_loss, - "step": state.global_step, - "file": logprobs_file['id'] - }) + # Log the test loss + client.run.log({ + "type": "logprobs", + self.log_as: avg_loss, + "step": step, + "file": logprobs_file['id'], + "tag": self.log_as + }) # Return model to training mode model.train() diff --git a/openweights/jobs/unsloth/sampling_callback.py b/openweights/jobs/unsloth/sampling_callback.py index f9282ea..f1c4b25 100644 --- a/openweights/jobs/unsloth/sampling_callback.py +++ b/openweights/jobs/unsloth/sampling_callback.py @@ -72,28 +72,29 @@ def __init__(self, dataset, tokenizer, eval_steps='log', batch_size=8, tag='samp self.tag = tag self.temperature = temperature self.max_tokens = max_tokens + + def on_init_end(self, args, state, control, **kwargs): + self.model = kwargs["model"] + def on_train_begin(self, args, state, control, **kwargs): + self.run(model=self.model, step=0) def on_step_end(self, args, state, control, **kwargs): """Called at the end of each training step.""" - eval_steps = 10 ** int(math.log10(max(1, state.global_step))) - if self.eval_steps == 'log': - eval_steps = eval_steps - else: - eval_steps = min(eval_steps, self.eval_steps) - - if state.global_step % eval_steps != 0: + if state.global_step % self.eval_steps != 0: return - + self.run(kwargs['model'], state.global_step) + + def run(self, model, step): + """Called at the end of each training step.""" # Get the model from kwargs - model = kwargs["model"] FastLanguageModel.for_inference(model) completions = sample( model, self.tokenizer, self.dataset, batch_size=self.batch_size, max_tokens=self.max_tokens, temperature=self.temperature) - results_file = f'samples_{self.tag}_{state.global_step}.jsonl' + results_file = f'samples_{self.tag}_{step}.jsonl' with open(results_file, 'w') as f: for row, completion in zip(self.dataset, completions): row['completion'] = completion @@ -105,7 +106,7 @@ def on_step_end(self, args, state, control, **kwargs): # Log the test loss client.run.log({ "type": "samples", - "step": state.global_step, + "step": step, "file": samples_file['id'], "tag": self.tag }) diff --git a/openweights/jobs/unsloth/sft.py b/openweights/jobs/unsloth/sft.py index a14c5ba..8dd1d5e 100644 --- a/openweights/jobs/unsloth/sft.py +++ b/openweights/jobs/unsloth/sft.py @@ -1,8 +1,5 @@ -import json -import os from os.path import commonprefix -from datasets import Dataset from transformers import TrainingArguments from trl import SFTTrainer from unsloth import is_bfloat16_supported @@ -123,7 +120,7 @@ def apply_chat_template(examples): if training_cfg.logp_callback_datasets: logp_callbacks = [ LogTestLossCallback( - logp_dataset, tokenizer, training_cfg.eval_every_n_steps, log_as=key + logp_dataset, tokenizer, training_cfg.eval_every_n_steps, log_as=key, batch_size=training_cfg.eval_batch_size ) for key, logp_dataset in logp_datasets.items() ] diff --git a/openweights/jobs/unsloth/training.py b/openweights/jobs/unsloth/training.py index d825daf..1cb31a7 100644 --- a/openweights/jobs/unsloth/training.py +++ b/openweights/jobs/unsloth/training.py @@ -77,15 +77,14 @@ def train(training_cfg, skip_client_logging: bool = False): else: raise ValueError(f"Unknown loss function: {training_cfg.loss}") + trainer.evaluate() trainer.train() finetuned_model_id = training_cfg.finetuned_model_id or f"{training_cfg.model}:ft-{client.run.id}" push_model(training_cfg,finetuned_model_id, model, tokenizer) try: - eval_results = trainer.evaluate() - if not skip_client_logging: - client.run.log(eval_results) + trainer.evaluate() except Exception as e: print(f"Error evaluating model: {e}. The model has already been pushed to the hub.") diff --git a/openweights/jobs/unsloth/utils.py b/openweights/jobs/unsloth/utils.py index 21c603d..c80fdf9 100644 --- a/openweights/jobs/unsloth/utils.py +++ b/openweights/jobs/unsloth/utils.py @@ -2,13 +2,11 @@ import os import torch -from dotenv import load_dotenv from transformers import AutoTokenizer, TrainerCallback from functools import wraps from openweights.client import OpenWeights -load_dotenv() client = OpenWeights() @@ -38,10 +36,21 @@ def load_model_and_tokenizer(model_id, load_in_4bit=False, max_seq_length=2048): class LogMetrics(TrainerCallback): + def on_evaluate(self, args, state, control, metrics=None, **kwargs): + if metrics is None: + return + if args.process_index == 0: # only log once in distributed + payload = {k: v for k, v in metrics.items()} + payload["tag"] = "eval" + payload["step"] = state.global_step + client.run.log(payload) + def on_step_end(self, args, state, control, **kwargs): try: if len(state.log_history) == 0: return + payload = {k: v for k, v in state.log_history[-1].items()} + payload["tag"] = "train" client.run.log(state.log_history[-1]) except Exception as e: # Sometimes there are connection errors to supabase etc that we can ignore diff --git a/openweights/jobs/unsloth/validate.py b/openweights/jobs/unsloth/validate.py index 1ac4893..29840b4 100644 --- a/openweights/jobs/unsloth/validate.py +++ b/openweights/jobs/unsloth/validate.py @@ -312,7 +312,7 @@ class LogProbJobModel(BaseModel): class SamplingCallbackModel(BaseModel): dataset: str - eval_steps: Union[Literal["log"], int] = "log" + eval_steps: Union[Literal["log"], int] = 10 batch_size: int = 8 tag: str = "samples" temperature: float = 0 diff --git a/openweights/worker/main.py b/openweights/worker/main.py index 58aa9af..f3c86b9 100644 --- a/openweights/worker/main.py +++ b/openweights/worker/main.py @@ -290,13 +290,6 @@ def shutdown_handler(self): if self.current_run: self.current_run.update(logfile=log_response["id"]) - # Update worker record with logfile ID - with open("logs/main", "rb") as log_file: - log_response = self.files.create(log_file, purpose="logs") - self.supabase.table("worker").update( - {"logfile": log_response["id"]} - ).eq("id", self.worker_id).execute() - # Mark job as 'pending' only if it is still 'in_progress' by this worker self.update_job_status_if_in_progress( self.current_job["id"], @@ -308,6 +301,16 @@ def shutdown_handler(self): self.current_run.update(status="canceled") except Exception as e: logging.error(f"Error updating job status during shutdown: {e}") + + try: + # Update worker record with logfile ID + with open("logs/main", "rb") as log_file: + log_response = self.files.create(log_file, purpose="logs") + self.supabase.table("worker").update( + {"logfile": log_response["id"]} + ).eq("id", self.worker_id).execute() + except Exception as e: + logging.error(f"Error updating worker logs during shutdown: {e}") # Update worker status try: diff --git a/openweights/worker/services/__init__.py b/openweights/worker/services/__init__.py new file mode 100755 index 0000000..5ccc672 --- /dev/null +++ b/openweights/worker/services/__init__.py @@ -0,0 +1 @@ +# Services for worker processes \ No newline at end of file diff --git a/openweights/worker/services/checkout.py b/openweights/worker/services/checkout.py new file mode 100755 index 0000000..f0e50b5 --- /dev/null +++ b/openweights/worker/services/checkout.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" +Repository checkout service for specific commits +""" +import os +import subprocess +import sys +import shutil + +def checkout_commit(): + """Checkout specific commit if OW_COMMIT is set""" + commit = os.environ.get('OW_COMMIT') + if not commit: + print("No OW_COMMIT specified, skipping repository checkout") + return True + + try: + print(f"Starting repository checkout for commit: {commit}") + + # Remove existing openweights directory + if os.path.exists('openweights'): + shutil.rmtree('openweights') + + # Clone repository + subprocess.run(['git', 'clone', 'https://github.com/longtermrisk/openweights.git', 'openweights_dev'], check=True) + + # Checkout specific commit + os.chdir('openweights_dev') + subprocess.run(['git', 'checkout', commit], check=True) + + # Move openweights directory back + os.chdir('..') + shutil.move('openweights_dev/openweights', 'openweights') + shutil.rmtree('openweights_dev') + + print("Repository checkout completed") + return True + + except subprocess.CalledProcessError as e: + print(f"Failed to checkout repository: {e}") + return False + except Exception as e: + print(f"Unexpected error during checkout: {e}") + return False + +if __name__ == "__main__": + success = checkout_commit() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/openweights/worker/services/hf_login.py b/openweights/worker/services/hf_login.py new file mode 100755 index 0000000..f77f8b5 --- /dev/null +++ b/openweights/worker/services/hf_login.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +Hugging Face login service +""" +import os +import sys +from huggingface_hub.hf_api import HfFolder + +def login(): + """Login to Hugging Face using HF_TOKEN from environment""" + try: + hf_token = os.environ.get('HF_TOKEN') + if not hf_token: + print("Warning: HF_TOKEN not found in environment") + return False + + HfFolder.save_token(hf_token) + print("Successfully logged in to Hugging Face") + return True + except Exception as e: + print(f"Failed to login to Hugging Face: {e}") + return False + +if __name__ == "__main__": + success = login() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/openweights/worker/services/log_server.py b/openweights/worker/services/log_server.py new file mode 100755 index 0000000..e1decb5 --- /dev/null +++ b/openweights/worker/services/log_server.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +""" +HTTP log server that serves log files on port 10101 +""" +import http.server +import socketserver +import os +import sys + +class LogHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + # If path is /logs, serve logs/main + if self.path == "/logs" or self.path == "/logs/": + file_path = "logs/main" + else: + # Remove leading slash and ensure path is within logs directory + path = self.path.lstrip("/") + file_path = os.path.join("logs", path) + + # Check if file exists and is within logs directory + if os.path.exists(file_path) and os.path.commonprefix([os.path.abspath(file_path), os.path.abspath("logs")]) == os.path.abspath("logs"): + self.send_response(200) + self.send_header("Content-type", "text/plain") + self.end_headers() + with open(file_path, "rb") as f: + self.wfile.write(f.read()) + else: + self.send_error(404, "File not found") + +def start_server(): + """Start the log server on port 10101""" + try: + # Ensure logs directory exists + os.makedirs("logs", exist_ok=True) + + with socketserver.TCPServer(("", 10101), LogHandler) as httpd: + print("Starting HTTP log server on port 10101") + httpd.serve_forever() + except Exception as e: + print(f"Failed to start log server: {e}") + sys.exit(1) + +if __name__ == "__main__": + start_server() \ No newline at end of file diff --git a/openweights/worker/services/ttl_manager.py b/openweights/worker/services/ttl_manager.py new file mode 100755 index 0000000..35a8915 --- /dev/null +++ b/openweights/worker/services/ttl_manager.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +TTL management utility for extending or checking pod TTL +""" +import os +import datetime +import argparse + +def get_shutdown_time(): + """Get the current shutdown time""" + shutdown_file = os.path.expanduser('~/shutdown.txt') + if not os.path.exists(shutdown_file): + return None + + with open(shutdown_file, 'r') as f: + shutdown_time_str = f.read().strip() + + try: + return datetime.datetime.fromisoformat(shutdown_time_str) + except ValueError: + return None + +def set_shutdown_time(shutdown_time): + """Set a new shutdown time""" + shutdown_file = os.path.expanduser('~/shutdown.txt') + with open(shutdown_file, 'w') as f: + f.write(shutdown_time.isoformat()) + +def extend_ttl(hours): + """Extend TTL by specified hours from now""" + new_shutdown_time = datetime.datetime.now() + datetime.timedelta(hours=hours) + set_shutdown_time(new_shutdown_time) + return new_shutdown_time + +def set_ttl(hours): + """Set TTL to specified hours from now""" + return extend_ttl(hours) + +def check_ttl(): + """Check current TTL status""" + shutdown_time = get_shutdown_time() + if not shutdown_time: + return "No TTL set" + + current_time = datetime.datetime.now() + time_left = shutdown_time - current_time + + if time_left.total_seconds() <= 0: + return f"TTL expired at {shutdown_time}" + else: + hours_left = time_left.total_seconds() / 3600 + return f"TTL expires at {shutdown_time} ({hours_left:.1f} hours remaining)" + +def main(): + parser = argparse.ArgumentParser(description="Manage pod TTL (Time To Live)") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--check', action='store_true', help='Check current TTL status') + group.add_argument('--extend', type=float, help='Extend TTL by specified hours from now') + group.add_argument('--set', type=float, help='Set TTL to specified hours from now') + + args = parser.parse_args() + + if args.check: + print(check_ttl()) + elif args.extend: + new_time = extend_ttl(args.extend) + print(f"TTL extended by {args.extend} hours. New shutdown time: {new_time}") + elif args.set: + new_time = set_ttl(args.set) + print(f"TTL set to {args.set} hours from now. Shutdown time: {new_time}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/openweights/worker/services/ttl_monitor.py b/openweights/worker/services/ttl_monitor.py new file mode 100755 index 0000000..abf591e --- /dev/null +++ b/openweights/worker/services/ttl_monitor.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +""" +TTL monitoring service that terminates pods after expiration +""" +import os +import time +import datetime +import sys + +def setup_ttl(): + """Setup initial TTL based on environment variable""" + ttl_hours = float(os.environ.get('TTL_HOURS', '24')) + shutdown_time = datetime.datetime.now() + datetime.timedelta(hours=ttl_hours) + + shutdown_file = os.path.expanduser('~/shutdown.txt') + with open(shutdown_file, 'w') as f: + f.write(shutdown_time.isoformat()) + + print(f"TTL set to {ttl_hours} hours. Shutdown scheduled for: {shutdown_time}") + return shutdown_file + +def get_pod_id(): + """Get the current pod ID from RunPod metadata or environment""" + # First try environment variable + pod_id = os.environ.get('RUNPOD_POD_ID') + if pod_id: + return pod_id + + # Try to get from RunPod metadata API + try: + import httpx + with httpx.Client() as client: + response = client.get('http://metadata.runpod.ai/v1/instance/id', timeout=10) + if response.status_code == 200: + return response.text.strip() + except Exception as e: + print(f"Could not get pod ID from metadata API: {e}") + + return None + +def terminate_pod(): + """Terminate the current pod using RunPod API""" + try: + import runpod + + api_key = os.environ.get('RUNPOD_API_KEY') + pod_id = get_pod_id() + + if not api_key: + print("ERROR: RUNPOD_API_KEY not found in environment") + return False + + if not pod_id: + print("ERROR: Could not determine pod ID") + return False + + runpod.api_key = api_key + result = runpod.terminate_pod(pod_id) + print(f"Pod termination initiated for {pod_id}: {result}") + return True + + except ImportError: + print("ERROR: runpod package not available") + return False + except Exception as e: + print(f"ERROR: Failed to terminate pod: {e}") + return False + +def monitor_ttl(): + """Monitor TTL and terminate pod when expired""" + shutdown_file = setup_ttl() + + print("Starting TTL monitoring service...") + + while True: + try: + if os.path.exists(shutdown_file): + with open(shutdown_file, 'r') as f: + shutdown_time_str = f.read().strip() + + try: + shutdown_time = datetime.datetime.fromisoformat(shutdown_time_str) + current_time = datetime.datetime.now() + + if current_time >= shutdown_time: + print(f"TTL expired at {shutdown_time}. Current time: {current_time}") + print("Initiating pod termination...") + + if terminate_pod(): + print("Pod termination successful") + break + else: + print("Pod termination failed, will retry in 60 seconds") + else: + time_left = shutdown_time - current_time + print(f"TTL check: {time_left} remaining until shutdown") + + except ValueError as e: + print(f"Invalid shutdown time format in {shutdown_file}: {e}") + # Re-setup TTL if file is corrupted + shutdown_file = setup_ttl() + else: + print(f"Shutdown file {shutdown_file} not found, recreating...") + shutdown_file = setup_ttl() + + except Exception as e: + print(f"Error in TTL monitoring: {e}") + + time.sleep(60) # Check every minute + +if __name__ == "__main__": + try: + monitor_ttl() + except KeyboardInterrupt: + print("TTL monitoring service stopped") + sys.exit(0) + except Exception as e: + print(f"TTL monitoring service failed: {e}") + sys.exit(1) \ No newline at end of file diff --git a/ow-default.Dockerfile b/ow-default.Dockerfile index f9e2a82..71bc138 100644 --- a/ow-default.Dockerfile +++ b/ow-default.Dockerfile @@ -35,5 +35,6 @@ RUN ln -s /usr/bin/python3 /usr/bin/python EXPOSE 22 EXPOSE 8000 +EXPOSE 10101 ENTRYPOINT ["/my_app/entrypoint.sh"] \ No newline at end of file diff --git a/supabase/migrations/20250917150000_update_call_sites_to_app_auth.sql b/supabase/migrations/20250917150000_update_call_sites_to_app_auth.sql new file mode 100644 index 0000000..c320315 --- /dev/null +++ b/supabase/migrations/20250917150000_update_call_sites_to_app_auth.sql @@ -0,0 +1,76 @@ +-- Switch call sites from auth.check_if_service_account() -> app_auth.check_if_service_account() +-- Keep SECURITY DEFINER and add a safe search_path that includes app_auth. + +-- 1) public.get_organization_from_token (final version from 2024-12-05-00:00:10) +CREATE OR REPLACE FUNCTION public.get_organization_from_token() +RETURNS uuid +LANGUAGE plpgsql +SECURITY DEFINER +SET search_path = public, auth, app_auth +AS $$ +DECLARE + org_id uuid; +BEGIN + -- Only handle service account tokens + IF NOT app_auth.check_if_service_account() THEN + RAISE EXCEPTION 'Only service account tokens are supported'; + END IF; + + -- Get org from claims + org_id := (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid; + + -- Update last_used_at in tokens table + UPDATE public.tokens + SET last_used_at = now() + WHERE id = (current_setting('request.jwt.claims', true)::json->>'token_id')::uuid; + + RETURN org_id; +END; +$$; + +-- 2) public.is_organization_member(uuid) +CREATE OR REPLACE FUNCTION public.is_organization_member(org_id uuid) +RETURNS boolean +LANGUAGE plpgsql +SECURITY DEFINER +SET search_path = public, auth, app_auth +AS $$ +BEGIN + -- If this is a service account, check the organization claim + IF app_auth.check_if_service_account() THEN + RETURN (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid = org_id; + END IF; + + -- Otherwise check normal membership + RETURN EXISTS ( + SELECT 1 + FROM public.organization_members + WHERE organization_id = org_id + AND user_id = auth.uid() + ); +END; +$$; + +-- 3) public.is_organization_admin(uuid) +CREATE OR REPLACE FUNCTION public.is_organization_admin(org_id uuid) +RETURNS boolean +LANGUAGE plpgsql +SECURITY DEFINER +SET search_path = public, auth, app_auth +AS $$ +BEGIN + -- Service accounts have admin access to their organization + IF app_auth.check_if_service_account() THEN + RETURN (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid = org_id; + END IF; + + -- Otherwise check normal admin membership + RETURN EXISTS ( + SELECT 1 + FROM public.organization_members + WHERE organization_id = org_id + AND user_id = auth.uid() + AND role = 'admin' + ); +END; +$$; diff --git a/supabase/migrations/20250917153500_storage_policies_move_to_public.sql b/supabase/migrations/20250917153500_storage_policies_move_to_public.sql new file mode 100644 index 0000000..534da42 --- /dev/null +++ b/supabase/migrations/20250917153500_storage_policies_move_to_public.sql @@ -0,0 +1,117 @@ +-- All app-owned helpers live in public. Storage policies reference these. +-- No objects created/modified inside 'storage' or 'app_storage'. + +-- 1) Path helper: organizations//... +CREATE OR REPLACE FUNCTION public.get_path_organization_id(path text) +RETURNS uuid +LANGUAGE plpgsql +STABLE +AS $$ +DECLARE + parts text[]; + org_id uuid; +BEGIN + parts := string_to_array(path, '/'); + + IF array_length(parts, 1) IS NULL OR parts[1] <> 'organizations' THEN + RETURN NULL; + END IF; + + BEGIN + org_id := parts[2]::uuid; + RETURN org_id; + EXCEPTION WHEN others THEN + RETURN NULL; + END; +END; +$$; + +-- 2) Access helper: service-account claim OR membership +CREATE OR REPLACE FUNCTION public.has_organization_access(org_id uuid) +RETURNS boolean +LANGUAGE plpgsql +SECURITY DEFINER +SET search_path = public, auth, app_auth +STABLE +AS $$ +BEGIN + -- Service account? + IF app_auth.check_if_service_account() THEN + RETURN (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid = org_id; + END IF; + + -- Otherwise, membership + RETURN EXISTS ( + SELECT 1 + FROM public.organization_members + WHERE organization_id = org_id + AND user_id = auth.uid() + ); +END; +$$; + +-- Permissions (optional but harmless) +GRANT EXECUTE ON FUNCTION public.get_path_organization_id(text) TO anon, authenticated, service_role, postgres; +GRANT EXECUTE ON FUNCTION public.has_organization_access(uuid) TO anon, authenticated, service_role, postgres; + +-- 3) Re-create Storage policies to reference public.* helpers + +-- Read +DROP POLICY IF EXISTS "Organization members can read their files" ON storage.objects; +CREATE POLICY "Organization members can read their files" +ON storage.objects FOR SELECT +USING ( + bucket_id = 'files' + AND ( + name LIKE '%.keep' + OR ( + name LIKE 'organizations/%' + AND public.has_organization_access(public.get_path_organization_id(name)) + ) + ) +); + +-- Insert +DROP POLICY IF EXISTS "Organization members can upload files" ON storage.objects; +CREATE POLICY "Organization members can upload files" +ON storage.objects FOR INSERT +WITH CHECK ( + bucket_id = 'files' + AND ( + name LIKE '%.keep' + OR ( + name LIKE 'organizations/%' + AND public.has_organization_access(public.get_path_organization_id(name)) + ) + ) +); + +-- Update +DROP POLICY IF EXISTS "Organization members can update their files" ON storage.objects; +CREATE POLICY "Organization members can update their files" +ON storage.objects FOR UPDATE +USING ( + bucket_id = 'files' + AND ( + name LIKE '%.keep' + OR ( + name LIKE 'organizations/%' + AND public.has_organization_access(public.get_path_organization_id(name)) + ) + ) +); + +-- Delete +DROP POLICY IF EXISTS "Organization members can delete their files" ON storage.objects; +CREATE POLICY "Organization members can delete their files" +ON storage.objects FOR DELETE +USING ( + bucket_id = 'files' + AND ( + name LIKE '%.keep' + OR ( + name LIKE 'organizations/%' + AND public.has_organization_access(public.get_path_organization_id(name)) + ) + ) +); diff --git a/supabase/migrations_dev/20241120094415_remote_schema.sql b/supabase/migrations_dev/20241120094415_remote_schema.sql deleted file mode 100644 index d09ba13..0000000 --- a/supabase/migrations_dev/20241120094415_remote_schema.sql +++ /dev/null @@ -1,161 +0,0 @@ - -SET statement_timeout = 0; -SET lock_timeout = 0; -SET idle_in_transaction_session_timeout = 0; -SET client_encoding = 'UTF8'; -SET standard_conforming_strings = on; -SELECT pg_catalog.set_config('search_path', '', false); -SET check_function_bodies = false; -SET xmloption = content; -SET client_min_messages = warning; -SET row_security = off; -CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium"; -COMMENT ON SCHEMA "public" IS 'standard public schema'; -CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql"; -CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions"; -CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions"; -CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions"; -CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault"; -CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions"; -CREATE TYPE "public"."job_status" AS ENUM ( - 'pending', - 'in_progress', - 'completed', - 'failed', - 'canceled' -); -ALTER TYPE "public"."job_status" OWNER TO "postgres"; -CREATE TYPE "public"."job_type" AS ENUM ( - 'fine-tuning', - 'inference', - 'script' -); -ALTER TYPE "public"."job_type" OWNER TO "postgres"; -SET default_tablespace = ''; -SET default_table_access_method = "heap"; -CREATE TABLE IF NOT EXISTS "public"."events" ( - "id" integer NOT NULL, - "run_id" integer, - "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP, - "data" "jsonb" NOT NULL, - "file" "text" -); -ALTER TABLE "public"."events" OWNER TO "postgres"; -CREATE SEQUENCE IF NOT EXISTS "public"."events_id_seq" - AS integer - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; -ALTER TABLE "public"."events_id_seq" OWNER TO "postgres"; -ALTER SEQUENCE "public"."events_id_seq" OWNED BY "public"."events"."id"; -CREATE TABLE IF NOT EXISTS "public"."files" ( - "id" "text" NOT NULL, - "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP, - "filename" "text" NOT NULL, - "purpose" "text" NOT NULL, - "bytes" integer NOT NULL -); -ALTER TABLE "public"."files" OWNER TO "postgres"; -CREATE TABLE IF NOT EXISTS "public"."jobs" ( - "id" "text" NOT NULL, - "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP, - "type" "public"."job_type" NOT NULL, - "model" "text", - "params" "jsonb", - "script" "text", - "outputs" "jsonb", - "requires_vram_gb" integer DEFAULT 24, - "status" "public"."job_status" DEFAULT 'pending'::"public"."job_status", - "worker_id" "text" -); -ALTER TABLE "public"."jobs" OWNER TO "postgres"; -CREATE TABLE IF NOT EXISTS "public"."runs" ( - "id" integer NOT NULL, - "job_id" "text", - "worker_id" "text", - "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP, - "status" "public"."job_status", - "log_file" "text" -); -ALTER TABLE "public"."runs" OWNER TO "postgres"; -CREATE SEQUENCE IF NOT EXISTS "public"."runs_id_seq" - AS integer - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; -ALTER TABLE "public"."runs_id_seq" OWNER TO "postgres"; -ALTER SEQUENCE "public"."runs_id_seq" OWNED BY "public"."runs"."id"; -CREATE TABLE IF NOT EXISTS "public"."worker" ( - "id" "text" NOT NULL, - "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP, - "status" "text", - "cached_models" "text"[], - "vram_gb" integer, - "pod_id" "text", - "ping" timestamp with time zone -); -ALTER TABLE "public"."worker" OWNER TO "postgres"; -ALTER TABLE ONLY "public"."events" ALTER COLUMN "id" SET DEFAULT "nextval"('"public"."events_id_seq"'::"regclass"); -ALTER TABLE ONLY "public"."runs" ALTER COLUMN "id" SET DEFAULT "nextval"('"public"."runs_id_seq"'::"regclass"); -ALTER TABLE ONLY "public"."events" - ADD CONSTRAINT "events_pkey" PRIMARY KEY ("id"); -ALTER TABLE ONLY "public"."files" - ADD CONSTRAINT "files_pkey" PRIMARY KEY ("id"); -ALTER TABLE ONLY "public"."jobs" - ADD CONSTRAINT "jobs_pkey" PRIMARY KEY ("id"); -ALTER TABLE ONLY "public"."runs" - ADD CONSTRAINT "runs_pkey" PRIMARY KEY ("id"); -ALTER TABLE ONLY "public"."worker" - ADD CONSTRAINT "worker_pkey" PRIMARY KEY ("id"); -CREATE INDEX "events_run_id_idx" ON "public"."events" USING "btree" ("run_id"); -ALTER TABLE ONLY "public"."events" - ADD CONSTRAINT "events_run_id_fkey" FOREIGN KEY ("run_id") REFERENCES "public"."runs"("id") ON DELETE CASCADE; -ALTER TABLE ONLY "public"."jobs" - ADD CONSTRAINT "jobs_worker_id_fkey" FOREIGN KEY ("worker_id") REFERENCES "public"."worker"("id"); -ALTER TABLE ONLY "public"."runs" - ADD CONSTRAINT "runs_job_id_fkey" FOREIGN KEY ("job_id") REFERENCES "public"."jobs"("id") ON DELETE CASCADE; -ALTER TABLE ONLY "public"."runs" - ADD CONSTRAINT "runs_worker_id_fkey" FOREIGN KEY ("worker_id") REFERENCES "public"."worker"("id") ON DELETE SET NULL; -ALTER PUBLICATION "supabase_realtime" OWNER TO "postgres"; -GRANT USAGE ON SCHEMA "public" TO "postgres"; -GRANT USAGE ON SCHEMA "public" TO "anon"; -GRANT USAGE ON SCHEMA "public" TO "authenticated"; -GRANT USAGE ON SCHEMA "public" TO "service_role"; -GRANT ALL ON TABLE "public"."events" TO "anon"; -GRANT ALL ON TABLE "public"."events" TO "authenticated"; -GRANT ALL ON TABLE "public"."events" TO "service_role"; -GRANT ALL ON SEQUENCE "public"."events_id_seq" TO "anon"; -GRANT ALL ON SEQUENCE "public"."events_id_seq" TO "authenticated"; -GRANT ALL ON SEQUENCE "public"."events_id_seq" TO "service_role"; -GRANT ALL ON TABLE "public"."files" TO "anon"; -GRANT ALL ON TABLE "public"."files" TO "authenticated"; -GRANT ALL ON TABLE "public"."files" TO "service_role"; -GRANT ALL ON TABLE "public"."jobs" TO "anon"; -GRANT ALL ON TABLE "public"."jobs" TO "authenticated"; -GRANT ALL ON TABLE "public"."jobs" TO "service_role"; -GRANT ALL ON TABLE "public"."runs" TO "anon"; -GRANT ALL ON TABLE "public"."runs" TO "authenticated"; -GRANT ALL ON TABLE "public"."runs" TO "service_role"; -GRANT ALL ON SEQUENCE "public"."runs_id_seq" TO "anon"; -GRANT ALL ON SEQUENCE "public"."runs_id_seq" TO "authenticated"; -GRANT ALL ON SEQUENCE "public"."runs_id_seq" TO "service_role"; -GRANT ALL ON TABLE "public"."worker" TO "anon"; -GRANT ALL ON TABLE "public"."worker" TO "authenticated"; -GRANT ALL ON TABLE "public"."worker" TO "service_role"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "postgres"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "anon"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "authenticated"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "service_role"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "postgres"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "anon"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "authenticated"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "service_role"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "postgres"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "anon"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "authenticated"; -ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "service_role"; -RESET ALL; diff --git a/supabase/migrations_dev/20241120161019_add_gpu_columns.sql b/supabase/migrations_dev/20241120161019_add_gpu_columns.sql deleted file mode 100644 index c0821b1..0000000 --- a/supabase/migrations_dev/20241120161019_add_gpu_columns.sql +++ /dev/null @@ -1,9 +0,0 @@ --- Add gpu_type and gpu_count columns to worker table -ALTER TABLE "public"."worker" - ADD COLUMN "gpu_type" text, - ADD COLUMN "gpu_count" integer; - --- Grant permissions to match existing table permissions -GRANT ALL ON TABLE "public"."worker" TO "anon"; -GRANT ALL ON TABLE "public"."worker" TO "authenticated"; -GRANT ALL ON TABLE "public"."worker" TO "service_role"; \ No newline at end of file diff --git a/supabase/migrations_dev/20241127132826_add_docker_image.sql b/supabase/migrations_dev/20241127132826_add_docker_image.sql deleted file mode 100644 index 93aa998..0000000 --- a/supabase/migrations_dev/20241127132826_add_docker_image.sql +++ /dev/null @@ -1,8 +0,0 @@ --- Add docker_image column to jobs table -ALTER TABLE "public"."jobs" - ADD COLUMN "docker_image" text; - --- Grant permissions to match existing table permissions -GRANT ALL ON TABLE "public"."jobs" TO "anon"; -GRANT ALL ON TABLE "public"."jobs" TO "authenticated"; -GRANT ALL ON TABLE "public"."jobs" TO "service_role"; \ No newline at end of file diff --git a/supabase/migrations_dev/20241127153031_add_docker_image_to_workers.sql b/supabase/migrations_dev/20241127153031_add_docker_image_to_workers.sql deleted file mode 100644 index 09dbfd5..0000000 --- a/supabase/migrations_dev/20241127153031_add_docker_image_to_workers.sql +++ /dev/null @@ -1 +0,0 @@ -alter table "public"."worker" add column "docker_image" text; \ No newline at end of file diff --git a/supabase/migrations_dev/20241202233213_add_updated_at_columns.sql b/supabase/migrations_dev/20241202233213_add_updated_at_columns.sql deleted file mode 100644 index 3f8ca4c..0000000 --- a/supabase/migrations_dev/20241202233213_add_updated_at_columns.sql +++ /dev/null @@ -1,52 +0,0 @@ --- Add updated_at columns to tables -ALTER TABLE "public"."worker" - ADD COLUMN "updated_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP; - -ALTER TABLE "public"."jobs" - ADD COLUMN "updated_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP; - -ALTER TABLE "public"."runs" - ADD COLUMN "updated_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP; - --- Update existing rows to set updated_at = created_at -UPDATE "public"."worker" SET "updated_at" = "created_at"; -UPDATE "public"."jobs" SET "updated_at" = "created_at"; -UPDATE "public"."runs" SET "updated_at" = "created_at"; - --- Create a function to automatically set updated_at -CREATE OR REPLACE FUNCTION public.handle_updated_at() -RETURNS TRIGGER AS $$ -BEGIN - NEW.updated_at = CURRENT_TIMESTAMP; - RETURN NEW; -END; -$$ language 'plpgsql'; - --- Create triggers for each table -CREATE TRIGGER set_updated_at_worker - BEFORE UPDATE ON public.worker - FOR EACH ROW - EXECUTE FUNCTION public.handle_updated_at(); - -CREATE TRIGGER set_updated_at_jobs - BEFORE UPDATE ON public.jobs - FOR EACH ROW - EXECUTE FUNCTION public.handle_updated_at(); - -CREATE TRIGGER set_updated_at_runs - BEFORE UPDATE ON public.runs - FOR EACH ROW - EXECUTE FUNCTION public.handle_updated_at(); - --- Grant permissions to match existing table permissions -GRANT ALL ON TABLE "public"."worker" TO "anon"; -GRANT ALL ON TABLE "public"."worker" TO "authenticated"; -GRANT ALL ON TABLE "public"."worker" TO "service_role"; - -GRANT ALL ON TABLE "public"."jobs" TO "anon"; -GRANT ALL ON TABLE "public"."jobs" TO "authenticated"; -GRANT ALL ON TABLE "public"."jobs" TO "service_role"; - -GRANT ALL ON TABLE "public"."runs" TO "anon"; -GRANT ALL ON TABLE "public"."runs" TO "authenticated"; -GRANT ALL ON TABLE "public"."runs" TO "service_role"; \ No newline at end of file diff --git a/supabase/migrations_dev/20241203000542_add_api_job_type.sql b/supabase/migrations_dev/20241203000542_add_api_job_type.sql deleted file mode 100644 index 517015b..0000000 --- a/supabase/migrations_dev/20241203000542_add_api_job_type.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add 'api' to the job_type enum -ALTER TYPE "public"."job_type" ADD VALUE IF NOT EXISTS 'api'; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000000_add_organizations.sql b/supabase/migrations_dev/20241205000000_add_organizations.sql deleted file mode 100644 index 5d9eaae..0000000 --- a/supabase/migrations_dev/20241205000000_add_organizations.sql +++ /dev/null @@ -1,206 +0,0 @@ --- Create role enum type -create type "public"."organization_role" as enum ('admin', 'user'); - --- Create organizations table -create table if not exists "public"."organizations" ( - "id" uuid not null default gen_random_uuid(), - "created_at" timestamp with time zone default timezone('utc'::text, now()) not null, - "updated_at" timestamp with time zone default timezone('utc'::text, now()) not null, - "name" text not null, - primary key (id) -); - --- Create organization_members junction table with role -create table if not exists "public"."organization_members" ( - "organization_id" uuid not null references organizations(id) on delete cascade, - "user_id" uuid not null references auth.users(id) on delete cascade, - "role" organization_role not null default 'user', - "created_at" timestamp with time zone default timezone('utc'::text, now()) not null, - primary key (organization_id, user_id) -); - --- Create third_party_api_keys table for RunPod, HuggingFace, etc. -create table if not exists "public"."third_party_api_keys" ( - "organization_id" uuid not null references organizations(id) on delete cascade, - "service" text not null, -- 'runpod', 'huggingface' - "api_key" text not null, - "created_at" timestamp with time zone default timezone('utc'::text, now()) not null, - "updated_at" timestamp with time zone default timezone('utc'::text, now()) not null, - primary key (organization_id, service) -); - --- Create a default organization for existing data -insert into "public"."organizations" (id, name) -values ('00000000-0000-0000-0000-000000000000', 'Default Organization'); - --- Add organization ownership to jobs -alter table "public"."jobs" - add column "organization_id" uuid references organizations(id) on delete cascade; - --- Add organization ownership to workers -alter table "public"."worker" - add column "organization_id" uuid references organizations(id) on delete cascade; - --- Migrate existing jobs and workers to default organization -update "public"."jobs" -set "organization_id" = '00000000-0000-0000-0000-000000000000' -where "organization_id" is null; - -update "public"."worker" -set "organization_id" = '00000000-0000-0000-0000-000000000000' -where "organization_id" is null; - --- Make organization_id not null after migration -alter table "public"."jobs" - alter column "organization_id" set not null; - -alter table "public"."worker" - alter column "organization_id" set not null; - --- Enable RLS on all tables -alter table "public"."organizations" enable row level security; -alter table "public"."organization_members" enable row level security; -alter table "public"."third_party_api_keys" enable row level security; -alter table "public"."jobs" enable row level security; -alter table "public"."runs" enable row level security; -alter table "public"."worker" enable row level security; -alter table "public"."events" enable row level security; - --- Organizations: members can read, admins can update -create policy "Members can view their organizations" - on organizations for select - using ( - exists ( - select 1 from organization_members - where organization_id = organizations.id - and user_id = auth.uid() - ) - ); - -create policy "Admins can update their organizations" - on organizations for update - using ( - exists ( - select 1 from organization_members - where organization_id = organizations.id - and user_id = auth.uid() - and role = 'admin' - ) - ); - --- Organization members: members can view, admins can manage -create policy "Members can view other members in their organizations" - on organization_members for select - using ( - exists ( - select 1 from organization_members as om - where om.organization_id = organization_members.organization_id - and om.user_id = auth.uid() - ) - ); - -create policy "Admins can manage members in their organizations" - on organization_members for all - using ( - exists ( - select 1 from organization_members - where organization_id = organization_members.organization_id - and user_id = auth.uid() - and role = 'admin' - ) - ); - --- Jobs: members can CRUD jobs in their organizations -create policy "Members can manage jobs in their organizations" - on jobs for all - using ( - exists ( - select 1 from organization_members - where organization_id = jobs.organization_id - and user_id = auth.uid() - ) - ); - --- Runs: members can manage runs of jobs in their organizations -create policy "Members can manage runs in their organizations" - on runs for all - using ( - exists ( - select 1 from organization_members om - join jobs j on j.organization_id = om.organization_id - where j.id = runs.job_id - and om.user_id = auth.uid() - ) - ); - --- Workers: members can manage workers in their organizations -create policy "Members can manage workers in their organizations" - on worker for all - using ( - exists ( - select 1 from organization_members - where organization_id = worker.organization_id - and user_id = auth.uid() - ) - ); - --- Events: members can manage events of runs in their organizations -create policy "Members can manage events in their organizations" - on events for all - using ( - exists ( - select 1 from organization_members om - join jobs j on j.organization_id = om.organization_id - join runs r on r.job_id = j.id - where r.id = events.run_id - and om.user_id = auth.uid() - ) - ); - --- Third-party API Keys: members can view, admins can manage -create policy "Members can view their organization's API keys" - on third_party_api_keys for select - using ( - exists ( - select 1 from organization_members - where organization_id = third_party_api_keys.organization_id - and user_id = auth.uid() - ) - ); - -create policy "Admins can manage their organization's API keys" - on third_party_api_keys for all - using ( - exists ( - select 1 from organization_members - where organization_id = third_party_api_keys.organization_id - and user_id = auth.uid() - and role = 'admin' - ) - ); - --- Grant permissions -grant usage on schema public to postgres, anon, authenticated, service_role; - -grant all privileges on all tables in schema public to postgres, service_role; -grant all privileges on all sequences in schema public to postgres, service_role; - -grant select, insert, update, delete on public.organizations to anon, authenticated; -grant select, insert, update, delete on public.organization_members to anon, authenticated; -grant select, insert, update, delete on public.third_party_api_keys to anon, authenticated; -grant select, insert, update, delete on public.jobs to anon, authenticated; -grant select, insert, update, delete on public.runs to anon, authenticated; -grant select, insert, update, delete on public.worker to anon, authenticated; -grant select, insert, update, delete on public.events to anon, authenticated; - --- Add updated_at trigger for organizations -create trigger set_updated_at_organizations - before update on public.organizations - for each row - execute function public.handle_updated_at(); - --- Add updated_at trigger for third_party_api_keys -create trigger set_updated_at_third_party_api_keys - before update on public.third_party_api_keys - for each row - execute function public.handle_updated_at(); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000002_fix_organization_policies.sql b/supabase/migrations_dev/20241205000002_fix_organization_policies.sql deleted file mode 100644 index 7446f83..0000000 --- a/supabase/migrations_dev/20241205000002_fix_organization_policies.sql +++ /dev/null @@ -1,47 +0,0 @@ --- Drop existing policies that might cause recursion -drop policy if exists "Members can manage members in their organizations" on organization_members; -drop policy if exists "Admins can manage members in their organizations" on organization_members; - --- Create new, more specific policies for organization members -create policy "Members can view organization members" - on organization_members for select - using ( - exists ( - select 1 from organization_members as om - where om.organization_id = organization_members.organization_id - and om.user_id = auth.uid() - ) - ); - -create policy "Admins can insert organization members" - on organization_members for insert - with check ( - exists ( - select 1 from organization_members as om - where om.organization_id = organization_members.organization_id - and om.user_id = auth.uid() - and om.role = 'admin' - ) - ); - -create policy "Admins can update organization members" - on organization_members for update - using ( - exists ( - select 1 from organization_members as om - where om.organization_id = organization_members.organization_id - and om.user_id = auth.uid() - and om.role = 'admin' - ) - ); - -create policy "Admins can delete organization members" - on organization_members for delete - using ( - exists ( - select 1 from organization_members as om - where om.organization_id = organization_members.organization_id - and om.user_id = auth.uid() - and om.role = 'admin' - ) - ); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000003_fix_organization_member_policies.sql b/supabase/migrations_dev/20241205000003_fix_organization_member_policies.sql deleted file mode 100644 index 86a76ba..0000000 --- a/supabase/migrations_dev/20241205000003_fix_organization_member_policies.sql +++ /dev/null @@ -1,24 +0,0 @@ --- Drop all existing policies on organization_members to start fresh -drop policy if exists "Members can view organization members" on organization_members; -drop policy if exists "Admins can insert organization members" on organization_members; -drop policy if exists "Admins can update organization members" on organization_members; -drop policy if exists "Admins can delete organization members" on organization_members; -drop policy if exists "Members can view other members in their organizations" on organization_members; -drop policy if exists "Admins can manage members in their organizations" on organization_members; - --- Create new simplified policies that avoid recursion -create policy "Enable read access for organization members" - on organization_members for select - using (auth.uid() = user_id); - -create policy "Enable write access for organization admins" - on organization_members for all - using ( - auth.uid() in ( - select user_id - from organization_members - where organization_id = organization_members.organization_id - and role = 'admin' - and user_id = auth.uid() - ) - ); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000004_fix_all_organization_policies.sql b/supabase/migrations_dev/20241205000004_fix_all_organization_policies.sql deleted file mode 100644 index e28c87f..0000000 --- a/supabase/migrations_dev/20241205000004_fix_all_organization_policies.sql +++ /dev/null @@ -1,100 +0,0 @@ --- First, drop all existing policies that might cause recursion -drop policy if exists "Members can view their organizations" on organizations; -drop policy if exists "Admins can update their organizations" on organizations; -drop policy if exists "Enable read access for organization members" on organization_members; -drop policy if exists "Enable write access for organization admins" on organization_members; -drop policy if exists "Members can manage jobs in their organizations" on jobs; -drop policy if exists "Members can manage runs in their organizations" on runs; -drop policy if exists "Members can manage workers in their organizations" on worker; -drop policy if exists "Members can manage events in their organizations" on events; -drop policy if exists "Members can view their organization's API keys" on third_party_api_keys; -drop policy if exists "Admins can manage their organization's API keys" on third_party_api_keys; - --- Create a function to check organization membership -create or replace function is_organization_member(org_id uuid) -returns boolean as $$ -begin - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - ); -end; -$$ language plpgsql security definer; - --- Create a function to check organization admin status -create or replace function is_organization_admin(org_id uuid) -returns boolean as $$ -begin - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ); -end; -$$ language plpgsql security definer; - --- Organizations policies -create policy "Enable read access for organization members" - on organizations for select - using (is_organization_member(id)); - -create policy "Enable write access for organization admins" - on organizations for all - using (is_organization_admin(id)); - --- Organization members policies -create policy "Enable read access for members" - on organization_members for select - using (is_organization_member(organization_id)); - -create policy "Enable write access for admins" - on organization_members for all - using (is_organization_admin(organization_id)); - --- Jobs policies -create policy "Enable access for organization members" - on jobs for all - using (is_organization_member(organization_id)); - --- Runs policies (through jobs) -create policy "Enable access for organization members" - on runs for all - using ( - exists ( - select 1 - from jobs - where jobs.id = runs.job_id - and is_organization_member(jobs.organization_id) - ) - ); - --- Workers policies -create policy "Enable access for organization members" - on worker for all - using (is_organization_member(organization_id)); - --- Events policies (through runs and jobs) -create policy "Enable access for organization members" - on events for all - using ( - exists ( - select 1 - from runs - join jobs on jobs.id = runs.job_id - where runs.id = events.run_id - and is_organization_member(jobs.organization_id) - ) - ); - --- Third-party API keys policies -create policy "Enable read for members" - on third_party_api_keys for select - using (is_organization_member(organization_id)); - -create policy "Enable write for admins" - on third_party_api_keys for all - using (is_organization_admin(organization_id)); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000005_add_organization_functions.sql b/supabase/migrations_dev/20241205000005_add_organization_functions.sql deleted file mode 100644 index 14291e2..0000000 --- a/supabase/migrations_dev/20241205000005_add_organization_functions.sql +++ /dev/null @@ -1,51 +0,0 @@ --- Function to get organization members with their email addresses -create or replace function public.get_organization_members(org_id uuid) -returns table ( - user_id uuid, - email text, - role public.organization_role -) security definer -set search_path = public -language plpgsql -as $$ -begin - return query - select - om.user_id, - au.email, - om.role - from public.organization_members om - join auth.users au on au.id = om.user_id - where om.organization_id = org_id - and exists ( - select 1 - from public.organization_members viewer - where viewer.organization_id = org_id - and viewer.user_id = auth.uid() - ); -end; -$$; - --- Function to get user by email -create or replace function public.get_user_by_email(user_email text) -returns table ( - id uuid, - email text -) security definer -set search_path = public -language plpgsql -as $$ -begin - return query - select - au.id, - au.email - from auth.users au - where au.email = user_email - limit 1; -end; -$$; - --- Grant execute permissions on the functions -grant execute on function public.get_organization_members(uuid) to authenticated; -grant execute on function public.get_user_by_email(text) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000006_fix_organization_function_types.sql b/supabase/migrations_dev/20241205000006_fix_organization_function_types.sql deleted file mode 100644 index c939638..0000000 --- a/supabase/migrations_dev/20241205000006_fix_organization_function_types.sql +++ /dev/null @@ -1,55 +0,0 @@ --- Drop existing functions -drop function if exists public.get_organization_members(uuid); -drop function if exists public.get_user_by_email(text); - --- Recreate functions with correct types -create or replace function public.get_organization_members(org_id uuid) -returns table ( - user_id uuid, - email varchar(255), - role public.organization_role -) security definer -set search_path = public -language plpgsql -as $$ -begin - return query - select - om.user_id, - au.email, - om.role - from public.organization_members om - join auth.users au on au.id = om.user_id - where om.organization_id = org_id - and exists ( - select 1 - from public.organization_members viewer - where viewer.organization_id = org_id - and viewer.user_id = auth.uid() - ); -end; -$$; - --- Function to get user by email -create or replace function public.get_user_by_email(user_email varchar(255)) -returns table ( - id uuid, - email varchar(255) -) security definer -set search_path = public -language plpgsql -as $$ -begin - return query - select - au.id, - au.email - from auth.users au - where au.email = user_email - limit 1; -end; -$$; - --- Grant execute permissions on the functions -grant execute on function public.get_organization_members(uuid) to authenticated; -grant execute on function public.get_user_by_email(varchar) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000007_fix_get_user_by_email_param.sql b/supabase/migrations_dev/20241205000007_fix_get_user_by_email_param.sql deleted file mode 100644 index 91d04af..0000000 --- a/supabase/migrations_dev/20241205000007_fix_get_user_by_email_param.sql +++ /dev/null @@ -1,25 +0,0 @@ --- Drop existing function -drop function if exists public.get_user_by_email(varchar); - --- Recreate function with parameter named 'email' -create or replace function public.get_user_by_email(email varchar(255)) -returns table ( - user_id uuid, - user_email varchar(255) -) security definer -set search_path = public -language plpgsql -as $$ -begin - return query - select - au.id as user_id, - au.email as user_email - from auth.users au - where au.email = get_user_by_email.email - limit 1; -end; -$$; - --- Grant execute permissions on the function -grant execute on function public.get_user_by_email(varchar) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000008_improve_get_user_by_email.sql b/supabase/migrations_dev/20241205000008_improve_get_user_by_email.sql deleted file mode 100644 index 6ca47b2..0000000 --- a/supabase/migrations_dev/20241205000008_improve_get_user_by_email.sql +++ /dev/null @@ -1,33 +0,0 @@ --- Drop existing function -drop function if exists public.get_user_by_email(varchar); - --- Recreate function with parameter named 'email' and better error handling -create or replace function public.get_user_by_email(email varchar(255)) -returns table ( - user_id uuid, - user_email varchar(255) -) security definer -set search_path = public -language plpgsql -as $$ -begin - if email is null then - raise exception 'Email parameter cannot be null'; - end if; - - return query - select - au.id as user_id, - au.email as user_email - from auth.users au - where lower(au.email) = lower(get_user_by_email.email) - limit 1; - - if not found then - raise exception 'User with email % not found', email; - end if; -end; -$$; - --- Grant execute permissions on the function -grant execute on function public.get_user_by_email(varchar) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000009_add_invite_member_function.sql b/supabase/migrations_dev/20241205000009_add_invite_member_function.sql deleted file mode 100644 index 4fcd423..0000000 --- a/supabase/migrations_dev/20241205000009_add_invite_member_function.sql +++ /dev/null @@ -1,60 +0,0 @@ --- Function to invite a member to an organization -create or replace function public.invite_organization_member( - org_id uuid, - member_email varchar(255), - member_role public.organization_role -) -returns table ( - user_id uuid, - email varchar(255), - role public.organization_role -) security definer -set search_path = public -language plpgsql -as $$ -declare - v_user_id uuid; - v_email varchar(255); -begin - -- Check if the inviter is an admin of the organization - if not exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ) then - raise exception 'Only organization admins can invite members'; - end if; - - -- Get the user ID for the email - select au.id, au.email - into v_user_id, v_email - from auth.users au - where lower(au.email) = lower(member_email); - - if v_user_id is null then - raise exception 'User with email % not found', member_email; - end if; - - -- Check if user is already a member - if exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = v_user_id - ) then - raise exception 'User is already a member of this organization'; - end if; - - -- Insert the new member - insert into organization_members (organization_id, user_id, role) - values (org_id, v_user_id, member_role) - returning user_id, v_email as email, role into user_id, email, role; - - return next; -end; -$$; - --- Grant execute permissions on the function -grant execute on function public.invite_organization_member(uuid, varchar, public.organization_role) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000010_fix_invite_member_function.sql b/supabase/migrations_dev/20241205000010_fix_invite_member_function.sql deleted file mode 100644 index 127a4fe..0000000 --- a/supabase/migrations_dev/20241205000010_fix_invite_member_function.sql +++ /dev/null @@ -1,67 +0,0 @@ --- Drop existing function -drop function if exists public.invite_organization_member(uuid, varchar, public.organization_role); - --- Recreate function with fixed column references -create or replace function public.invite_organization_member( - org_id uuid, - member_email varchar(255), - member_role public.organization_role -) -returns table ( - user_id uuid, - email varchar(255), - role public.organization_role -) security definer -set search_path = public -language plpgsql -as $$ -declare - v_user_id uuid; - v_email varchar(255); -begin - -- Check if the inviter is an admin of the organization - if not exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ) then - raise exception 'Only organization admins can invite members'; - end if; - - -- Get the user ID for the email - select au.id, au.email - into v_user_id, v_email - from auth.users au - where lower(au.email) = lower(member_email); - - if v_user_id is null then - raise exception 'User with email % not found', member_email; - end if; - - -- Check if user is already a member - if exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = v_user_id - ) then - raise exception 'User is already a member of this organization'; - end if; - - -- Insert the new member - insert into organization_members (organization_id, user_id, role) - values (org_id, v_user_id, member_role); - - -- Return the result - return query - select - v_user_id as user_id, - v_email as email, - member_role as role; -end; -$$; - --- Grant execute permissions on the function -grant execute on function public.invite_organization_member(uuid, varchar, public.organization_role) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000011_fix_get_members_function.sql b/supabase/migrations_dev/20241205000011_fix_get_members_function.sql deleted file mode 100644 index adb4a3a..0000000 --- a/supabase/migrations_dev/20241205000011_fix_get_members_function.sql +++ /dev/null @@ -1,33 +0,0 @@ --- Drop existing function -drop function if exists public.get_organization_members(uuid); - --- Recreate function with explicit column references -create or replace function public.get_organization_members(org_id uuid) -returns table ( - user_id uuid, - email varchar(255), - role public.organization_role -) security definer -set search_path = public -language plpgsql -as $$ -begin - return query - select - om.user_id, - au.email, - om.role - from public.organization_members om - join auth.users au on au.id = om.user_id - where om.organization_id = org_id - and exists ( - select 1 - from public.organization_members viewer - where viewer.organization_id = org_id - and viewer.user_id = auth.uid() - ); -end; -$$; - --- Grant execute permissions on the function -grant execute on function public.get_organization_members(uuid) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000012_fix_organization_check_functions.sql b/supabase/migrations_dev/20241205000012_fix_organization_check_functions.sql deleted file mode 100644 index 970f14a..0000000 --- a/supabase/migrations_dev/20241205000012_fix_organization_check_functions.sql +++ /dev/null @@ -1,102 +0,0 @@ --- First drop all policies that use these functions -drop policy if exists "Enable read access for organization members" on organizations; -drop policy if exists "Enable write access for organization admins" on organizations; -drop policy if exists "Enable read access for members" on organization_members; -drop policy if exists "Enable write access for admins" on organization_members; -drop policy if exists "Enable access for organization members" on jobs; -drop policy if exists "Enable access for organization members" on runs; -drop policy if exists "Enable access for organization members" on worker; -drop policy if exists "Enable access for organization members" on events; -drop policy if exists "Enable read for members" on third_party_api_keys; -drop policy if exists "Enable write for admins" on third_party_api_keys; - --- Now we can safely drop the functions -drop function if exists public.is_organization_member(uuid); -drop function if exists public.is_organization_admin(uuid); - --- Recreate is_organization_member function with explicit column references -create or replace function is_organization_member(org_id uuid) -returns boolean as $$ -begin - return exists ( - select 1 - from organization_members om - where om.organization_id = org_id - and om.user_id = auth.uid() - ); -end; -$$ language plpgsql security definer; - --- Recreate is_organization_admin function with explicit column references -create or replace function is_organization_admin(org_id uuid) -returns boolean as $$ -begin - return exists ( - select 1 - from organization_members om - where om.organization_id = org_id - and om.user_id = auth.uid() - and om.role = 'admin' - ); -end; -$$ language plpgsql security definer; - --- Recreate policies with explicit table references -create policy "Enable read access for organization members" - on organizations for select - using (is_organization_member(id)); - -create policy "Enable write access for organization admins" - on organizations for all - using (is_organization_admin(id)); - -create policy "Enable read access for members" - on organization_members for select - using (is_organization_member(organization_id)); - -create policy "Enable write access for admins" - on organization_members for all - using (is_organization_admin(organization_id)); - -create policy "Enable access for organization members" - on jobs for all - using (is_organization_member(organization_id)); - -create policy "Enable access for organization members" - on runs for all - using ( - exists ( - select 1 - from jobs j - where j.id = runs.job_id - and is_organization_member(j.organization_id) - ) - ); - -create policy "Enable access for organization members" - on worker for all - using (is_organization_member(organization_id)); - -create policy "Enable access for organization members" - on events for all - using ( - exists ( - select 1 - from runs r - join jobs j on j.id = r.job_id - where r.id = events.run_id - and is_organization_member(j.organization_id) - ) - ); - -create policy "Enable read for members" - on third_party_api_keys for select - using (is_organization_member(organization_id)); - -create policy "Enable write for admins" - on third_party_api_keys for all - using (is_organization_admin(organization_id)); - --- Grant execute permissions -grant execute on function public.is_organization_member(uuid) to authenticated; -grant execute on function public.is_organization_admin(uuid) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000013_fix_all_user_id_references.sql b/supabase/migrations_dev/20241205000013_fix_all_user_id_references.sql deleted file mode 100644 index 5f2d2e5..0000000 --- a/supabase/migrations_dev/20241205000013_fix_all_user_id_references.sql +++ /dev/null @@ -1,96 +0,0 @@ --- Drop and recreate get_user_by_email function with explicit references -drop function if exists public.get_user_by_email(varchar); - -create or replace function public.get_user_by_email(email varchar(255)) -returns table ( - user_id uuid, - user_email varchar(255) -) security definer -set search_path = public -language plpgsql -as $$ -begin - if email is null then - raise exception 'Email parameter cannot be null'; - end if; - - return query - select - au.id as user_id, - au.email as user_email - from auth.users au - where lower(au.email) = lower(get_user_by_email.email) - limit 1; - - if not found then - raise exception 'User with email % not found', email; - end if; -end; -$$; - --- Drop and recreate invite_organization_member function with explicit references -drop function if exists public.invite_organization_member(uuid, varchar, public.organization_role); - -create or replace function public.invite_organization_member( - org_id uuid, - member_email varchar(255), - member_role public.organization_role -) -returns table ( - user_id uuid, - email varchar(255), - role public.organization_role -) security definer -set search_path = public -language plpgsql -as $$ -declare - v_user_id uuid; - v_email varchar(255); -begin - -- Check if the inviter is an admin of the organization - if not exists ( - select 1 - from organization_members om - where om.organization_id = org_id - and om.user_id = auth.uid() - and om.role = 'admin' - ) then - raise exception 'Only organization admins can invite members'; - end if; - - -- Get the user ID for the email - select au.id, au.email - into v_user_id, v_email - from auth.users au - where lower(au.email) = lower(member_email); - - if v_user_id is null then - raise exception 'User with email % not found', member_email; - end if; - - -- Check if user is already a member - if exists ( - select 1 - from organization_members om - where om.organization_id = org_id - and om.user_id = v_user_id - ) then - raise exception 'User is already a member of this organization'; - end if; - - -- Insert the new member - insert into organization_members (organization_id, user_id, role) - values (org_id, v_user_id, member_role); - - -- Return the result explicitly - user_id := v_user_id; - email := v_email; - role := member_role; - return next; -end; -$$; - --- Grant execute permissions -grant execute on function public.get_user_by_email(varchar) to authenticated; -grant execute on function public.invite_organization_member(uuid, varchar, public.organization_role) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000014_add_storage_policies.sql b/supabase/migrations_dev/20241205000014_add_storage_policies.sql deleted file mode 100644 index f1ec8ef..0000000 --- a/supabase/migrations_dev/20241205000014_add_storage_policies.sql +++ /dev/null @@ -1,79 +0,0 @@ --- Create storage policies for the files bucket --- Note: The bucket itself needs to be created manually in the dashboard - --- Drop any existing policies to avoid conflicts -begin; - drop policy if exists "Organization members can read their files" on storage.objects; - drop policy if exists "Organization members can upload files" on storage.objects; - drop policy if exists "Organization members can delete their files" on storage.objects; - - -- Policy for reading files: - -- Allow if user is member of the organization that owns the file - create policy "Organization members can read their files" - on storage.objects for select - using ( - -- Check if file is in an organization folder - (storage.foldername(name))[1] = 'organizations' - and ( - -- User is member of the organization specified in the path - exists ( - select 1 - from public.organization_members - where organization_id = uuid((storage.foldername(name))[2]) - and user_id = auth.uid() - ) - ) - ); - - -- Policy for inserting files: - -- Allow if user is member of the organization and file path matches organization - create policy "Organization members can upload files" - on storage.objects for insert - with check ( - -- Ensure file is being uploaded to an organization folder - (storage.foldername(name))[1] = 'organizations' - and ( - -- User is member of the organization specified in the path - exists ( - select 1 - from public.organization_members - where organization_id = uuid((storage.foldername(name))[2]) - and user_id = auth.uid() - ) - ) - ); - - -- Policy for deleting files: - -- Allow if user is admin of the organization that owns the file - create policy "Organization members can delete their files" - on storage.objects for delete - using ( - -- Check if file is in an organization folder - (storage.foldername(name))[1] = 'organizations' - and ( - -- User is admin of the organization specified in the path - exists ( - select 1 - from public.organization_members - where organization_id = uuid((storage.foldername(name))[2]) - and user_id = auth.uid() - and role = 'admin' - ) - ) - ); - - -- Create a function to get the correct organization path for file storage - create or replace function public.get_organization_storage_path( - org_id uuid, - filename text - ) returns text - language sql - stable - as $$ - select 'organizations/' || org_id || '/' || filename; - $$; - - -- Grant execute permissions - grant execute on function public.get_organization_storage_path(uuid, text) to authenticated; - -commit; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000015_migrate_existing_files.sql b/supabase/migrations_dev/20241205000015_migrate_existing_files.sql deleted file mode 100644 index e7451b2..0000000 --- a/supabase/migrations_dev/20241205000015_migrate_existing_files.sql +++ /dev/null @@ -1,40 +0,0 @@ --- Function to migrate existing files to default organization -create or replace function migrate_files_to_default_organization() -returns void -language plpgsql -security definer -as $$ -declare - file_record record; - new_path text; - default_org_id uuid := '00000000-0000-0000-0000-000000000000'; -begin - -- Loop through all files that are not in an organizations folder - for file_record in - select name, id - from storage.objects - where (storage.foldername(name))[1] != 'organizations' - loop - -- Generate new path - new_path := 'organizations/' || default_org_id || '/' || - case - when position('/' in file_record.name) > 0 - then substring(file_record.name from position('/' in file_record.name) + 1) - else file_record.name - end; - - -- Move file to new location - update storage.objects - set name = new_path - where id = file_record.id; - - raise notice 'Moved file % to %', file_record.name, new_path; - end loop; -end; -$$; - --- Execute the migration function -select migrate_files_to_default_organization(); - --- Drop the migration function as it's no longer needed -drop function migrate_files_to_default_organization(); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000016_fix_file_migration.sql b/supabase/migrations_dev/20241205000016_fix_file_migration.sql deleted file mode 100644 index ffebc6d..0000000 --- a/supabase/migrations_dev/20241205000016_fix_file_migration.sql +++ /dev/null @@ -1,67 +0,0 @@ --- Function to ensure folder exists and migrate files -create or replace function migrate_files_to_default_organization_v2() -returns void -language plpgsql -security definer -as $$ -declare - file_record record; - new_path text; - default_org_id uuid := '00000000-0000-0000-0000-000000000000'; -begin - -- Create the organizations folder if it doesn't exist - insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version) - values ('files', 'organizations/.keep', auth.uid(), now(), now(), '1') - on conflict do nothing; - - -- Create the default organization folder if it doesn't exist - insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version) - values ('files', 'organizations/' || default_org_id || '/.keep', auth.uid(), now(), now(), '1') - on conflict do nothing; - - -- Loop through all files that are not in an organizations folder - for file_record in - select name, id, owner, created_at, updated_at, version, metadata - from storage.objects - where bucket_id = 'files' - and name not like 'organizations/%' - and name not like '%.keep' - loop - -- Generate new path - new_path := 'organizations/' || default_org_id || '/' || file_record.name; - - -- Insert file with new path - insert into storage.objects ( - bucket_id, - name, - owner, - created_at, - updated_at, - version, - metadata - ) - values ( - 'files', - new_path, - file_record.owner, - file_record.created_at, - file_record.updated_at, - file_record.version, - file_record.metadata - ); - - -- Delete old file entry - delete from storage.objects - where id = file_record.id; - - raise notice 'Moved file % to %', file_record.name, new_path; - end loop; -end; -$$; - --- Execute the migration function -select migrate_files_to_default_organization_v2(); - --- Drop the migration functions as they're no longer needed -drop function if exists migrate_files_to_default_organization(); -drop function migrate_files_to_default_organization_v2(); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000017_migrate_file_references.sql b/supabase/migrations_dev/20241205000017_migrate_file_references.sql deleted file mode 100644 index 30d54eb..0000000 --- a/supabase/migrations_dev/20241205000017_migrate_file_references.sql +++ /dev/null @@ -1,76 +0,0 @@ --- Function to get the organization storage path -create or replace function get_file_path(file_id text) -returns text -language sql -stable -as $$ - select 'organizations/00000000-0000-0000-0000-000000000000/' || file_id; -$$; - --- Function to update file references in a JSON object -create or replace function update_file_refs_in_json(data jsonb) -returns jsonb -language plpgsql -as $$ -declare - result jsonb := data; - file_id text; -begin - -- Handle direct 'file' key - if result ? 'file' and (result->>'file' is not null) then - file_id := result->>'file'; - result := jsonb_set(result, '{file}', to_jsonb(get_file_path(file_id))); - end if; - - -- Handle input_file_id - if result ? 'input_file_id' and (result->>'input_file_id' is not null) then - file_id := result->>'input_file_id'; - result := jsonb_set(result, '{input_file_id}', to_jsonb(get_file_path(file_id))); - end if; - - -- Handle training_file - if result ? 'training_file' and (result->>'training_file' is not null) then - file_id := result->>'training_file'; - result := jsonb_set(result, '{training_file}', to_jsonb(get_file_path(file_id))); - end if; - - -- Handle test_file - if result ? 'test_file' and (result->>'test_file' is not null) then - file_id := result->>'test_file'; - result := jsonb_set(result, '{test_file}', to_jsonb(get_file_path(file_id))); - end if; - - return result; -end; -$$; - --- Begin the migration -do $$ -declare - r record; -begin - -- Update runs.log_file - update runs - set log_file = get_file_path(log_file) - where log_file is not null - and log_file not like 'organizations/%'; - - -- Update jobs.outputs - update jobs - set outputs = update_file_refs_in_json(outputs) - where outputs is not null - and outputs::text not like '%organizations/%'; - - -- Update jobs.params - update jobs - set params = update_file_refs_in_json(params) - where params is not null - and params::text not like '%organizations/%'; - - raise notice 'File references migration completed'; -end; -$$; - --- Clean up temporary functions -drop function update_file_refs_in_json(jsonb); -drop function get_file_path(text); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000018_revert_file_references.sql b/supabase/migrations_dev/20241205000018_revert_file_references.sql deleted file mode 100644 index 0b609fd..0000000 --- a/supabase/migrations_dev/20241205000018_revert_file_references.sql +++ /dev/null @@ -1,62 +0,0 @@ --- Revert any changes to file references -update runs -set log_file = split_part(log_file, '/', 3) -where log_file like 'organizations/%'; - --- For jobs.outputs and jobs.params, we need to extract the file ID from the path -create or replace function extract_file_id_from_path(path text) -returns text -language sql -immutable -as $$ - select split_part(path, '/', 3) - where path like 'organizations/%'; -$$; - -create or replace function revert_file_refs_in_json(data jsonb) -returns jsonb -language plpgsql -as $$ -declare - result jsonb := data; - file_path text; -begin - -- Handle direct 'file' key - if result ? 'file' and (result->>'file' like 'organizations/%') then - file_path := result->>'file'; - result := jsonb_set(result, '{file}', to_jsonb(extract_file_id_from_path(file_path))); - end if; - - -- Handle input_file_id - if result ? 'input_file_id' and (result->>'input_file_id' like 'organizations/%') then - file_path := result->>'input_file_id'; - result := jsonb_set(result, '{input_file_id}', to_jsonb(extract_file_id_from_path(file_path))); - end if; - - -- Handle training_file - if result ? 'training_file' and (result->>'training_file' like 'organizations/%') then - file_path := result->>'training_file'; - result := jsonb_set(result, '{training_file}', to_jsonb(extract_file_id_from_path(file_path))); - end if; - - -- Handle test_file - if result ? 'test_file' and (result->>'test_file' like 'organizations/%') then - file_path := result->>'test_file'; - result := jsonb_set(result, '{test_file}', to_jsonb(extract_file_id_from_path(file_path))); - end if; - - return result; -end; -$$; - --- Revert changes in jobs table -update jobs -set - outputs = revert_file_refs_in_json(outputs), - params = revert_file_refs_in_json(params) -where - (outputs::text like '%organizations/%' or params::text like '%organizations/%'); - --- Clean up functions -drop function revert_file_refs_in_json(jsonb); -drop function extract_file_id_from_path(text); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000019_verify_file_locations.sql b/supabase/migrations_dev/20241205000019_verify_file_locations.sql deleted file mode 100644 index dc8e2e3..0000000 --- a/supabase/migrations_dev/20241205000019_verify_file_locations.sql +++ /dev/null @@ -1,115 +0,0 @@ --- Create a function to list all files in storage -create or replace function list_storage_files() -returns table ( - id bigint, - name text, - bucket_id text, - owner uuid, - created_at timestamptz, - updated_at timestamptz, - last_accessed_at timestamptz, - metadata jsonb, - version text, - size bigint -) -security definer -language plpgsql -as $$ -begin - return query - select * - from storage.objects - where bucket_id = 'files' - order by name; -end; -$$; - --- Grant execute permission -grant execute on function list_storage_files() to postgres; - --- Create a function to check and fix file locations -create or replace function verify_and_fix_file_locations() -returns table ( - file_name text, - old_path text, - new_path text, - status text -) -language plpgsql -security definer -as $$ -declare - file_record record; - new_path text; - default_org_id uuid := '00000000-0000-0000-0000-000000000000'; -begin - create temp table if not exists migration_log ( - file_name text, - old_path text, - new_path text, - status text - ); - - for file_record in - select * from storage.objects - where bucket_id = 'files' - and name not like 'organizations/%' - and name not like '%.keep' - loop - new_path := 'organizations/' || default_org_id || '/' || file_record.name; - - begin - -- Copy the file to new location - insert into storage.objects ( - bucket_id, - name, - owner, - created_at, - updated_at, - version, - size, - metadata - ) - select - bucket_id, - new_path, - owner, - created_at, - updated_at, - version, - size, - metadata - from storage.objects - where id = file_record.id; - - -- Log successful copy - insert into migration_log values ( - file_record.name, - file_record.name, - new_path, - 'copied' - ); - - exception when others then - -- Log failed copy - insert into migration_log values ( - file_record.name, - file_record.name, - new_path, - 'failed: ' || SQLERRM - ); - end; - end loop; - - return query select * from migration_log; - - drop table migration_log; -end; -$$; - --- Execute the verification and get results -select * from verify_and_fix_file_locations(); - --- Clean up -drop function verify_and_fix_file_locations(); -drop function list_storage_files(); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000020_fix_storage_locations.sql b/supabase/migrations_dev/20241205000020_fix_storage_locations.sql deleted file mode 100644 index 37ed383..0000000 --- a/supabase/migrations_dev/20241205000020_fix_storage_locations.sql +++ /dev/null @@ -1,102 +0,0 @@ --- Create a function to move files in storage -create or replace function move_files_in_storage() -returns table ( - file_name text, - old_path text, - new_path text, - status text -) -language plpgsql -security definer -as $$ -declare - file_record record; - new_path text; - default_org_id uuid := '00000000-0000-0000-0000-000000000000'; -begin - create temp table if not exists migration_log ( - file_name text, - old_path text, - new_path text, - status text - ); - - -- First ensure the organization folders exist - insert into storage.objects ( - bucket_id, - name, - owner, - created_at, - updated_at, - version, - metadata - ) values - ('files', 'organizations/.keep', auth.uid(), now(), now(), '1', '{}'), - ('files', 'organizations/' || default_org_id || '/.keep', auth.uid(), now(), now(), '1', '{}') - on conflict (bucket_id, name) do nothing; - - -- Move each file - for file_record in - select * from storage.objects - where bucket_id = 'files' - and name not like 'organizations/%' - and name not like '%.keep' - loop - new_path := 'organizations/' || default_org_id || '/' || file_record.name; - - begin - -- Move the file using storage.copy_object - insert into storage.objects ( - bucket_id, - name, - owner, - created_at, - updated_at, - version, - metadata - ) - select - bucket_id, - new_path, - owner, - created_at, - updated_at, - version, - metadata - from storage.objects - where id = file_record.id - returning name into new_path; - - -- If copy successful, delete the old object - delete from storage.objects where id = file_record.id; - - -- Log successful move - insert into migration_log values ( - file_record.name, - file_record.name, - new_path, - 'moved' - ); - - exception when others then - -- Log failed move - insert into migration_log values ( - file_record.name, - file_record.name, - new_path, - 'failed: ' || SQLERRM - ); - end; - end loop; - - return query select * from migration_log; - - drop table migration_log; -end; -$$; - --- Execute the move operation and get results -select * from move_files_in_storage(); - --- Clean up -drop function move_files_in_storage(); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000021_check_storage.sql b/supabase/migrations_dev/20241205000021_check_storage.sql deleted file mode 100644 index 39cceef..0000000 --- a/supabase/migrations_dev/20241205000021_check_storage.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Check storage objects -select id, bucket_id, name -from storage.objects -where name like '%14e5eef3bbad%' -or name like '%organizations%' -order by name; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000022_fix_storage_policies.sql b/supabase/migrations_dev/20241205000022_fix_storage_policies.sql deleted file mode 100644 index 2a9b80b..0000000 --- a/supabase/migrations_dev/20241205000022_fix_storage_policies.sql +++ /dev/null @@ -1,71 +0,0 @@ --- Drop existing policies to avoid conflicts -drop policy if exists "Organization members can read their files" on storage.objects; -drop policy if exists "Organization members can upload files" on storage.objects; -drop policy if exists "Organization members can delete their files" on storage.objects; - --- Enable RLS on storage.objects -alter table storage.objects enable row level security; - --- Policy for reading files: -create policy "Organization members can read their files" -on storage.objects for select -using ( - bucket_id = 'files' - and ( - -- Allow access to .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is a member - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = uuid(split_part(name, '/', 2)) - and user_id = auth.uid() - ) - ) - ) -); - --- Policy for inserting files: -create policy "Organization members can upload files" -on storage.objects for insert -with check ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is a member - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = uuid(split_part(name, '/', 2)) - and user_id = auth.uid() - ) - ) - ) -); - --- Policy for deleting files: -create policy "Organization members can delete their files" -on storage.objects for delete -using ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is an admin - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = uuid(split_part(name, '/', 2)) - and user_id = auth.uid() - and role = 'admin' - ) - ) - ) -); \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000023_add_api_tokens.sql b/supabase/migrations_dev/20241205000023_add_api_tokens.sql deleted file mode 100644 index d969703..0000000 --- a/supabase/migrations_dev/20241205000023_add_api_tokens.sql +++ /dev/null @@ -1,122 +0,0 @@ --- Create API tokens table -create table if not exists "public"."api_tokens" ( - "id" uuid not null default gen_random_uuid(), - "organization_id" uuid not null references organizations(id) on delete cascade, - "name" text not null, - "token" text not null, - "created_at" timestamp with time zone default timezone('utc'::text, now()) not null, - "created_by" uuid not null references auth.users(id) on delete cascade, - "last_used_at" timestamp with time zone, - primary key (id) -); - --- Enable RLS -alter table "public"."api_tokens" enable row level security; - --- Create policy for reading tokens -create policy "Organization members can view their tokens" - on api_tokens for select - using ( - exists ( - select 1 from organization_members - where organization_id = api_tokens.organization_id - and user_id = auth.uid() - ) - ); - --- Create policy for managing tokens (admin only) -create policy "Organization admins can manage tokens" - on api_tokens for all - using ( - exists ( - select 1 from organization_members - where organization_id = api_tokens.organization_id - and user_id = auth.uid() - and role = 'admin' - ) - ); - --- Function to generate a secure random token -create or replace function generate_api_token() -returns text -language plpgsql -as $$ -declare - token text; -begin - -- Generate a random UUID and encode it to base64 - token := encode(digest(gen_random_uuid()::text || now()::text, 'sha256'), 'base64'); - -- Remove any non-alphanumeric characters and trim to 32 characters - token := regexp_replace(token, '[^a-zA-Z0-9]', '', 'g'); - return substring(token, 1, 32); -end; -$$; - --- Function to create a new API token -create or replace function create_api_token( - org_id uuid, - token_name text -) -returns table ( - id uuid, - name text, - token text, - created_at timestamptz -) -language plpgsql -security definer -as $$ -declare - new_token text; - result record; -begin - -- Check if user is an admin of the organization - if not exists ( - select 1 from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ) then - raise exception 'Only organization admins can create API tokens'; - end if; - - -- Generate new token - new_token := generate_api_token(); - - -- Insert new token - insert into api_tokens (organization_id, name, token, created_by) - values (org_id, token_name, new_token, auth.uid()) - returning id, name, token, created_at into result; - - return query select result.id, result.name, result.token, result.created_at; -end; -$$; - --- Function to delete an API token -create or replace function delete_api_token(token_id uuid) -returns void -language plpgsql -security definer -as $$ -begin - -- Check if user is an admin of the organization that owns the token - if not exists ( - select 1 from api_tokens t - join organization_members m on m.organization_id = t.organization_id - where t.id = token_id - and m.user_id = auth.uid() - and m.role = 'admin' - ) then - raise exception 'Only organization admins can delete API tokens'; - end if; - - -- Delete the token - delete from api_tokens where id = token_id; -end; -$$; - --- Grant necessary permissions -grant all on table api_tokens to postgres, authenticated; -grant execute on function generate_api_token() to postgres, authenticated; -grant execute on function create_api_token(uuid, text) to postgres, authenticated; -grant execute on function delete_api_token(uuid) to postgres, authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241205000024_add_token_validation.sql b/supabase/migrations_dev/20241205000024_add_token_validation.sql deleted file mode 100644 index 4e3ed2e..0000000 --- a/supabase/migrations_dev/20241205000024_add_token_validation.sql +++ /dev/null @@ -1,28 +0,0 @@ --- Function to validate an API token -create or replace function validate_api_token(token_text text) -returns table ( - organization_id uuid -) -language plpgsql -security definer -as $$ -begin - -- Update last_used_at timestamp - update api_tokens - set last_used_at = now() - where token = token_text; - - -- Return the organization_id if token is valid - return query - select t.organization_id - from api_tokens t - where t.token = token_text; - - if not found then - raise exception 'Invalid API token'; - end if; -end; -$$; - --- Grant execute permission -grant execute on function validate_api_token(text) to authenticated, anon; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000001_add_long_lived_tokens.sql b/supabase/migrations_dev/20241219000001_add_long_lived_tokens.sql deleted file mode 100644 index f9d1e32..0000000 --- a/supabase/migrations_dev/20241219000001_add_long_lived_tokens.sql +++ /dev/null @@ -1,40 +0,0 @@ --- Create tokens table to keep track of issued tokens -create table if not exists "public"."tokens" ( - "id" uuid not null default gen_random_uuid(), - "organization_id" uuid not null references organizations(id) on delete cascade, - "name" text not null, - "expires_at" timestamp with time zone, - "created_at" timestamp with time zone default timezone('utc'::text, now()) not null, - "created_by" uuid not null references auth.users(id) on delete cascade, - "last_used_at" timestamp with time zone, - primary key (id) -); - --- Enable RLS -alter table "public"."tokens" enable row level security; - --- Create policy for reading tokens -create policy "Organization members can view their tokens" - on tokens for select - using ( - exists ( - select 1 from organization_members - where organization_id = tokens.organization_id - and user_id = auth.uid() - ) - ); - --- Create policy for managing tokens (admin only) -create policy "Organization admins can manage tokens" - on tokens for all - using ( - exists ( - select 1 from organization_members - where organization_id = tokens.organization_id - and user_id = auth.uid() - and role = 'admin' - ) - ); - --- Grant necessary permissions -grant all on table tokens to postgres, authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000002_add_token_hash.sql b/supabase/migrations_dev/20241219000002_add_token_hash.sql deleted file mode 100644 index 7cef74b..0000000 --- a/supabase/migrations_dev/20241219000002_add_token_hash.sql +++ /dev/null @@ -1,3 +0,0 @@ --- Add token_hash column to tokens table -alter table "public"."tokens" - add column "token_hash" text; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000003_update_rls_for_custom_tokens.sql b/supabase/migrations_dev/20241219000003_update_rls_for_custom_tokens.sql deleted file mode 100644 index a2caed0..0000000 --- a/supabase/migrations_dev/20241219000003_update_rls_for_custom_tokens.sql +++ /dev/null @@ -1,100 +0,0 @@ --- Function to get organization ID from token -create or replace function get_organization_from_token() -returns uuid -language plpgsql -security definer -as $$ -declare - auth_token text; - org_id uuid; -begin - -- Get the Authorization header value - auth_token := coalesce( - current_setting('request.headers', true)::json->>'authorization', - '' - ); - - -- Extract token from "Bearer " - auth_token := replace(auth_token, 'Bearer ', ''); - - -- If it's a custom token (starts with ow_) - if auth_token like 'ow_%' then - -- Look up organization ID from tokens table - select organization_id into org_id - from tokens - where token_hash = auth_token - and (expires_at is null or expires_at > now()); - - if found then - -- Update last_used_at - update tokens - set last_used_at = now() - where token_hash = auth_token; - - return org_id; - end if; - end if; - - -- If not a custom token or token not found, - -- return null to fall back to normal auth - return null; -end; -$$; - --- Function to check if current token has access to an organization -create or replace function has_organization_access(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -declare - token_org_id uuid; -begin - -- First check custom tokens - token_org_id := get_organization_from_token(); - if token_org_id is not null then - return token_org_id = org_id; - end if; - - -- Fall back to checking organization membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - ); -end; -$$; - --- Update all RLS policies to use the new function -create or replace function is_organization_member(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -begin - return has_organization_access(org_id); -end; -$$; - -create or replace function is_organization_admin(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -begin - -- For custom tokens, they have admin access - if get_organization_from_token() is not null then - return true; - end if; - - -- Otherwise check normal admin membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ); -end; -$$; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000004_update_token_auth.sql b/supabase/migrations_dev/20241219000004_update_token_auth.sql deleted file mode 100644 index cc452ed..0000000 --- a/supabase/migrations_dev/20241219000004_update_token_auth.sql +++ /dev/null @@ -1,128 +0,0 @@ --- Drop existing functions that we're replacing -drop function if exists get_organization_from_token(); -drop function if exists has_organization_access(uuid); - --- Function to get organization ID from token -create or replace function get_organization_from_token() -returns uuid -language plpgsql -security definer -as $$ -declare - auth_token text; - custom_token text; - org_id uuid; -begin - -- First try custom token header - custom_token := coalesce( - current_setting('request.headers', true)::json->>'x-openweights-token', - '' - ); - - if custom_token != '' and custom_token like 'ow_%' then - -- Look up organization ID from tokens table - select organization_id into org_id - from tokens - where token_hash = custom_token - and (expires_at is null or expires_at > now()); - - if found then - -- Update last_used_at - update tokens - set last_used_at = now() - where token_hash = custom_token; - - return org_id; - end if; - end if; - - -- If no custom token, try normal auth - auth_token := coalesce( - current_setting('request.headers', true)::json->>'authorization', - '' - ); - - -- Extract token from "Bearer " - auth_token := replace(auth_token, 'Bearer ', ''); - - -- If it's a valid JWT, auth.uid() will work - if auth_token != '' then - -- Return organization ID from membership - select organization_id into org_id - from organization_members - where user_id = auth.uid() - limit 1; - - return org_id; - end if; - - -- No valid token found - return null; -end; -$$; - --- Function to check if current token has access to an organization -create or replace function has_organization_access(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -declare - token_org_id uuid; -begin - -- First check custom tokens - token_org_id := get_organization_from_token(); - if token_org_id is not null then - return token_org_id = org_id; - end if; - - -- Fall back to checking organization membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - ); -end; -$$; - --- Update existing functions to use the new logic -create or replace function is_organization_member(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -begin - return has_organization_access(org_id); -end; -$$; - -create or replace function is_organization_admin(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -declare - custom_token text; -begin - -- First try custom token header - custom_token := coalesce( - current_setting('request.headers', true)::json->>'x-openweights-token', - '' - ); - - -- API tokens have admin access - if custom_token != '' and custom_token like 'ow_%' then - return has_organization_access(org_id); - end if; - - -- Otherwise check normal admin membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ); -end; -$$; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000005_service_accounts.sql b/supabase/migrations_dev/20241219000005_service_accounts.sql deleted file mode 100644 index 9295255..0000000 --- a/supabase/migrations_dev/20241219000005_service_accounts.sql +++ /dev/null @@ -1,121 +0,0 @@ --- Add service account metadata to auth.users -create or replace function auth.check_if_service_account() -returns boolean as $$ - select coalesce( - current_setting('request.jwt.claims', true)::json->>'is_service_account', - 'false' - )::boolean; -$$ language sql security definer; - --- Update RLS functions to handle service accounts -create or replace function get_organization_from_token() -returns uuid -language plpgsql -security definer -as $$ -declare - org_id uuid; -begin - -- If this is a service account token, get org from claims - if auth.check_if_service_account() then - org_id := (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid; - - -- Update last_used_at in tokens table - update tokens - set last_used_at = now() - where id = (current_setting('request.jwt.claims', true)::json->>'token_id')::uuid; - - return org_id; - end if; - - -- Otherwise, get organization from membership - select organization_id into org_id - from organization_members - where user_id = auth.uid() - limit 1; - - return org_id; -end; -$$; - --- Update is_organization_admin to give service accounts admin access -create or replace function is_organization_admin(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -begin - -- Service accounts have admin access to their organization - if auth.check_if_service_account() then - return get_organization_from_token() = org_id; - end if; - - -- Otherwise check normal admin membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ); -end; -$$; - --- Add function to create service account -create or replace function create_service_account_token( - org_id uuid, - token_name text, - expires_at timestamp with time zone default null -) -returns table ( - token_id uuid, - jwt_token text -) -language plpgsql -security definer -set search_path = public -as $$ -declare - v_token_id uuid; - v_user_id uuid; - v_jwt_secret text; - v_jwt_token text; -begin - -- Get JWT secret from vault - select decrypted_secret into v_jwt_secret - from vault.decrypted_secrets - where name = 'jwt_secret' - limit 1; - - if v_jwt_secret is null then - raise exception 'JWT secret not found in vault'; - end if; - - -- Create token record - insert into tokens (organization_id, name, expires_at, created_by) - values (org_id, token_name, expires_at, auth.uid()) - returning id into v_token_id; - - -- Create JWT token with service account claims - v_jwt_token := extensions.sign( - json_build_object( - 'role', 'authenticated', - 'iss', 'supabase', - 'iat', extract(epoch from now())::integer, - 'exp', case - when expires_at is null then extract(epoch from now() + interval '10 years')::integer - else extract(epoch from expires_at)::integer - end, - 'is_service_account', true, - 'organization_id', org_id, - 'token_id', v_token_id - )::json, - v_jwt_secret - ); - - return query select v_token_id, v_jwt_token; -end; -$$; - --- Grant execute permissions -grant execute on function create_service_account_token(uuid, text, timestamp with time zone) to authenticated; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000006_cleanup_token_tables.sql b/supabase/migrations_dev/20241219000006_cleanup_token_tables.sql deleted file mode 100644 index ad886a6..0000000 --- a/supabase/migrations_dev/20241219000006_cleanup_token_tables.sql +++ /dev/null @@ -1,9 +0,0 @@ --- Drop the older api_tokens table and related functions -drop function if exists public.generate_api_token(); -drop function if exists public.create_api_token(uuid, text); -drop function if exists public.delete_api_token(uuid); -drop table if exists public.api_tokens; - --- Keep the newer tokens table and its functions --- But add an index on token_hash for faster lookups -create index if not exists idx_tokens_token_hash on public.tokens(token_hash); \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000007_fix_token_creation.sql b/supabase/migrations_dev/20241219000007_fix_token_creation.sql deleted file mode 100644 index 90f37ef..0000000 --- a/supabase/migrations_dev/20241219000007_fix_token_creation.sql +++ /dev/null @@ -1,60 +0,0 @@ --- Drop existing function -drop function if exists create_service_account_token(uuid, text, timestamp with time zone); - --- Recreate with user_id parameter -create or replace function create_service_account_token( - org_id uuid, - token_name text, - created_by uuid, - expires_at timestamp with time zone default null -) -returns table ( - token_id uuid, - jwt_token text -) -language plpgsql -security definer -set search_path = public -as $$ -declare - v_token_id uuid; - v_jwt_secret text; - v_jwt_token text; -begin - -- Get JWT secret from env var for now (we'll move this to vault later) - v_jwt_secret := current_setting('app.jwt_secret', true); - - if v_jwt_secret is null then - raise exception 'JWT secret not configured'; - end if; - - -- Create token record - insert into tokens (organization_id, name, expires_at, created_by) - values (org_id, token_name, expires_at, created_by) - returning id into v_token_id; - - -- Create JWT token with service account claims - v_jwt_token := extensions.sign( - json_build_object( - 'role', 'authenticated', - 'iss', 'supabase', - 'iat', extract(epoch from now())::integer, - 'exp', case - when expires_at is null then extract(epoch from now() + interval '10 years')::integer - else extract(epoch from expires_at)::integer - end, - 'is_service_account', true, - 'organization_id', org_id, - 'token_id', v_token_id - )::json, - v_jwt_secret - ); - - -- Update token hash - update tokens - set token_hash = v_jwt_token - where id = v_token_id; - - return query select v_token_id, v_jwt_token; -end; -$$; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000009_jwt_secret_function.sql b/supabase/migrations_dev/20241219000009_jwt_secret_function.sql deleted file mode 100644 index bbd0e43..0000000 --- a/supabase/migrations_dev/20241219000009_jwt_secret_function.sql +++ /dev/null @@ -1,69 +0,0 @@ --- Create a function to get the JWT secret -create or replace function get_jwt_secret() -returns text -language plpgsql -security definer -as $$ -begin - -- This is a placeholder - in production, you'd want to store this more securely - return 'AsQjcwl78lW6ND4aXiFXGg5bEuEfw7fcnte8opUfUrTR65Mz83YuksM+kRCcneAdp+yW/5NNDCZ6Gb2mZ+VJrw=='; -end; -$$; - --- Update the token creation function to use the new get_jwt_secret function -create or replace function create_service_account_token( - org_id uuid, - token_name text, - created_by uuid, - expires_at timestamp with time zone default null -) -returns table ( - token_id uuid, - jwt_token text -) -language plpgsql -security definer -set search_path = public -as $$ -declare - v_token_id uuid; - v_jwt_secret text; - v_jwt_token text; -begin - -- Get JWT secret - v_jwt_secret := get_jwt_secret(); - - if v_jwt_secret is null then - raise exception 'JWT secret not configured'; - end if; - - -- Create token record - insert into tokens (organization_id, name, expires_at, created_by) - values (org_id, token_name, expires_at, created_by) - returning id into v_token_id; - - -- Create JWT token with service account claims - v_jwt_token := extensions.sign( - json_build_object( - 'role', 'authenticated', - 'iss', 'supabase', - 'iat', extract(epoch from now())::integer, - 'exp', case - when expires_at is null then extract(epoch from now() + interval '10 years')::integer - else extract(epoch from expires_at)::integer - end, - 'is_service_account', true, - 'organization_id', org_id, - 'token_id', v_token_id - )::json, - v_jwt_secret - ); - - -- Update token hash - update tokens - set token_hash = v_jwt_token - where id = v_token_id; - - return query select v_token_id, v_jwt_token; -end; -$$; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000010_fix_token_validation.sql b/supabase/migrations_dev/20241219000010_fix_token_validation.sql deleted file mode 100644 index 0ccd0e5..0000000 --- a/supabase/migrations_dev/20241219000010_fix_token_validation.sql +++ /dev/null @@ -1,221 +0,0 @@ --- First drop all policies that use these functions -drop policy if exists "Enable read access for organization members" on organizations; -drop policy if exists "Enable write access for organization admins" on organizations; -drop policy if exists "Enable read access for members" on organization_members; -drop policy if exists "Enable write access for admins" on organization_members; -drop policy if exists "Enable access for organization members" on jobs; -drop policy if exists "Enable access for organization members" on runs; -drop policy if exists "Enable access for organization members" on worker; -drop policy if exists "Enable access for organization members" on events; -drop policy if exists "Enable read for members" on third_party_api_keys; -drop policy if exists "Enable write for admins" on third_party_api_keys; - --- Now we can safely drop the functions -drop function if exists get_organization_from_token(); -drop function if exists has_organization_access(uuid); -drop function if exists is_organization_member(uuid); -drop function if exists is_organization_admin(uuid); - --- Function to update token last used timestamp -create or replace function update_token_last_used(token_id uuid) -returns void -language plpgsql -security definer -as $$ -begin - -- Update last_used_at in a separate transaction - update tokens - set last_used_at = now() - where id = token_id; -exception - when others then - -- Ignore any errors during update - null; -end; -$$; - --- Function to get organization ID from token -create or replace function get_organization_from_token() -returns uuid -language plpgsql -security definer -as $$ -declare - auth_token text; - custom_token text; - org_id uuid; - token_id uuid; -begin - -- First try custom token header - custom_token := coalesce( - current_setting('request.headers', true)::json->>'x-openweights-token', - '' - ); - - if custom_token != '' and custom_token like 'ow_%' then - -- Look up organization ID from tokens table - select t.organization_id, t.id into org_id, token_id - from tokens t - where t.token_hash = custom_token - and (t.expires_at is null or t.expires_at > now()); - - if found then - -- Try to update last_used_at in background - perform update_token_last_used(token_id); - return org_id; - end if; - end if; - - -- If no custom token, try normal auth - auth_token := coalesce( - current_setting('request.headers', true)::json->>'authorization', - '' - ); - - -- Extract token from "Bearer " - auth_token := replace(auth_token, 'Bearer ', ''); - - -- If it's a valid JWT, auth.uid() will work - if auth_token != '' then - -- Return organization ID from membership - select organization_id into org_id - from organization_members - where user_id = auth.uid() - limit 1; - - return org_id; - end if; - - -- No valid token found - return null; -end; -$$; - --- Function to check if current token has access to an organization -create or replace function has_organization_access(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -declare - token_org_id uuid; -begin - -- First check custom tokens - token_org_id := get_organization_from_token(); - if token_org_id is not null then - return token_org_id = org_id; - end if; - - -- Fall back to checking organization membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - ); -end; -$$; - --- Update existing functions to use the new logic -create or replace function is_organization_member(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -begin - return has_organization_access(org_id); -end; -$$; - -create or replace function is_organization_admin(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -declare - custom_token text; -begin - -- First try custom token header - custom_token := coalesce( - current_setting('request.headers', true)::json->>'x-openweights-token', - '' - ); - - -- API tokens have admin access - if custom_token != '' and custom_token like 'ow_%' then - return has_organization_access(org_id); - end if; - - -- Otherwise check normal admin membership - return exists ( - select 1 - from organization_members - where organization_id = org_id - and user_id = auth.uid() - and role = 'admin' - ); -end; -$$; - --- Recreate policies with explicit table references -create policy "Enable read access for organization members" - on organizations for select - using (is_organization_member(id)); - -create policy "Enable write access for organization admins" - on organizations for all - using (is_organization_admin(id)); - -create policy "Enable read access for members" - on organization_members for select - using (is_organization_member(organization_id)); - -create policy "Enable write access for admins" - on organization_members for all - using (is_organization_admin(organization_id)); - -create policy "Enable access for organization members" - on jobs for all - using (is_organization_member(organization_id)); - -create policy "Enable access for organization members" - on runs for all - using ( - exists ( - select 1 - from jobs j - where j.id = runs.job_id - and is_organization_member(j.organization_id) - ) - ); - -create policy "Enable access for organization members" - on worker for all - using (is_organization_member(organization_id)); - -create policy "Enable access for organization members" - on events for all - using ( - exists ( - select 1 - from runs r - join jobs j on j.id = r.job_id - where r.id = events.run_id - and is_organization_member(j.organization_id) - ) - ); - -create policy "Enable read for members" - on third_party_api_keys for select - using (is_organization_member(organization_id)); - -create policy "Enable write for admins" - on third_party_api_keys for all - using (is_organization_admin(organization_id)); - --- Grant execute permissions -grant execute on function update_token_last_used(uuid) to postgres, authenticated, anon; -grant execute on function get_organization_from_token() to postgres, authenticated, anon; -grant execute on function has_organization_access(uuid) to postgres, authenticated, anon; -grant execute on function is_organization_member(uuid) to postgres, authenticated, anon; -grant execute on function is_organization_admin(uuid) to postgres, authenticated, anon; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000011_fix_job_policies.sql b/supabase/migrations_dev/20241219000011_fix_job_policies.sql deleted file mode 100644 index 35df091..0000000 --- a/supabase/migrations_dev/20241219000011_fix_job_policies.sql +++ /dev/null @@ -1,22 +0,0 @@ --- Drop existing job policy -drop policy if exists "Enable access for organization members" on jobs; - --- Create separate policies for read and write operations -create policy "Organization members can read jobs" - on jobs for select - using (is_organization_member(organization_id)); - -create policy "Organization members can insert jobs" - on jobs for insert - with check ( - -- For new jobs, organization_id must match the user's organization - organization_id = get_organization_from_token() - ); - -create policy "Organization members can update their jobs" - on jobs for update - using (is_organization_member(organization_id)); - -create policy "Organization members can delete their jobs" - on jobs for delete - using (is_organization_member(organization_id)); \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000012_add_get_org_function.sql b/supabase/migrations_dev/20241219000012_add_get_org_function.sql deleted file mode 100644 index d162aa5..0000000 --- a/supabase/migrations_dev/20241219000012_add_get_org_function.sql +++ /dev/null @@ -1,31 +0,0 @@ --- Function to get organization ID from token -create or replace function get_organization_from_token() -returns uuid -language plpgsql -security definer -as $$ -declare - org_id uuid; -begin - -- If this is a service account token, get org from claims - if auth.check_if_service_account() then - org_id := (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid; - return org_id; - end if; - - -- Otherwise, get organization from membership - select organization_id into org_id - from organization_members - where user_id = auth.uid() - limit 1; - - if org_id is null then - raise exception 'No organization found for current user/token'; - end if; - - return org_id; -end; -$$; - --- Grant execute permissions -grant execute on function get_organization_from_token() to authenticated, anon; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000013_fix_storage_policies.sql b/supabase/migrations_dev/20241219000013_fix_storage_policies.sql deleted file mode 100644 index 0f67801..0000000 --- a/supabase/migrations_dev/20241219000013_fix_storage_policies.sql +++ /dev/null @@ -1,119 +0,0 @@ --- Drop existing storage policies -drop policy if exists "Organization members can read their files" on storage.objects; -drop policy if exists "Organization members can upload files" on storage.objects; -drop policy if exists "Organization members can delete their files" on storage.objects; - --- Enable RLS on storage.objects -alter table storage.objects enable row level security; - --- Function to extract organization ID from storage path -create or replace function storage.get_path_organization_id(path text) -returns uuid -language plpgsql -as $$ -declare - parts text[]; - org_id uuid; -begin - -- Split path into parts - parts := string_to_array(path, '/'); - - -- Check if path starts with 'organizations' - if parts[1] != 'organizations' then - return null; - end if; - - -- Try to convert second part to UUID - begin - org_id := parts[2]::uuid; - return org_id; - exception when others then - return null; - end; -end; -$$; - --- Policy for reading files: -create policy "Organization members can read their files" -on storage.objects for select -using ( - bucket_id = 'files' - and ( - -- Allow access to .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is a member - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = storage.get_path_organization_id(name) - and user_id = auth.uid() - ) - ) - ) -); - --- Policy for inserting files: -create policy "Organization members can upload files" -on storage.objects for insert -with check ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is a member - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = storage.get_path_organization_id(name) - and user_id = auth.uid() - ) - ) - ) -); - --- Policy for updating files: -create policy "Organization members can update their files" -on storage.objects for update -using ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is a member - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = storage.get_path_organization_id(name) - and user_id = auth.uid() - ) - ) - ) -); - --- Policy for deleting files: -create policy "Organization members can delete their files" -on storage.objects for delete -using ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user is an admin - name like 'organizations/%' - and exists ( - select 1 - from public.organization_members - where organization_id = storage.get_path_organization_id(name) - and user_id = auth.uid() - and role = 'admin' - ) - ) - ) -); \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000014_fix_storage_setup.sql b/supabase/migrations_dev/20241219000014_fix_storage_setup.sql deleted file mode 100644 index f4c0f44..0000000 --- a/supabase/migrations_dev/20241219000014_fix_storage_setup.sql +++ /dev/null @@ -1,141 +0,0 @@ --- Drop existing storage policies -drop policy if exists "Organization members can read their files" on storage.objects; -drop policy if exists "Organization members can upload files" on storage.objects; -drop policy if exists "Organization members can update their files" on storage.objects; -drop policy if exists "Organization members can delete their files" on storage.objects; - --- Enable RLS on storage.objects -alter table storage.objects enable row level security; - --- Function to extract organization ID from storage path -create or replace function storage.get_path_organization_id(path text) -returns uuid -language plpgsql -as $$ -declare - parts text[]; - org_id uuid; -begin - -- Split path into parts - parts := string_to_array(path, '/'); - - -- Check if path starts with 'organizations' - if parts[1] != 'organizations' then - return null; - end if; - - -- Try to convert second part to UUID - begin - org_id := parts[2]::uuid; - return org_id; - exception when others then - return null; - end; -end; -$$; - --- Function to check if user has access to organization -create or replace function storage.has_organization_access(org_id uuid) -returns boolean -language plpgsql -security definer -as $$ -begin - -- First check custom tokens - if auth.jwt() ? 'organization_id' then - return (auth.jwt()->>'organization_id')::uuid = org_id; - end if; - - -- Otherwise check organization membership - return exists ( - select 1 - from public.organization_members - where organization_id = org_id - and user_id = auth.uid() - ); -end; -$$; - --- Policy for reading files: -create policy "Organization members can read their files" -on storage.objects for select -using ( - bucket_id = 'files' - and ( - -- Allow access to .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user has access - name like 'organizations/%' - and storage.has_organization_access(storage.get_path_organization_id(name)) - ) - ) -); - --- Policy for inserting files: -create policy "Organization members can upload files" -on storage.objects for insert -with check ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user has access - name like 'organizations/%' - and storage.has_organization_access(storage.get_path_organization_id(name)) - ) - ) -); - --- Policy for updating files: -create policy "Organization members can update their files" -on storage.objects for update -using ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user has access - name like 'organizations/%' - and storage.has_organization_access(storage.get_path_organization_id(name)) - ) - ) -); - --- Policy for deleting files: -create policy "Organization members can delete their files" -on storage.objects for delete -using ( - bucket_id = 'files' - and ( - -- Allow .keep files - name like '%.keep' - or ( - -- Check if file is in an organization folder and user has access - name like 'organizations/%' - and storage.has_organization_access(storage.get_path_organization_id(name)) - ) - ) -); - --- Create organization folders -do $$ -declare - org record; -begin - -- Create organizations folder - insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version) - values ('files', 'organizations/.keep', auth.uid(), now(), now(), '1') - on conflict do nothing; - - -- Create folder for each organization - for org in select id from public.organizations - loop - insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version) - values ('files', 'organizations/' || org.id || '/.keep', auth.uid(), now(), now(), '1') - on conflict do nothing; - end loop; -end; -$$; \ No newline at end of file diff --git a/supabase/migrations_dev/20241219000015_fix_run_policies.sql b/supabase/migrations_dev/20241219000015_fix_run_policies.sql deleted file mode 100644 index 42941a8..0000000 --- a/supabase/migrations_dev/20241219000015_fix_run_policies.sql +++ /dev/null @@ -1,75 +0,0 @@ --- Drop existing run policy -drop policy if exists "Enable access for organization members" on runs; - --- Create separate policies for read and write operations -create policy "Organization members can read runs" - on runs for select - using ( - exists ( - select 1 - from jobs j - where j.id = runs.job_id - and ( - is_organization_member(j.organization_id) - or - ( - auth.jwt() ? 'is_service_account' - and (auth.jwt()->>'organization_id')::uuid = j.organization_id - ) - ) - ) - ); - -create policy "Organization members can insert runs" - on runs for insert - with check ( - exists ( - select 1 - from jobs j - where j.id = job_id - and ( - is_organization_member(j.organization_id) - or - ( - auth.jwt() ? 'is_service_account' - and (auth.jwt()->>'organization_id')::uuid = j.organization_id - ) - ) - ) - ); - -create policy "Organization members can update runs" - on runs for update - using ( - exists ( - select 1 - from jobs j - where j.id = runs.job_id - and ( - is_organization_member(j.organization_id) - or - ( - auth.jwt() ? 'is_service_account' - and (auth.jwt()->>'organization_id')::uuid = j.organization_id - ) - ) - ) - ); - -create policy "Organization members can delete runs" - on runs for delete - using ( - exists ( - select 1 - from jobs j - where j.id = runs.job_id - and ( - is_organization_member(j.organization_id) - or - ( - auth.jwt() ? 'is_service_account' - and (auth.jwt()->>'organization_id')::uuid = j.organization_id - ) - ) - ) - ); \ No newline at end of file diff --git a/supabase/seed.sql b/supabase/seed.sql new file mode 100644 index 0000000..d7cd575 --- /dev/null +++ b/supabase/seed.sql @@ -0,0 +1,6 @@ +-- Ensure 'files' bucket exists in the shadow DB used by `supabase db pull`. +-- This runs before migrations during local/shadow setup and is a NOOP on subsequent runs. + +insert into storage.buckets (id, name, public) +values ('files', 'files', false) +on conflict (id) do nothing; \ No newline at end of file