diff --git a/.dockerignore b/.dockerignore
index 2d1c758..c4346f7 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -5,4 +5,91 @@ cache/
.llm-cache/
**/.llm-cache/
node_modules/
-**/node_modules/
\ No newline at end of file
+**/node_modules/
+logs/
+openweights/dashboard/backend/backend.log
+openweights/dashboard/backend/backend.pid
+build-docker-in-runpod
+.env
+.env.dev
+.env.prod
+.env.ow-dev
+.env.ow-migrations
+openweights/dashboard/backend/.env.ow-dev
+openweights/dashboard/backend/.env.ow-migrations
+openweights/dashboard/frontend/.env.ow-dev
+openweights/dashboard/frontend/.env.ow-migrations
+artifacts/
+debug/
+example/.ipynb_checkpoints/
+example/Untitled1.ipynb
+dev.py
+planb/
+vulnerable-code/
+openweights/client/.llm-cache
+# yeaa
+cache
+# Bazel
+/bazel-*
+/bazel-bin
+/bazel-genfiles
+/bazel-out
+/bazel-testlogs
+/bazel-workspace
+
+# Bazel symlinks
+/bazel-*
+
+# Bazel disk cache
+.bazel-cache/
+
+# Bazel IntelliJ plugin
+.ijwb/
+
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual Environment
+venv/
+env/
+ENV/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS
+.DS_Store
+Thumbs.db
+
+example/ft_job_artifacts/
+example/mcq_dataset.jsonl
+openweights/jobs/unsloth/logp.ipynb
+
+openweights/dashboard/backend/static/
+example/_*
+.logs
+openweights/jobs/unsloth/check.ipynb
+.cache
+.logs/
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 7ef80c8..5f9402e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -83,4 +83,5 @@ openweights/dashboard/backend/static/
example/_*
.logs
openweights/jobs/unsloth/check.ipynb
-.cache
\ No newline at end of file
+.cache
+admin/
\ No newline at end of file
diff --git a/docs/ttl.md b/docs/ttl.md
new file mode 100644
index 0000000..1ab5c8b
--- /dev/null
+++ b/docs/ttl.md
@@ -0,0 +1,82 @@
+# TTL (Time To Live) Feature
+
+The TTL feature provides automatic pod termination to prevent runaway costs and ensure resource cleanup.
+
+## Overview
+
+- **Default TTL**: 24 hours for all pods
+- **Automatic termination**: Pods self-terminate when TTL expires
+- **Extensible**: TTL can be extended from within the pod
+- **Dev mode support**: TTL monitoring runs for both dev and worker instances
+
+## Usage
+
+### Starting pods with custom TTL
+
+```bash
+# Start dev instance with default 24-hour TTL
+python openweights/cluster/start_runpod.py A100 default --dev_mode=true
+
+# Start dev instance with 2-hour TTL
+python openweights/cluster/start_runpod.py A100 default --dev_mode=true --ttl_hours=2
+
+# Start worker with 12-hour TTL
+python openweights/cluster/start_runpod.py A100 finetuning --ttl_hours=12
+```
+
+### Managing TTL from within a pod
+
+Once inside a pod, use the TTL manager utility:
+
+```bash
+# Check current TTL status
+python openweights/worker/services/ttl_manager.py --check
+
+# Extend TTL by 5 more hours
+python openweights/worker/services/ttl_manager.py --extend 5
+
+# Set TTL to 10 hours from now
+python openweights/worker/services/ttl_manager.py --set 10
+```
+
+### Manual TTL management
+
+You can also manually update the TTL by editing `~/shutdown.txt`:
+
+```bash
+python3 -c "
+import datetime
+with open('~/shutdown.txt', 'w') as f:
+ new_time = datetime.datetime.now() + datetime.timedelta(hours=48)
+ f.write(new_time.isoformat())
+print(f'TTL extended to {new_time}')
+"
+```
+
+## How it works
+
+1. **TTL Setup**: When a pod starts, the TTL monitor service calculates the shutdown time and writes it to `~/shutdown.txt`
+2. **Monitoring**: A background service checks the shutdown time every minute
+3. **Termination**: When the current time exceeds the shutdown time, the service terminates the pod using the RunPod API
+4. **Extension**: Jobs or users can extend the TTL by updating the shutdown time in the file
+
+## Architecture
+
+- **TTL Monitor Service**: `openweights/worker/services/ttl_monitor.py`
+- **TTL Manager Utility**: `openweights/worker/services/ttl_manager.py`
+- **Configuration**: TTL passed via `TTL_HOURS` environment variable
+- **Shutdown File**: `~/shutdown.txt` contains ISO format datetime
+
+## Environment Variables
+
+- `TTL_HOURS`: Number of hours for TTL (default: 24)
+- `RUNPOD_API_KEY`: RunPod API key for pod termination
+- `OW_DEV`: Indicates if running in dev mode (affects other services, not TTL)
+
+## Notes
+
+- TTL monitoring runs for both dev and worker instances
+- This provides an additional safety net especially for dev instances
+- Pod ID is automatically detected from RunPod metadata API
+- Failed termination attempts are retried every minute
+- TTL can be reset/extended unlimited times before expiration
\ No newline at end of file
diff --git a/entrypoint.sh b/entrypoint.sh
index 084607c..812023b 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -13,78 +13,44 @@ if [ -n "$PUBLIC_KEY" ]; then
else
echo "[$(date)] No PUBLIC_KEY provided, skipping SSH key setup"
fi
-echo "Authorized keys added."
-# if OW_COMMIT is set, checkout the commit
+# Repository checkout if needed
echo "[$(date)] Checking for OW_COMMIT environment variable"
if [ -n "$OW_COMMIT" ]; then
- echo "[$(date)] Starting repository checkout for commit: $OW_COMMIT"
- rm -rf openweights
- git clone https://github.com/longtermrisk/openweights.git openweights_dev
- cd openweights_dev
- git checkout $OW_COMMIT
- mv openweights ../openweights
- cd ..
- rm -rf openweights_dev
- echo "[$(date)] Repository checkout completed"
+ echo "[$(date)] Starting repository checkout"
+ python3 openweights/worker/services/checkout.py
+ if [ $? -ne 0 ]; then
+ echo "[$(date)] Repository checkout failed"
+ exit 1
+ fi
else
echo "[$(date)] No OW_COMMIT specified, skipping repository checkout"
fi
# Login to huggingface
echo "[$(date)] Attempting to login to Hugging Face"
-python3 -c "from huggingface_hub.hf_api import HfFolder; import os; HfFolder.save_token(os.environ['HF_TOKEN'])"
+python3 openweights/worker/services/hf_login.py
echo "[$(date)] Hugging Face login completed"
-# Generate SSH host keys
-echo "[$(date)] Generating SSH host keys"
+# Generate SSH host keys and start SSH service
+echo "[$(date)] Setting up SSH service"
ssh-keygen -A
-echo "[$(date)] SSH host keys generated"
-
-# Start SSH service
-echo "[$(date)] Starting SSH service"
service ssh start
echo "[$(date)] SSH service started"
# Print sshd logs to stdout
tail -f /var/log/auth.log &
-# Start a simple server that serves the content of main.log on port 10101
-# Create main.log if it doesn't exist
-touch main.log
-
-# Start a simple Python HTTP server to serve files from logs/
+# Start background services
echo "[$(date)] Starting HTTP log server on port 10101"
-python3 -c '
-import http.server
-import socketserver
-import os
+mkdir logs
+python3 openweights/worker/services/log_server.py &
-class LogHandler(http.server.SimpleHTTPRequestHandler):
- def do_GET(self):
- # If path is /logs, serve logs/main
- if self.path == "/logs":
- file_path = "logs/main"
- else:
- # Remove leading slash and ensure path is within logs directory
- path = self.path.lstrip("/")
- file_path = os.path.join("logs", path)
-
- # Check if file exists and is within logs directory
- if os.path.exists(file_path) and os.path.commonprefix([os.path.abspath(file_path), os.path.abspath("logs")]) == os.path.abspath("logs"):
- self.send_response(200)
- self.send_header("Content-type", "text/plain")
- self.end_headers()
- with open(file_path, "rb") as f:
- self.wfile.write(f.read())
- else:
- self.send_error(404, "File not found")
+# Start TTL monitoring service
+echo "[$(date)] Starting TTL monitoring service"
+python3 openweights/worker/services/ttl_monitor.py &
-with socketserver.TCPServer(("", 10101), LogHandler) as httpd:
- httpd.serve_forever()
-' &
-
-echo "[$(date)] HTTP log server started"
+echo "[$(date)] All services started"
# Execute the main application or run in dev mode
if [ "$OW_DEV" = "true" ]; then
@@ -92,5 +58,5 @@ if [ "$OW_DEV" = "true" ]; then
exec tail -f /dev/null
else
echo "[$(date)] Starting main application"
- exec python3 openweights/worker/main.py \ > >(tee logs/main) \ 2> >(tee -a logs/main >&2)
-fi
+ exec python3 openweights/worker/main.py > >(tee logs/main) 2> >(tee -a logs/main >&2)
+fi
\ No newline at end of file
diff --git a/example/ui.py b/example/ui.py
new file mode 100644
index 0000000..31e8474
--- /dev/null
+++ b/example/ui.py
@@ -0,0 +1,29 @@
+import gradio as gr # type: ignore
+from openai import OpenAI # type: ignore
+
+def chat_with(model):
+ client = OpenAI(base_url="https://ag5a2je35kxz7y-8000.proxy.runpod.net/v1")
+ def predict(message, history):
+ messages = []
+ for human, assistant in history:
+ messages.append({"role": "user", "content": human})
+ messages.append({"role": "assistant", "content": assistant})
+ messages.append({"role": "user", "content": message})
+
+ stream = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ stream=True
+ )
+
+ partial_message = ""
+ for chunk in stream:
+ if chunk.choices[0].delta.content is not None:
+ partial_message += chunk.choices[0].delta.content
+ yield partial_message
+
+ gr.ChatInterface(predict).queue().launch()
+
+
+if __name__ == '__main__':
+ chat_with('Qwen/Qwen3-235B-A22B-Instruct-2507-FP8')
\ No newline at end of file
diff --git a/openweights/client/__init__.py b/openweights/client/__init__.py
index 53506fa..57819c1 100644
--- a/openweights/client/__init__.py
+++ b/openweights/client/__init__.py
@@ -1,16 +1,6 @@
-import asyncio
-import atexit
-import json
-from typing import Optional, BinaryIO, Dict, Any, List, Union
+from typing import Optional, Dict, Any
import os
-import sys
-from postgrest.exceptions import APIError
-import hashlib
-from datetime import datetime
-from openai import OpenAI, AsyncOpenAI
-import backoff
-import time
-from supabase import create_client, Client
+from supabase import create_client
from supabase.lib.client_options import ClientOptions
from openweights.client.files import Files, validate_messages, validate_preference_dataset
@@ -20,7 +10,13 @@
from openweights.client.temporary_api import TemporaryApi
from openweights.client.chat import ChatCompletions, AsyncChatCompletions
from openweights.client.utils import group_models_or_adapters_by_model, get_lora_rank
+from openweights.client.decorators import supabase_retry
+import logging
+
+# Reduce noise to only warnings+errors
+for name in ["httpx", "httpx._client", "postgrest", "gotrue", "supabase"]:
+ logging.getLogger(name).setLevel(logging.WARNING)
def create_authenticated_client(supabase_url: str, supabase_anon_key: str, auth_token: Optional[str] = None):
"""Create a Supabase client with authentication.
@@ -114,7 +110,7 @@ def __init__(self,
setattr(self, name, cls(self))
OpenWeights._INSTANCES.append(self)
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def get_organization_id(self) -> str:
"""Get the organization ID associated with the current token"""
result = self._supabase.rpc('get_organization_from_token').execute()
@@ -122,7 +118,7 @@ def get_organization_id(self) -> str:
raise ValueError("Could not determine organization ID from token")
return result.data
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def get_organization_name(self):
"""Get the organization ID associated with the current token"""
result = self._supabase.table('organizations')\
@@ -131,7 +127,7 @@ def get_organization_name(self):
.single().execute()
return result.data['name']
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def get_hf_org(self):
"""Get organization secrets from the database."""
result = self._supabase.table('organization_secrets')\
diff --git a/openweights/client/cache_on_disk.py b/openweights/client/cache_on_disk.py
deleted file mode 100644
index 6c291cf..0000000
--- a/openweights/client/cache_on_disk.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import asyncio
-import hashlib
-import json
-import os
-from functools import wraps
-
-import diskcache as dc
-
-class CacheOnDisk:
- def __init__(self, n_semaphore=100, cache_dir=None):
- """
- Create a CacheOnDisk instance.
-
- Parameters:
- n_semaphore (int): Maximum number of parallel cache accesses.
- cache_dir (str): Path to the cache directory. Defaults to a ".llm-cache"
- directory alongside this file.
- """
- if cache_dir is None:
- cache_dir = os.path.join(os.path.dirname(__file__), ".llm-cache")
- os.makedirs(cache_dir, exist_ok=True)
- self.cache = dc.FanoutCache(cache_dir, shards=64, timeout=10)
- self.semaphore = asyncio.Semaphore(n_semaphore)
-
- def __call__(self, possible_func=None, *, required_kwargs=None):
- """
- When used as a decorator, CacheOnDisk works in two ways:
-
- 1. As a no-argument decorator:
- @cache_on_disk
- async def my_func(...): ...
-
- 2. As a parameterized decorator:
- @cache_on_disk(required_kwargs=["foo"])
- async def my_func(...): ...
-
- The `required_kwargs` parameter (a list) determines which keyword
- arguments are needed for caching. If they are not present, the function
- is simply executed.
- """
- if possible_func is not None and callable(possible_func):
- # Used as "@cache_on_disk" without explicit parameters.
- return self._make_decorator(required_kwargs or [])(possible_func)
- else:
- # Used as "@cache_on_disk(required_kwargs=[...])". Return a decorator.
- required_kwargs = required_kwargs or []
- def decorator(func):
- return self._make_decorator(required_kwargs)(func)
- return decorator
-
- def _make_decorator(self, required_kwargs):
- def decorator(function):
- @wraps(function)
- async def wrapper(*args, **kwargs):
- # Only attempt caching if all required keyword arguments are present.
- if not all(k in kwargs for k in required_kwargs):
- return await function(*args, **kwargs)
-
- # Serialize args/kwargs and compute the cache key.
- serialized = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True)
- key = hashlib.sha256(serialized.encode()).hexdigest()
-
- # Limit the number of concurrent cache accesses.
- async with self.semaphore:
- cached_result = await asyncio.to_thread(self.cache.get, key, None)
- if cached_result is not None:
- return cached_result
-
- result = await function(*args, **kwargs)
-
- async with self.semaphore:
- await asyncio.to_thread(self.cache.set, key, result)
- return result
- return wrapper
- return decorator
-
-# Create a default object for easy importing.
-cache_on_disk = CacheOnDisk()
\ No newline at end of file
diff --git a/openweights/client/chat.py b/openweights/client/chat.py
index b91a66b..1fd809a 100644
--- a/openweights/client/chat.py
+++ b/openweights/client/chat.py
@@ -1,10 +1,9 @@
from collections import defaultdict
import openai
import asyncio
-from openweights.client.cache_on_disk import cache_on_disk
-import backoff
from openweights.client.temporary_api import TemporaryApi, APIS
import openai
+from openweights.client.decorators import openai_retry
DEPLOYMENT_QUEUE = []
@@ -31,7 +30,6 @@ def __init__(self, ow, deploy_kwargs={}, request_timeout=300, per_token_timeout=
self.per_token_timeout = per_token_timeout
async def create(self, model: str, **kwargs):
- # @cache_on_disk(required_kwargs=['seed'])
async def cached_create(model, **kwargs):
return await self._create(model, **kwargs)
return await cached_create(model, **kwargs)
@@ -41,18 +39,7 @@ async def _create(self, model, **kwargs):
async with api.sem:
return await self._create_with_backoff(api, model, **kwargs)
- @backoff.on_exception(
- wait_gen=backoff.expo,
- exception=(
- openai.RateLimitError,
- openai.APIConnectionError,
- openai.APITimeoutError,
- openai.InternalServerError
- ),
- max_value=60,
- factor=1.5,
- max_tries=10
- )
+ @openai_retry()
async def _create_with_backoff(self, api, model, **kwargs):
timeout = kwargs.pop('timeout', None) or max(
self.request_timeout,
diff --git a/openweights/client/decorators.py b/openweights/client/decorators.py
new file mode 100644
index 0000000..93a8057
--- /dev/null
+++ b/openweights/client/decorators.py
@@ -0,0 +1,202 @@
+import backoff
+import httpx
+import openai
+import functools
+import time
+import random
+from typing import Optional
+
+# Optional: postgrest is used under the hood by supabase-py; handle if present
+try:
+ import postgrest
+except Exception: # pragma: no cover
+ postgrest = None
+
+
+def _is_transient_http_status(status: int) -> bool:
+ # Retry on server errors & rate limiting
+ return status >= 500 or status == 429
+
+
+def _is_transient(exc: BaseException) -> bool:
+ """
+ Returns True for errors that are likely to be temporary network/service hiccups:
+ - httpx timeouts / connect errors / protocol errors
+ - HTTPStatusError with 5xx or 429
+ - postgrest.APIError with 5xx or 429 (if postgrest available)
+ """
+ # httpx family (network-ish)
+ if isinstance(exc, (
+ httpx.ConnectError,
+ httpx.ConnectTimeout,
+ httpx.ReadTimeout,
+ httpx.WriteTimeout,
+ httpx.PoolTimeout,
+ httpx.NetworkError,
+ httpx.RemoteProtocolError,
+ )):
+ return True
+
+ # httpx raised because .raise_for_status() was called
+ if isinstance(exc, httpx.HTTPStatusError):
+ try:
+ return _is_transient_http_status(exc.response.status_code)
+ except Exception:
+ return False
+
+ # postgrest API errors (supabase-py)
+ if postgrest is not None and isinstance(exc, postgrest.APIError):
+ try:
+ code = getattr(exc, "code", None)
+ # code may be a string; try to coerce
+ code_int = int(code) if code is not None else None
+ return code_int is not None and _is_transient_http_status(code_int)
+ except Exception:
+ return False
+
+ # Sometimes libraries wrap the real error; walk the causal chain
+ cause = getattr(exc, "__cause__", None) or getattr(exc, "__context__", None)
+ if cause and cause is not exc:
+ return _is_transient(cause)
+
+ return False
+
+
+def openai_retry(
+ *,
+ # Exponential mode (default)
+ factor: float = 1.5,
+ max_value: int = 60,
+ # Constant mode (set interval to enable)
+ interval: Optional[float] = None,
+ max_time: Optional[float] = None,
+ # Common
+ max_tries: int = 10,
+):
+ """
+ Retry transient OpenAI API errors with backoff + jitter.
+
+ Modes:
+ • Exponential (default): pass `factor`, `max_value`, `max_tries`
+ • Constant: pass `interval` (seconds) and optionally `max_time`, `max_tries`
+
+ Examples:
+ @openai_retry() # exponential (default)
+ def call(...): ...
+
+ @openai_retry(interval=10, max_time=3600, max_tries=3600) # constant
+ def call(...): ...
+
+ Retries on:
+ - openai.RateLimitError
+ - openai.APIConnectionError
+ - openai.APITimeoutError
+ - openai.InternalServerError
+ """
+ exceptions = (
+ openai.RateLimitError,
+ openai.APIConnectionError,
+ openai.APITimeoutError,
+ openai.InternalServerError,
+ )
+
+ def _decorator(fn):
+ if interval is not None:
+ # Constant backoff mode
+ decorated = backoff.on_exception(
+ wait_gen=backoff.constant,
+ exception=exceptions,
+ interval=interval,
+ max_time=max_time, # total wall-clock cap (optional)
+ max_tries=max_tries, # total attempts cap
+ jitter=backoff.full_jitter,
+ logger=None, # stay quiet
+ )(fn)
+ else:
+ # Exponential backoff mode
+ decorated = backoff.on_exception(
+ wait_gen=backoff.expo,
+ exception=exceptions,
+ factor=factor, # growth factor
+ max_value=max_value, # cap per-wait
+ max_tries=max_tries, # total attempts cap
+ jitter=backoff.full_jitter,
+ logger=None, # stay quiet
+ )(fn)
+
+ @functools.wraps(fn)
+ def inner(*args, **kwargs):
+ return decorated(*args, **kwargs)
+
+ return inner
+
+ return _decorator
+
+
+# sentinel to indicate "raise on exhaustion"
+_RAISE = object()
+
+def supabase_retry(
+ max_time: float = 60,
+ max_tries: int = 8,
+ *,
+ base: float = 1.0, # initial delay
+ factor: float = 2.0, # exponential growth
+ max_delay: float = 60.0, # cap for each delay step
+ return_on_exhaustion=_RAISE # e.g., set to None to "ignore" after retries
+):
+ """
+ Retries ONLY transient Supabase/http errors (see _is_transient) with exponential backoff + full jitter.
+ If `return_on_exhaustion` is not `_RAISE`, return that value after retry budget is exhausted for a
+ transient error. Non-transient errors still raise immediately.
+
+ Args:
+ max_time: maximum total wall-clock seconds spent retrying
+ max_tries: maximum attempts (including the first)
+ base: initial backoff delay (seconds)
+ factor: exponential growth factor per attempt (>= 1)
+ max_delay: max per-attempt delay (seconds)
+ return_on_exhaustion: value to return after exhausting retries on a transient error.
+ Leave as `_RAISE` to re-raise instead.
+ """
+ def _next_delay(attempt: int) -> float:
+ # attempt starts at 1 for the first retry (after one failure)
+ raw = base * (factor ** (attempt - 1))
+ return min(raw, max_delay) * random.random() # full jitter
+
+ def _decorator(fn):
+ @functools.wraps(fn)
+ def inner(*args, **kwargs):
+ # quick path: try once
+ start = time.monotonic()
+ attempt = 0
+ while True:
+ try:
+ return fn(*args, **kwargs)
+ except Exception as exc:
+ # Non-transient? bubble up immediately.
+ if not _is_transient(exc):
+ raise
+
+ attempt += 1
+ # Have we exhausted attempts?
+ if attempt >= max_tries:
+ if return_on_exhaustion is _RAISE:
+ raise
+ return return_on_exhaustion
+
+ # Compute delay with jitter, ensure we don't break max_time
+ delay = _next_delay(attempt)
+ if max_time is not None:
+ elapsed = time.monotonic() - start
+ remaining = max_time - elapsed
+ if remaining <= 0:
+ if return_on_exhaustion is _RAISE:
+ raise
+ return return_on_exhaustion
+ # don't sleep past the deadline
+ delay = min(delay, max(0.0, remaining))
+
+ time.sleep(delay)
+ return inner
+ return _decorator
diff --git a/openweights/client/events.py b/openweights/client/events.py
index d4bdfd1..8aafaca 100644
--- a/openweights/client/events.py
+++ b/openweights/client/events.py
@@ -1,10 +1,11 @@
from typing import Optional, Dict, Any, List
-
+from openweights.client.decorators import supabase_retry
class Events():
def __init__(self, supabase):
self._supabase = supabase
+ @supabase_retry()
def list(self, job_id: Optional[str]=None, run_id: Optional[str]=None):
"""List events by job_id or run_id, sorted by created_at in ascending order"""
if run_id:
diff --git a/openweights/client/files.py b/openweights/client/files.py
index 1b4dd29..05762ea 100644
--- a/openweights/client/files.py
+++ b/openweights/client/files.py
@@ -1,12 +1,14 @@
from typing import Optional, BinaryIO, Dict, Any, List, Union
import os
import hashlib
+import io
+import tempfile
from datetime import datetime
from supabase import Client
import backoff
import json
import logging
-
+from openweights.client.decorators import supabase_retry
def validate_message(message):
try:
@@ -72,14 +74,17 @@ def __init__(self, supabase: Client, organization_id: str):
self._supabase = supabase
self._org_id = organization_id
- def _calculate_file_hash(self, file: BinaryIO) -> str:
+ def _calculate_file_hash(self, stream: BinaryIO) -> str:
"""Calculate SHA-256 hash of file content"""
sha256_hash = hashlib.sha256()
- for byte_block in iter(lambda: file.read(4096), b""):
+ for byte_block in iter(lambda: stream.read(4096), b""):
sha256_hash.update(byte_block)
# Add the org ID to the hash to ensure uniqueness
sha256_hash.update(self._org_id.encode())
- file.seek(0) # Reset file pointer
+ try:
+ stream.seek(0)
+ except Exception:
+ pass
return f"file-{sha256_hash.hexdigest()[:12]}"
def _get_storage_path(self, file_id: str) -> str:
@@ -94,45 +99,75 @@ def _get_storage_path(self, file_id: str) -> str:
# Fallback if RPC fails
return f"organizations/{self._org_id}/{file_id}"
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def create(self, file: BinaryIO, purpose: str) -> Dict[str, Any]:
- """Upload a file and create a database entry"""
- file.seek(0)
- file_id = f"{purpose}:{self._calculate_file_hash(file)}"
+ """Upload a file and create a database entry.
+ Robust to retries by buffering the input stream into memory once
+ and using fresh BytesIO streams for hashing, validation, and upload.
+ """
+ # Read all bytes once; support both real files and file-like objects
+ try:
+ # Ensure at start (some callers might pass a consumed stream)
+ if hasattr(file, 'seek'):
+ try:
+ file.seek(0)
+ except Exception:
+ pass
+ data = file.read()
+ finally:
+ # Do not close the caller's file handle; just leave it as-is
+ # (the caller used a context manager typically)
+ pass
+
+ if not isinstance(data, (bytes, bytearray)):
+ raise TypeError("Files.create expects a binary file-like object returning bytes")
+
+ file_id = f"{purpose}:{self._calculate_file_hash(io.BytesIO(data))}"
# If the file already exists, return the existing file
try:
existing_file = self._supabase.table('files').select('*').eq('id', file_id).single().execute().data
if existing_file:
return existing_file
- except:
+ except Exception:
pass # File doesn't exist yet, continue with creation
- # Validate file content
- if not self.validate(file, purpose):
+ # Validate file content using a fresh buffer
+ if not self.validate(io.BytesIO(data), purpose):
raise ValueError("File content is not valid")
- file_size = os.fstat(file.fileno()).st_size
+ file_size = len(data)
filename = getattr(file, 'name', 'unknown')
# Get organization-specific storage path
storage_path = self._get_storage_path(file_id)
# Store file in Supabase Storage with organization path
- self._supabase.storage.from_('files').upload(
- path=storage_path,
- file=file
- )
-
+ # storage3's sync client expects a file path-like in some versions; write to a temp file for compatibility
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
+ tmp.write(data)
+ tmp.flush()
+ tmp_path = tmp.name
+ try:
+ self._supabase.storage.from_('files').upload(
+ path=storage_path,
+ file=tmp_path,
+ file_options={"upsert": "true"}
+ )
+ finally:
+ try:
+ os.remove(tmp_path)
+ except Exception:
+ pass
# Create database entry
- data = {
+ data_row = {
'id': file_id,
'filename': filename,
'purpose': purpose,
'bytes': file_size
}
- result = self._supabase.table('files').insert(data).execute()
+ result = self._supabase.table('files').insert(data_row).execute()
return {
'id': file_id,
@@ -143,14 +178,14 @@ def create(self, file: BinaryIO, purpose: str) -> Dict[str, Any]:
'purpose': purpose,
}
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def content(self, file_id: str) -> bytes:
"""Get file content"""
storage_path = self._get_storage_path(file_id)
return self._supabase.storage.from_('files').download(storage_path)
def validate(self, file: BinaryIO, purpose: str) -> bool:
- """Validate file content"""
+ """Validate file content. The passed stream will be consumed."""
if purpose in ['conversations']:
content = file.read().decode('utf-8')
return validate_messages(content)
@@ -160,7 +195,7 @@ def validate(self, file: BinaryIO, purpose: str) -> bool:
else:
return True
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def get_by_id(self, file_id: str) -> Dict[str, Any]:
"""Get file details by ID"""
return self._supabase.table('files').select('*').eq('id', file_id).single().execute().data
diff --git a/openweights/client/jobs.py b/openweights/client/jobs.py
index 4071e16..7089cd6 100644
--- a/openweights/client/jobs.py
+++ b/openweights/client/jobs.py
@@ -1,18 +1,14 @@
-import re
import json
-from typing import BinaryIO, Dict, Any, List, Union, Tuple, Type
+from typing import Dict, Any, List, Type
import os
from postgrest.exceptions import APIError
-import backoff
import hashlib
-from supabase import Client
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
from datetime import datetime
from dataclasses import dataclass
-from openweights.client.utils import resolve_lora_model, get_lora_rank
from openweights.cluster.start_runpod import GPUs
-
+from openweights.client.decorators import supabase_retry
@dataclass
class Job:
@@ -133,14 +129,7 @@ def _upload_mounted_files(self, extra_files=None) -> Dict[str, str]:
return uploaded_files
- @backoff.on_exception(
- backoff.constant,
- Exception,
- interval=1,
- max_time=60,
- max_tries=60,
- on_backoff=lambda details: print(f"Retrying... {details['exception']}"),
- )
+ @supabase_retry()
def list(self, limit: int = 10) -> List[Dict[str, Any]]:
"""List jobs"""
result = (
@@ -152,14 +141,7 @@ def list(self, limit: int = 10) -> List[Dict[str, Any]]:
)
return [Job(**row, _manager=self) for row in result.data]
- @backoff.on_exception(
- backoff.constant,
- Exception,
- interval=1,
- max_time=60,
- max_tries=60,
- on_backoff=lambda details: print(f"Retrying... {details['exception']}"),
- )
+ @supabase_retry()
def retrieve(self, job_id: str) -> Dict[str, Any]:
"""Get job details"""
result = (
@@ -167,14 +149,7 @@ def retrieve(self, job_id: str) -> Dict[str, Any]:
)
return Job(**result.data, _manager=self)
- @backoff.on_exception(
- backoff.constant,
- Exception,
- interval=1,
- max_time=60,
- max_tries=60,
- on_backoff=lambda details: print(f"Retrying... {details['exception']}"),
- )
+ @supabase_retry()
def cancel(self, job_id: str) -> Dict[str, Any]:
"""Cancel a job"""
result = (
@@ -185,14 +160,7 @@ def cancel(self, job_id: str) -> Dict[str, Any]:
)
return Job(**result.data[0], _manager=self)
- @backoff.on_exception(
- backoff.constant,
- Exception,
- interval=1,
- max_time=60,
- max_tries=60,
- on_backoff=lambda details: print(f"Retrying... {details['exception']}"),
- )
+ @supabase_retry()
def restart(self, job_id: str) -> Dict[str, Any]:
"""Restart a job"""
result = (
@@ -220,14 +188,7 @@ def compute_id(self, data: Dict[str, Any]) -> str:
job_id += f"-{data['params']['validated_params']['job_id_suffix']}"
return job_id
- @backoff.on_exception(
- backoff.constant,
- Exception,
- interval=1,
- max_time=60,
- max_tries=60,
- on_backoff=lambda details: print(f"Retrying... {details['exception']}"),
- )
+ @supabase_retry()
def get_or_create_or_reset(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""If job exists and is [pending, in_progress, completed] return it.
If job exists and is [failed, canceled] reset it to pending and return it.
@@ -274,14 +235,7 @@ def get_or_create_or_reset(self, data: Dict[str, Any]) -> Dict[str, Any]:
else:
raise ValueError(f"Invalid job status: {job['status']}")
- @backoff.on_exception(
- backoff.constant,
- Exception,
- interval=1,
- max_time=60,
- max_tries=60,
- on_backoff=lambda details: print(f"Retrying... {details['exception']}"),
- )
+ @supabase_retry()
def find(self, **params) -> List[Dict[str, Any]]:
"""Find jobs by their JSON values in job.params
Example:
diff --git a/openweights/client/run.py b/openweights/client/run.py
index 8eba698..b4c5039 100644
--- a/openweights/client/run.py
+++ b/openweights/client/run.py
@@ -1,39 +1,11 @@
-import asyncio
-import atexit
-import json
-from typing import Optional, BinaryIO, Dict, Any, List, Union
+from typing import Optional, BinaryIO, Dict, Any, List
import os
import sys
from postgrest.exceptions import APIError
import hashlib
from datetime import datetime
-import backoff
from supabase import Client
-import httpx
-import logging
-from functools import wraps
-
-
-def is_transient_error(e):
- """Check if an error is likely transient and should be retried"""
- if isinstance(e, (httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectTimeout)):
- return True
- return False
-
-def retry_or_ignore(func, n_retries=5):
- """Retry a function, if it continues to fail, ignore the error"""
- @wraps(func)
- def wrapper(*args, **kwargs):
- for i in range(n_retries):
- try:
- return func(*args, **kwargs)
- except Exception as e:
- logging.error(f"Error in {func.__name__} (attempt {i+1}/{n_retries}): {e}")
- if i == n_retries - 1:
- # On last attempt, return None instead of raising
- return None
- return None
- return wrapper
+from openweights.client.decorators import supabase_retry
class Run:
@@ -43,14 +15,7 @@ def __init__(self, client: 'OpenWeights', job_id: Optional[str] = None, worker_i
self.organization_id = organization_id
self.id = run_id or os.getenv('OPENWEIGHTS_RUN_ID')
if self.id:
- # Run ID exists, fetch the data
- try:
- self._fetch_and_init_run(job_id, worker_id)
- except Exception as e:
- if not is_transient_error(e):
- raise
- # For transient errors, retry with backoff
- self._fetch_and_init_run_with_retry(job_id, worker_id)
+ self._fetch_and_init_run(job_id, worker_id)
else:
# Create new run
data = {
@@ -83,15 +48,7 @@ def __init__(self, client: 'OpenWeights', job_id: Optional[str] = None, worker_i
result = self._supabase.table('runs').insert(data).execute()
self._load_data(result.data[0])
- @backoff.on_exception(
- backoff.expo,
- (httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectTimeout),
- max_tries=5
- )
- def _fetch_and_init_run_with_retry(self, job_id, worker_id):
- """Fetch run data with retry logic"""
- self._fetch_and_init_run(job_id, worker_id)
-
+ @supabase_retry()
def _fetch_and_init_run(self, job_id, worker_id):
"""Fetch run data and initialize"""
try:
@@ -113,11 +70,7 @@ def _fetch_and_init_run(self, job_id, worker_id):
self._load_data(run_data)
- @backoff.on_exception(
- backoff.expo,
- (httpx.RemoteProtocolError, httpx.ReadTimeout, httpx.ConnectTimeout),
- max_tries=5
- )
+ @supabase_retry()
def _get_job_org_id_with_retry(self, job_id):
"""Get job organization ID with retry logic"""
return self._supabase.table('jobs').select('organization_id').eq('id', job_id).single().execute()
@@ -144,7 +97,7 @@ def get(supabase: Client, run_id: int) -> 'Run':
run._load_data(result.data)
return run
- @retry_or_ignore
+ @supabase_retry(return_on_exhaustion=None)
def update(self, status: Optional[str] = None, logfile: Optional[str] = None):
"""Update run status and/or logfile"""
data = {}
@@ -157,7 +110,7 @@ def update(self, status: Optional[str] = None, logfile: Optional[str] = None):
result = self._supabase.table('runs').update(data).eq('id', self.id).execute()
self._load_data(result.data[0])
- @retry_or_ignore
+ @supabase_retry(return_on_exhaustion=None)
def log(self, event_data: Dict[str, Any], file: Optional[BinaryIO] = None):
"""Log an event for this run"""
if file:
@@ -204,7 +157,7 @@ def __init__(self, client: 'OpenWeights'):
self.client = client
self._supabase = client._supabase
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_time=60, max_tries=60, on_backoff=lambda details: print(f"Retrying... {details['exception']}"))
+ @supabase_retry()
def list(self, job_id: Optional[str] = None, worker_id: Optional[str] = None, limit: int = 10, status: Optional[str]=None) -> List[Dict[str, Any]]:
"""List runs by job_id or worker_id"""
query = self._supabase.table('runs').select('*').limit(limit)
diff --git a/openweights/client/temporary_api.py b/openweights/client/temporary_api.py
index 3e44434..1a6289f 100644
--- a/openweights/client/temporary_api.py
+++ b/openweights/client/temporary_api.py
@@ -1,31 +1,14 @@
import asyncio
-import atexit
from openai import OpenAI, AsyncOpenAI
-import backoff
import time
import threading
from datetime import datetime, timedelta, timezone
-import json
-from huggingface_hub import HfApi, hf_hub_download
-from collections import defaultdict
-from typing import List, Dict
-from functools import lru_cache
-
-from openweights.client.utils import group_models_or_adapters_by_model, get_lora_rank
+from openweights.client.decorators import openai_retry
APIS = {}
-def on_backoff(details):
- exception_info = details['exception']
- if 'Bad Gateway' in exception_info:
- return
- if '
' in str(exception_info):
- exception_info = str(exception_info).split('')[1].split('')[0]
- print(f"Retrying... {exception_info}")
-
-
class TemporaryApi:
def __init__(self, ow, job_id):
self.ow = ow
@@ -68,12 +51,12 @@ def up(self):
def __enter__(self):
return self.up()
- @backoff.on_exception(backoff.constant, Exception, interval=10, max_time=3600, max_tries=3600)
+ @openai_retry(interval=10, max_time=3600, max_tries=3600)
def wait_until_ready(self, openai, model):
print('Waiting for API to be ready...')
openai.chat.completions.create(model=model, messages=[dict(role='user', content='Hello')], max_tokens=200)
- @backoff.on_exception(backoff.constant, Exception, interval=1, max_tries=10)
+ @openai_retry(interval=1, max_tries=10)
async def async_up(self):
self._stop_timeout_thread = False
self._timeout_thread = threading.Thread(target=self._manage_timeout, daemon=True)
diff --git a/openweights/cluster/start_runpod.py b/openweights/cluster/start_runpod.py
index c4d798c..0857704 100644
--- a/openweights/cluster/start_runpod.py
+++ b/openweights/cluster/start_runpod.py
@@ -1,6 +1,18 @@
"""
Usage:
- python start_runpod.py --gpu A6000 --container_disk_in_gb 25 --volume_in_gb 30
+ python start_runpod.py --gpu A6000 --container_disk_in_gb 25 --volume_in_gb 30 --ttl_hours 24
+
+TTL (Time To Live) Feature:
+ - All pods have a default TTL of 24 hours to prevent runaway costs
+ - TTL can be customized with --ttl_hours parameter
+ - TTL can be extended from within the pod by updating ~/shutdown.txt with a new timestamp
+ - Example to extend TTL from within pod:
+ python3 -c "
+ import datetime
+ with open('~/shutdown.txt', 'w') as f:
+ new_time = datetime.datetime.now() + datetime.timedelta(hours=48)
+ f.write(new_time.isoformat())
+ "
Note: possible unknown error with echo when running the script.
"""
@@ -16,7 +28,6 @@
from scp import SCPClient
from functools import lru_cache
-load_dotenv(override=True)
IMAGES = {
'default': 'nielsrolf/ow-default',
@@ -81,7 +92,6 @@
# Build map of memory -> hardware configu
HARDWARE_CONFIG = {}
-print("Available hardware configurations:")
def populate_hardware_config(runpod_client):
runpod_gpus = runpod_client.get_gpus()
for gpu_short, gpu_full in GPUs.items():
@@ -199,7 +209,7 @@ def check_correct_cuda(pod, allowed=allowed_cuda_versions, runpod_client=None):
@backoff.on_exception(backoff.expo, Exception, max_time=60, max_tries=5)
-def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, pending_workers=None, env=None, runpod_client=None):
+def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, ttl_hours=24, pending_workers=None, env=None, runpod_client=None):
client = runpod_client or runpod
gpu = GPUs[gpu]
# default name: -worker-
@@ -213,7 +223,9 @@ def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=5
env.update({
'WORKER_ID': worker_id,
'DOCKER_IMAGE': image,
- 'OW_DEV': 'true' if dev_mode else 'false'
+ 'OW_DEV': 'true' if dev_mode else 'false',
+ 'TTL_HOURS': str(ttl_hours),
+ 'RUNPOD_API_KEY': os.getenv('RUNPOD_API_KEY')
})
if worker_id is None:
worker_id = uuid.uuid4().hex[:8]
@@ -239,7 +251,7 @@ def _start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=5
return pod
-def start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, env=None, runpod_client=None):
+def start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=500, volume_in_gb=500, worker_id=None, dev_mode=False, ttl_hours=24, env=None, runpod_client=None):
pending_workers = []
if dev_mode:
env = {var: os.environ.get(var) for var in [
@@ -249,7 +261,7 @@ def start_worker(gpu, image, count=GPU_COUNT, name=None, container_disk_in_gb=50
runpod.api_key = os.getenv('RUNPOD_API_KEY')
runpod_client = runpod
try:
- pod = _start_worker(gpu, image, count, name, container_disk_in_gb, volume_in_gb, worker_id, dev_mode, pending_workers, env, runpod_client)
+ pod = _start_worker(gpu, image, count, name, container_disk_in_gb, volume_in_gb, worker_id, dev_mode, ttl_hours, pending_workers, env, runpod_client)
return pod
except Exception as e:
import traceback
diff --git a/openweights/dashboard/backend/main.py b/openweights/dashboard/backend/main.py
index c013938..870e1bd 100644
--- a/openweights/dashboard/backend/main.py
+++ b/openweights/dashboard/backend/main.py
@@ -17,6 +17,7 @@
CORSMiddleware,
allow_origins=[
"http://localhost:5173", # Vite dev server
+ "http://localhost:5174", # Vite dev server
"http://localhost:8124", # FastAPI dev server
"http://localhost:4173", # Vite preview
],
diff --git a/openweights/dashboard/frontend/src/App.tsx b/openweights/dashboard/frontend/src/App.tsx
index 9921ed2..9e4728b 100644
--- a/openweights/dashboard/frontend/src/App.tsx
+++ b/openweights/dashboard/frontend/src/App.tsx
@@ -16,7 +16,7 @@ import {
import MoreVertIcon from '@mui/icons-material/MoreVert';
import SettingsIcon from '@mui/icons-material/Settings';
import { JobsView } from './components/JobsView';
-import { RunsView } from './components/RunsView';
+// import { RunsView } from './components/RunsView';
import { WorkersView } from './components/WorkersView';
import { JobDetailView, RunDetailView, WorkerDetailView } from './components/DetailViews';
import { Auth } from './components/Auth/Auth';
@@ -123,7 +123,7 @@ function NavBar() {
{currentOrganization && (
<>
-
+
>
)}
@@ -216,10 +216,10 @@ function OrganizationRoutes() {
} />
} />
- } />
- } />
+
} />
} />
+ } />
} />
} />
diff --git a/openweights/dashboard/frontend/src/components/JobsListView.tsx b/openweights/dashboard/frontend/src/components/JobsListView.tsx
index 4094257..2ca738c 100644
--- a/openweights/dashboard/frontend/src/components/JobsListView.tsx
+++ b/openweights/dashboard/frontend/src/components/JobsListView.tsx
@@ -38,6 +38,7 @@ interface JobsListViewProps {
onPageChange: (event: unknown, newPage: number) => void;
onRowsPerPageChange: (event: React.ChangeEvent) => void;
orgId: string;
+ onCancelJob: (jobId: string) => Promise;
}
export const JobsListView: React.FC = ({
@@ -48,6 +49,7 @@ export const JobsListView: React.FC = ({
onPageChange,
onRowsPerPageChange,
orgId,
+ onCancelJob,
}) => {
const filteredJobs = jobs.filter(job => {
const searchStr = filter.toLowerCase();
@@ -77,6 +79,7 @@ export const JobsListView: React.FC = ({
Docker ImageCreated AtActions
+ Manage
@@ -110,6 +113,18 @@ export const JobsListView: React.FC = ({
View Details
+
+ {(job.status === 'pending' || job.status === 'in_progress') && (
+
+ )}
+
))}
{emptyRows > 0 && (
diff --git a/openweights/dashboard/frontend/src/components/JobsView.tsx b/openweights/dashboard/frontend/src/components/JobsView.tsx
index 7baa535..d3a4102 100644
--- a/openweights/dashboard/frontend/src/components/JobsView.tsx
+++ b/openweights/dashboard/frontend/src/components/JobsView.tsx
@@ -26,7 +26,7 @@ import { ViewToggle } from './ViewToggle';
import { JobsListView } from './JobsListView';
import { useOrganization } from '../contexts/OrganizationContext';
-const JobCard: React.FC<{ job: Job; orgId: string }> = ({ job, orgId }) => (
+const JobCard: React.FC<{ job: Job; orgId: string; onCancelJob: (jobId: string) => Promise }> = ({ job, orgId, onCancelJob }) => (
= ({ job, orgId }) => (
Created: {new Date(job.created_at).toLocaleString()}
-
+
+
+ {(job.status === 'pending' || job.status === 'in_progress') && (
+
+ )}
+
);
@@ -94,6 +106,7 @@ interface JobsColumnProps {
onRefresh: () => void;
loading?: boolean;
orgId: string;
+ onCancelJob: (jobId: string) => Promise;
}
const JobsColumn: React.FC = ({
@@ -107,8 +120,10 @@ const JobsColumn: React.FC = ({
lastRefresh,
onRefresh,
loading,
- orgId
+ orgId,
+ onCancelJob
}) => {
+
const filteredJobs = jobs.filter(job => {
const searchStr = filter.toLowerCase();
const jobId = String(job.id);
@@ -152,7 +167,7 @@ const JobsColumn: React.FC = ({
{paginatedJobs.map(job => (
-
+
))}
{
const [loading, setLoading] = useState(false);
const [lastRefresh, setLastRefresh] = useState();
const [autoRefresh, setAutoRefresh] = useState(true);
+ const [isCancelling, setIsCancelling] = useState(null);
const [view, setView] = useState<'three-column' | 'list'>('three-column');
const [statusFilters, setStatusFilters] = useState({
completed: true,
@@ -230,6 +246,19 @@ export const JobsView: React.FC = () => {
setPages({ pending: 0, inProgress: 0, completed: 0 });
};
+ const cancelJob = async (jobId: string) => {
+ if (!orgId) return;
+ try {
+ setIsCancelling(jobId);
+ await api.cancelJob(orgId, jobId);
+ await fetchJobs();
+ } catch (error) {
+ console.error('Failed to cancel job', error);
+ } finally {
+ setIsCancelling(null);
+ }
+ };
+
const filteredJobs = jobs.filter(job => {
const matchesType = typeFilter === 'all' || job.type === typeFilter;
return matchesType;
@@ -322,6 +351,7 @@ export const JobsView: React.FC = () => {
onRefresh={fetchJobs}
loading={loading}
orgId={orgId}
+ onCancelJob={cancelJob}
/>
{
onRefresh={fetchJobs}
loading={loading}
orgId={orgId}
+ onCancelJob={cancelJob}
/>
{
onPageChange={(_, newPage) => handlePageChange('completed')(newPage)}
onRowsPerPageChange={(event) => handleRowsPerPageChange(parseInt(event.target.value, 10))}
orgId={orgId}
+ onCancelJob={cancelJob}
/>
)}
diff --git a/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx b/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx
index e6579e0..60674ab 100644
--- a/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx
+++ b/openweights/dashboard/frontend/src/contexts/OrganizationContext.tsx
@@ -18,6 +18,7 @@ export function OrganizationProvider({ children }: { children: React.ReactNode }
const [organizations, setOrganizations] = useState([]);
const [currentOrganization, setCurrentOrganization] = useState(null);
const [loading, setLoading] = useState(true);
+ const [initialLoaded, setInitialLoaded] = useState(false);
const navigate = useNavigate();
const location = useLocation();
const { user } = useAuth();
@@ -29,21 +30,24 @@ export function OrganizationProvider({ children }: { children: React.ReactNode }
console.log('Loaded organizations:', orgs);
setOrganizations(orgs);
setLoading(false);
+ setInitialLoaded(true);
} catch (error) {
console.error('Failed to load organizations:', error);
setLoading(false);
+ setInitialLoaded(true);
}
};
// Load organizations when user changes
useEffect(() => {
- if (user) {
+ if (user && !initialLoaded) {
loadOrganizations();
- } else {
+ } else if (!user) {
setOrganizations([]);
setCurrentOrganization(null);
+ setInitialLoaded(false);
}
- }, [user]);
+ }, [user, initialLoaded]);
// Try to extract organization ID from URL and set it as current
useEffect(() => {
diff --git a/openweights/dashboard/frontend/src/main.tsx b/openweights/dashboard/frontend/src/main.tsx
index bef5202..9898468 100644
--- a/openweights/dashboard/frontend/src/main.tsx
+++ b/openweights/dashboard/frontend/src/main.tsx
@@ -4,7 +4,5 @@ import './index.css'
import App from './App.tsx'
createRoot(document.getElementById('root')!).render(
-
-
- ,
+
)
diff --git a/openweights/jobs/inference/cli.py b/openweights/jobs/inference/cli.py
index 3ec98b6..0d8370a 100644
--- a/openweights/jobs/inference/cli.py
+++ b/openweights/jobs/inference/cli.py
@@ -4,7 +4,6 @@
import time
import torch
-from dotenv import load_dotenv
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
@@ -17,7 +16,6 @@
from validate import InferenceConfig
-load_dotenv()
client = OpenWeights()
diff --git a/openweights/jobs/inspect_ai.py b/openweights/jobs/inspect_ai.py
index 3b7ab28..3e09104 100644
--- a/openweights/jobs/inspect_ai.py
+++ b/openweights/jobs/inspect_ai.py
@@ -7,7 +7,6 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List
-from dotenv import load_dotenv
class InspectAiConfig(BaseModel):
diff --git a/openweights/jobs/mmlu_pro/__init__.py b/openweights/jobs/mmlu_pro/__init__.py
index f7f8522..1203d62 100644
--- a/openweights/jobs/mmlu_pro/__init__.py
+++ b/openweights/jobs/mmlu_pro/__init__.py
@@ -7,7 +7,6 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List
-from dotenv import load_dotenv
class MMLUProArgs(BaseModel):
diff --git a/openweights/jobs/sft/logp_callback.py b/openweights/jobs/sft/logp_callback.py
index ec44e1a..0147536 100644
--- a/openweights/jobs/sft/logp_callback.py
+++ b/openweights/jobs/sft/logp_callback.py
@@ -1,8 +1,5 @@
-import math
import json
import os
-import torch
-import torch.nn.functional as F
from transformers import TrainerCallback
from utils import client
@@ -10,7 +7,7 @@
class LogTestLossCallback(TrainerCallback):
- def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_as='test_loss'):
+ def __init__(self, test_dataset, tokenizer, eval_steps, log_as, batch_size):
"""
A callback that evaluates model performance on a test dataset and logs the results.
@@ -18,7 +15,6 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_
test_dataset: Dataset with 'messages' field containing conversation messages
tokenizer: The tokenizer to use for encoding conversations
eval_steps: Evaluate every `eval_steps` training steps
- output_dir: Directory where token-level logP data will be saved
batch_size: Batch size to use during evaluation
log_as: Key to use when logging the loss metric
"""
@@ -27,47 +23,45 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_
self.eval_steps = eval_steps
self.batch_size = batch_size
self.log_as = log_as
+ self.is_block_format = False
+ if 'messages' in self.test_dataset.column_names and len(self.test_dataset) > 0:
+ first_example = self.test_dataset[0]
+ if 'messages' in first_example and len(first_example['messages']) > 0:
+ first_message = first_example['messages'][0]
+ if 'content' in first_message and isinstance(first_message['content'], list):
+ self.is_block_format = True
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
+
+ def on_init_end(self, args, state, control, **kwargs):
+ self.model = kwargs["model"]
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ self.run(model=self.model, step=0)
def on_step_end(self, args, state, control, **kwargs):
"""Called at the end of each training step."""
- print(f"Evaluating every {self.eval_steps} steps")
if state.global_step % self.eval_steps != 0:
return
-
- # Get the model from kwargs
- model = kwargs["model"]
+ self.run(kwargs['model'], state.global_step)
+ def run(self, model, step):
# Set model to eval mode
model.eval()
-
- # Check if the original dataset has weighted content format
- # The test_dataset should be the raw dataset with 'messages' field
- has_weighted_content = False
- if 'messages' in self.test_dataset.column_names and len(self.test_dataset) > 0:
- first_example = self.test_dataset[0]
- if 'messages' in first_example and len(first_example['messages']) > 0:
- first_message = first_example['messages'][0]
- if 'content' in first_message and isinstance(first_message['content'], list):
- # Check if it's the weighted format
- has_weighted_content = all(
- isinstance(block, dict) and 'weight' in block
- for block in first_message['content']
- )
- if has_weighted_content:
+ if self.is_block_format:
dataset_with_logprobs = get_logprobs_blockwise(model, self.tokenizer, self.test_dataset, self.batch_size)
- with open(f'logp_{self.log_as}_{state.global_step}.json', 'w') as f:
+ with open(f'logp_{self.log_as}_{step}.json', 'w') as f:
json.dump(dataset_with_logprobs, f)
- with open(f'logp_{self.log_as}_{state.global_step}.json', 'rb') as f:
+ with open(f'logp_{self.log_as}_{step}.json', 'rb') as f:
logprobs_file = client.files.create(f, purpose="logp_blockwise")
# For blockwise, we don't have a simple loss value, just log the file
client.run.log({
"type": "logprobs_blockwise",
- "step": state.global_step,
- "file": logprobs_file['id']
+ "step": step,
+ "file": logprobs_file['id'],
+ "tag": self.log_as
})
else:
token_logp, total_loss = get_logprobs(model, self.tokenizer, self.test_dataset, self.batch_size)
@@ -75,17 +69,18 @@ def on_step_end(self, args, state, control, **kwargs):
# Calculate average loss across all batches
avg_loss = total_loss / (len(self.test_dataset) / self.batch_size)
- with open(f'logp_{self.log_as}_{state.global_step}.json', 'w') as f:
+ with open(f'logp_{self.log_as}_{step}.json', 'w') as f:
json.dump(token_logp, f)
- with open(f'logp_{self.log_as}_{state.global_step}.json', 'rb') as f:
+ with open(f'logp_{self.log_as}_{step}.json', 'rb') as f:
logprobs_file = client.files.create(f, purpose="logp")
# Log the test loss
client.run.log({
"type": "logprobs",
self.log_as: avg_loss,
- "step": state.global_step,
- "file": logprobs_file['id']
+ "step": step,
+ "file": logprobs_file['id'],
+ "tag": self.log_as
})
# Return model to training mode
diff --git a/openweights/jobs/sft/training.py b/openweights/jobs/sft/training.py
index 61aca8a..785ac0c 100644
--- a/openweights/jobs/sft/training.py
+++ b/openweights/jobs/sft/training.py
@@ -57,15 +57,14 @@ def train(training_cfg, skip_client_logging: bool = False):
kwargs["max_steps"] = training_cfg.max_steps
trainer = sft_train(training_cfg, dataset, model, tokenizer, test_dataset=test_dataset, logp_datasets=logp_datasets, **kwargs)
+ trainer.evaluate()
trainer.train()
finetuned_model_id = training_cfg.finetuned_model_id or f"{training_cfg.model}:ft-{client.run.id}"
push_model(training_cfg,finetuned_model_id, model, tokenizer)
try:
- eval_results = trainer.evaluate()
- if not skip_client_logging:
- client.run.log(eval_results)
+ trainer.evaluate()
except Exception as e:
print(f"Error evaluating model: {e}. The model has already been pushed to the hub.")
diff --git a/openweights/jobs/sft/utils.py b/openweights/jobs/sft/utils.py
index 8f9760a..c80fdf9 100644
--- a/openweights/jobs/sft/utils.py
+++ b/openweights/jobs/sft/utils.py
@@ -2,28 +2,29 @@
import os
import torch
-from dotenv import load_dotenv
from transformers import AutoTokenizer, TrainerCallback
from functools import wraps
from openweights.client import OpenWeights
-load_dotenv()
client = OpenWeights()
-def load_model_and_tokenizer(model_id, load_in_4bit=False):
+def load_model_and_tokenizer(model_id, load_in_4bit=False, max_seq_length=2048):
from unsloth import FastLanguageModel, is_bfloat16_supported
+
model, tokenizer = FastLanguageModel.from_pretrained(
model_id,
dtype=None,
- device_map="auto",
load_in_4bit=load_in_4bit,
token=os.environ["HF_TOKEN"],
- max_seq_length=2048,
+ max_seq_length=max_seq_length,
+ device_map=None, # important: no lazy/meta map
+ low_cpu_mem_usage=False, # force real tensors
)
+ model = model.to("cuda")
if tokenizer.pad_token is None:
print("WARNING: tokenizer.pad_token is None. Setting it to tokenizer.eos_token")
tokenizer.pad_token = tokenizer.eos_token
@@ -35,10 +36,21 @@ def load_model_and_tokenizer(model_id, load_in_4bit=False):
class LogMetrics(TrainerCallback):
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
+ if metrics is None:
+ return
+ if args.process_index == 0: # only log once in distributed
+ payload = {k: v for k, v in metrics.items()}
+ payload["tag"] = "eval"
+ payload["step"] = state.global_step
+ client.run.log(payload)
+
def on_step_end(self, args, state, control, **kwargs):
try:
if len(state.log_history) == 0:
return
+ payload = {k: v for k, v in state.log_history[-1].items()}
+ payload["tag"] = "train"
client.run.log(state.log_history[-1])
except Exception as e:
# Sometimes there are connection errors to supabase etc that we can ignore
diff --git a/openweights/jobs/unsloth/logp_callback.py b/openweights/jobs/unsloth/logp_callback.py
index c2a358b..0147536 100644
--- a/openweights/jobs/unsloth/logp_callback.py
+++ b/openweights/jobs/unsloth/logp_callback.py
@@ -1,16 +1,13 @@
-import math
import json
import os
-import torch
-import torch.nn.functional as F
from transformers import TrainerCallback
from utils import client
-from logprobs import get_logprobs
+from logprobs import get_logprobs, get_logprobs_blockwise
class LogTestLossCallback(TrainerCallback):
- def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_as='test_loss'):
+ def __init__(self, test_dataset, tokenizer, eval_steps, log_as, batch_size):
"""
A callback that evaluates model performance on a test dataset and logs the results.
@@ -18,7 +15,6 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_
test_dataset: Dataset with 'messages' field containing conversation messages
tokenizer: The tokenizer to use for encoding conversations
eval_steps: Evaluate every `eval_steps` training steps
- output_dir: Directory where token-level logP data will be saved
batch_size: Batch size to use during evaluation
log_as: Key to use when logging the loss metric
"""
@@ -27,38 +23,65 @@ def __init__(self, test_dataset, tokenizer, eval_steps='log', batch_size=8, log_
self.eval_steps = eval_steps
self.batch_size = batch_size
self.log_as = log_as
+ self.is_block_format = False
+ if 'messages' in self.test_dataset.column_names and len(self.test_dataset) > 0:
+ first_example = self.test_dataset[0]
+ if 'messages' in first_example and len(first_example['messages']) > 0:
+ first_message = first_example['messages'][0]
+ if 'content' in first_message and isinstance(first_message['content'], list):
+ self.is_block_format = True
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
+
+ def on_init_end(self, args, state, control, **kwargs):
+ self.model = kwargs["model"]
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ self.run(model=self.model, step=0)
def on_step_end(self, args, state, control, **kwargs):
"""Called at the end of each training step."""
- print(f"Evaluating every {self.eval_steps} steps")
if state.global_step % self.eval_steps != 0:
return
-
- # Get the model from kwargs
- model = kwargs["model"]
+ self.run(kwargs['model'], state.global_step)
+ def run(self, model, step):
# Set model to eval mode
model.eval()
- token_logp, total_loss = get_logprobs(model, self.tokenizer, self.test_dataset, self.batch_size)
+ if self.is_block_format:
+ dataset_with_logprobs = get_logprobs_blockwise(model, self.tokenizer, self.test_dataset, self.batch_size)
+ with open(f'logp_{self.log_as}_{step}.json', 'w') as f:
+ json.dump(dataset_with_logprobs, f)
+ with open(f'logp_{self.log_as}_{step}.json', 'rb') as f:
+ logprobs_file = client.files.create(f, purpose="logp_blockwise")
+
+ # For blockwise, we don't have a simple loss value, just log the file
+ client.run.log({
+ "type": "logprobs_blockwise",
+ "step": step,
+ "file": logprobs_file['id'],
+ "tag": self.log_as
+ })
+ else:
+ token_logp, total_loss = get_logprobs(model, self.tokenizer, self.test_dataset, self.batch_size)
- # Calculate average loss across all batches
- avg_loss = total_loss / (len(self.test_dataset) / self.batch_size)
+ # Calculate average loss across all batches
+ avg_loss = total_loss / (len(self.test_dataset) / self.batch_size)
- with open(f'logp_{self.log_as}_{state.global_step}.json', 'w') as f:
- json.dump(token_logp, f)
- with open(f'logp_{self.log_as}_{state.global_step}.json', 'rb') as f:
- logprobs_file = client.files.create(f, purpose="logp")
+ with open(f'logp_{self.log_as}_{step}.json', 'w') as f:
+ json.dump(token_logp, f)
+ with open(f'logp_{self.log_as}_{step}.json', 'rb') as f:
+ logprobs_file = client.files.create(f, purpose="logp")
- # Log the test loss
- client.run.log({
- "type": "logprobs",
- self.log_as: avg_loss,
- "step": state.global_step,
- "file": logprobs_file['id']
- })
+ # Log the test loss
+ client.run.log({
+ "type": "logprobs",
+ self.log_as: avg_loss,
+ "step": step,
+ "file": logprobs_file['id'],
+ "tag": self.log_as
+ })
# Return model to training mode
model.train()
diff --git a/openweights/jobs/unsloth/sampling_callback.py b/openweights/jobs/unsloth/sampling_callback.py
index f9282ea..f1c4b25 100644
--- a/openweights/jobs/unsloth/sampling_callback.py
+++ b/openweights/jobs/unsloth/sampling_callback.py
@@ -72,28 +72,29 @@ def __init__(self, dataset, tokenizer, eval_steps='log', batch_size=8, tag='samp
self.tag = tag
self.temperature = temperature
self.max_tokens = max_tokens
+
+ def on_init_end(self, args, state, control, **kwargs):
+ self.model = kwargs["model"]
+ def on_train_begin(self, args, state, control, **kwargs):
+ self.run(model=self.model, step=0)
def on_step_end(self, args, state, control, **kwargs):
"""Called at the end of each training step."""
- eval_steps = 10 ** int(math.log10(max(1, state.global_step)))
- if self.eval_steps == 'log':
- eval_steps = eval_steps
- else:
- eval_steps = min(eval_steps, self.eval_steps)
-
- if state.global_step % eval_steps != 0:
+ if state.global_step % self.eval_steps != 0:
return
-
+ self.run(kwargs['model'], state.global_step)
+
+ def run(self, model, step):
+ """Called at the end of each training step."""
# Get the model from kwargs
- model = kwargs["model"]
FastLanguageModel.for_inference(model)
completions = sample(
model, self.tokenizer, self.dataset, batch_size=self.batch_size,
max_tokens=self.max_tokens, temperature=self.temperature)
- results_file = f'samples_{self.tag}_{state.global_step}.jsonl'
+ results_file = f'samples_{self.tag}_{step}.jsonl'
with open(results_file, 'w') as f:
for row, completion in zip(self.dataset, completions):
row['completion'] = completion
@@ -105,7 +106,7 @@ def on_step_end(self, args, state, control, **kwargs):
# Log the test loss
client.run.log({
"type": "samples",
- "step": state.global_step,
+ "step": step,
"file": samples_file['id'],
"tag": self.tag
})
diff --git a/openweights/jobs/unsloth/sft.py b/openweights/jobs/unsloth/sft.py
index a14c5ba..8dd1d5e 100644
--- a/openweights/jobs/unsloth/sft.py
+++ b/openweights/jobs/unsloth/sft.py
@@ -1,8 +1,5 @@
-import json
-import os
from os.path import commonprefix
-from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
@@ -123,7 +120,7 @@ def apply_chat_template(examples):
if training_cfg.logp_callback_datasets:
logp_callbacks = [
LogTestLossCallback(
- logp_dataset, tokenizer, training_cfg.eval_every_n_steps, log_as=key
+ logp_dataset, tokenizer, training_cfg.eval_every_n_steps, log_as=key, batch_size=training_cfg.eval_batch_size
)
for key, logp_dataset in logp_datasets.items()
]
diff --git a/openweights/jobs/unsloth/training.py b/openweights/jobs/unsloth/training.py
index d825daf..1cb31a7 100644
--- a/openweights/jobs/unsloth/training.py
+++ b/openweights/jobs/unsloth/training.py
@@ -77,15 +77,14 @@ def train(training_cfg, skip_client_logging: bool = False):
else:
raise ValueError(f"Unknown loss function: {training_cfg.loss}")
+ trainer.evaluate()
trainer.train()
finetuned_model_id = training_cfg.finetuned_model_id or f"{training_cfg.model}:ft-{client.run.id}"
push_model(training_cfg,finetuned_model_id, model, tokenizer)
try:
- eval_results = trainer.evaluate()
- if not skip_client_logging:
- client.run.log(eval_results)
+ trainer.evaluate()
except Exception as e:
print(f"Error evaluating model: {e}. The model has already been pushed to the hub.")
diff --git a/openweights/jobs/unsloth/utils.py b/openweights/jobs/unsloth/utils.py
index 21c603d..c80fdf9 100644
--- a/openweights/jobs/unsloth/utils.py
+++ b/openweights/jobs/unsloth/utils.py
@@ -2,13 +2,11 @@
import os
import torch
-from dotenv import load_dotenv
from transformers import AutoTokenizer, TrainerCallback
from functools import wraps
from openweights.client import OpenWeights
-load_dotenv()
client = OpenWeights()
@@ -38,10 +36,21 @@ def load_model_and_tokenizer(model_id, load_in_4bit=False, max_seq_length=2048):
class LogMetrics(TrainerCallback):
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
+ if metrics is None:
+ return
+ if args.process_index == 0: # only log once in distributed
+ payload = {k: v for k, v in metrics.items()}
+ payload["tag"] = "eval"
+ payload["step"] = state.global_step
+ client.run.log(payload)
+
def on_step_end(self, args, state, control, **kwargs):
try:
if len(state.log_history) == 0:
return
+ payload = {k: v for k, v in state.log_history[-1].items()}
+ payload["tag"] = "train"
client.run.log(state.log_history[-1])
except Exception as e:
# Sometimes there are connection errors to supabase etc that we can ignore
diff --git a/openweights/jobs/unsloth/validate.py b/openweights/jobs/unsloth/validate.py
index 1ac4893..29840b4 100644
--- a/openweights/jobs/unsloth/validate.py
+++ b/openweights/jobs/unsloth/validate.py
@@ -312,7 +312,7 @@ class LogProbJobModel(BaseModel):
class SamplingCallbackModel(BaseModel):
dataset: str
- eval_steps: Union[Literal["log"], int] = "log"
+ eval_steps: Union[Literal["log"], int] = 10
batch_size: int = 8
tag: str = "samples"
temperature: float = 0
diff --git a/openweights/worker/main.py b/openweights/worker/main.py
index 58aa9af..f3c86b9 100644
--- a/openweights/worker/main.py
+++ b/openweights/worker/main.py
@@ -290,13 +290,6 @@ def shutdown_handler(self):
if self.current_run:
self.current_run.update(logfile=log_response["id"])
- # Update worker record with logfile ID
- with open("logs/main", "rb") as log_file:
- log_response = self.files.create(log_file, purpose="logs")
- self.supabase.table("worker").update(
- {"logfile": log_response["id"]}
- ).eq("id", self.worker_id).execute()
-
# Mark job as 'pending' only if it is still 'in_progress' by this worker
self.update_job_status_if_in_progress(
self.current_job["id"],
@@ -308,6 +301,16 @@ def shutdown_handler(self):
self.current_run.update(status="canceled")
except Exception as e:
logging.error(f"Error updating job status during shutdown: {e}")
+
+ try:
+ # Update worker record with logfile ID
+ with open("logs/main", "rb") as log_file:
+ log_response = self.files.create(log_file, purpose="logs")
+ self.supabase.table("worker").update(
+ {"logfile": log_response["id"]}
+ ).eq("id", self.worker_id).execute()
+ except Exception as e:
+ logging.error(f"Error updating worker logs during shutdown: {e}")
# Update worker status
try:
diff --git a/openweights/worker/services/__init__.py b/openweights/worker/services/__init__.py
new file mode 100755
index 0000000..5ccc672
--- /dev/null
+++ b/openweights/worker/services/__init__.py
@@ -0,0 +1 @@
+# Services for worker processes
\ No newline at end of file
diff --git a/openweights/worker/services/checkout.py b/openweights/worker/services/checkout.py
new file mode 100755
index 0000000..f0e50b5
--- /dev/null
+++ b/openweights/worker/services/checkout.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+"""
+Repository checkout service for specific commits
+"""
+import os
+import subprocess
+import sys
+import shutil
+
+def checkout_commit():
+ """Checkout specific commit if OW_COMMIT is set"""
+ commit = os.environ.get('OW_COMMIT')
+ if not commit:
+ print("No OW_COMMIT specified, skipping repository checkout")
+ return True
+
+ try:
+ print(f"Starting repository checkout for commit: {commit}")
+
+ # Remove existing openweights directory
+ if os.path.exists('openweights'):
+ shutil.rmtree('openweights')
+
+ # Clone repository
+ subprocess.run(['git', 'clone', 'https://github.com/longtermrisk/openweights.git', 'openweights_dev'], check=True)
+
+ # Checkout specific commit
+ os.chdir('openweights_dev')
+ subprocess.run(['git', 'checkout', commit], check=True)
+
+ # Move openweights directory back
+ os.chdir('..')
+ shutil.move('openweights_dev/openweights', 'openweights')
+ shutil.rmtree('openweights_dev')
+
+ print("Repository checkout completed")
+ return True
+
+ except subprocess.CalledProcessError as e:
+ print(f"Failed to checkout repository: {e}")
+ return False
+ except Exception as e:
+ print(f"Unexpected error during checkout: {e}")
+ return False
+
+if __name__ == "__main__":
+ success = checkout_commit()
+ sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/openweights/worker/services/hf_login.py b/openweights/worker/services/hf_login.py
new file mode 100755
index 0000000..f77f8b5
--- /dev/null
+++ b/openweights/worker/services/hf_login.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+"""
+Hugging Face login service
+"""
+import os
+import sys
+from huggingface_hub.hf_api import HfFolder
+
+def login():
+ """Login to Hugging Face using HF_TOKEN from environment"""
+ try:
+ hf_token = os.environ.get('HF_TOKEN')
+ if not hf_token:
+ print("Warning: HF_TOKEN not found in environment")
+ return False
+
+ HfFolder.save_token(hf_token)
+ print("Successfully logged in to Hugging Face")
+ return True
+ except Exception as e:
+ print(f"Failed to login to Hugging Face: {e}")
+ return False
+
+if __name__ == "__main__":
+ success = login()
+ sys.exit(0 if success else 1)
\ No newline at end of file
diff --git a/openweights/worker/services/log_server.py b/openweights/worker/services/log_server.py
new file mode 100755
index 0000000..e1decb5
--- /dev/null
+++ b/openweights/worker/services/log_server.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python3
+"""
+HTTP log server that serves log files on port 10101
+"""
+import http.server
+import socketserver
+import os
+import sys
+
+class LogHandler(http.server.SimpleHTTPRequestHandler):
+ def do_GET(self):
+ # If path is /logs, serve logs/main
+ if self.path == "/logs" or self.path == "/logs/":
+ file_path = "logs/main"
+ else:
+ # Remove leading slash and ensure path is within logs directory
+ path = self.path.lstrip("/")
+ file_path = os.path.join("logs", path)
+
+ # Check if file exists and is within logs directory
+ if os.path.exists(file_path) and os.path.commonprefix([os.path.abspath(file_path), os.path.abspath("logs")]) == os.path.abspath("logs"):
+ self.send_response(200)
+ self.send_header("Content-type", "text/plain")
+ self.end_headers()
+ with open(file_path, "rb") as f:
+ self.wfile.write(f.read())
+ else:
+ self.send_error(404, "File not found")
+
+def start_server():
+ """Start the log server on port 10101"""
+ try:
+ # Ensure logs directory exists
+ os.makedirs("logs", exist_ok=True)
+
+ with socketserver.TCPServer(("", 10101), LogHandler) as httpd:
+ print("Starting HTTP log server on port 10101")
+ httpd.serve_forever()
+ except Exception as e:
+ print(f"Failed to start log server: {e}")
+ sys.exit(1)
+
+if __name__ == "__main__":
+ start_server()
\ No newline at end of file
diff --git a/openweights/worker/services/ttl_manager.py b/openweights/worker/services/ttl_manager.py
new file mode 100755
index 0000000..35a8915
--- /dev/null
+++ b/openweights/worker/services/ttl_manager.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+"""
+TTL management utility for extending or checking pod TTL
+"""
+import os
+import datetime
+import argparse
+
+def get_shutdown_time():
+ """Get the current shutdown time"""
+ shutdown_file = os.path.expanduser('~/shutdown.txt')
+ if not os.path.exists(shutdown_file):
+ return None
+
+ with open(shutdown_file, 'r') as f:
+ shutdown_time_str = f.read().strip()
+
+ try:
+ return datetime.datetime.fromisoformat(shutdown_time_str)
+ except ValueError:
+ return None
+
+def set_shutdown_time(shutdown_time):
+ """Set a new shutdown time"""
+ shutdown_file = os.path.expanduser('~/shutdown.txt')
+ with open(shutdown_file, 'w') as f:
+ f.write(shutdown_time.isoformat())
+
+def extend_ttl(hours):
+ """Extend TTL by specified hours from now"""
+ new_shutdown_time = datetime.datetime.now() + datetime.timedelta(hours=hours)
+ set_shutdown_time(new_shutdown_time)
+ return new_shutdown_time
+
+def set_ttl(hours):
+ """Set TTL to specified hours from now"""
+ return extend_ttl(hours)
+
+def check_ttl():
+ """Check current TTL status"""
+ shutdown_time = get_shutdown_time()
+ if not shutdown_time:
+ return "No TTL set"
+
+ current_time = datetime.datetime.now()
+ time_left = shutdown_time - current_time
+
+ if time_left.total_seconds() <= 0:
+ return f"TTL expired at {shutdown_time}"
+ else:
+ hours_left = time_left.total_seconds() / 3600
+ return f"TTL expires at {shutdown_time} ({hours_left:.1f} hours remaining)"
+
+def main():
+ parser = argparse.ArgumentParser(description="Manage pod TTL (Time To Live)")
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument('--check', action='store_true', help='Check current TTL status')
+ group.add_argument('--extend', type=float, help='Extend TTL by specified hours from now')
+ group.add_argument('--set', type=float, help='Set TTL to specified hours from now')
+
+ args = parser.parse_args()
+
+ if args.check:
+ print(check_ttl())
+ elif args.extend:
+ new_time = extend_ttl(args.extend)
+ print(f"TTL extended by {args.extend} hours. New shutdown time: {new_time}")
+ elif args.set:
+ new_time = set_ttl(args.set)
+ print(f"TTL set to {args.set} hours from now. Shutdown time: {new_time}")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/openweights/worker/services/ttl_monitor.py b/openweights/worker/services/ttl_monitor.py
new file mode 100755
index 0000000..abf591e
--- /dev/null
+++ b/openweights/worker/services/ttl_monitor.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+"""
+TTL monitoring service that terminates pods after expiration
+"""
+import os
+import time
+import datetime
+import sys
+
+def setup_ttl():
+ """Setup initial TTL based on environment variable"""
+ ttl_hours = float(os.environ.get('TTL_HOURS', '24'))
+ shutdown_time = datetime.datetime.now() + datetime.timedelta(hours=ttl_hours)
+
+ shutdown_file = os.path.expanduser('~/shutdown.txt')
+ with open(shutdown_file, 'w') as f:
+ f.write(shutdown_time.isoformat())
+
+ print(f"TTL set to {ttl_hours} hours. Shutdown scheduled for: {shutdown_time}")
+ return shutdown_file
+
+def get_pod_id():
+ """Get the current pod ID from RunPod metadata or environment"""
+ # First try environment variable
+ pod_id = os.environ.get('RUNPOD_POD_ID')
+ if pod_id:
+ return pod_id
+
+ # Try to get from RunPod metadata API
+ try:
+ import httpx
+ with httpx.Client() as client:
+ response = client.get('http://metadata.runpod.ai/v1/instance/id', timeout=10)
+ if response.status_code == 200:
+ return response.text.strip()
+ except Exception as e:
+ print(f"Could not get pod ID from metadata API: {e}")
+
+ return None
+
+def terminate_pod():
+ """Terminate the current pod using RunPod API"""
+ try:
+ import runpod
+
+ api_key = os.environ.get('RUNPOD_API_KEY')
+ pod_id = get_pod_id()
+
+ if not api_key:
+ print("ERROR: RUNPOD_API_KEY not found in environment")
+ return False
+
+ if not pod_id:
+ print("ERROR: Could not determine pod ID")
+ return False
+
+ runpod.api_key = api_key
+ result = runpod.terminate_pod(pod_id)
+ print(f"Pod termination initiated for {pod_id}: {result}")
+ return True
+
+ except ImportError:
+ print("ERROR: runpod package not available")
+ return False
+ except Exception as e:
+ print(f"ERROR: Failed to terminate pod: {e}")
+ return False
+
+def monitor_ttl():
+ """Monitor TTL and terminate pod when expired"""
+ shutdown_file = setup_ttl()
+
+ print("Starting TTL monitoring service...")
+
+ while True:
+ try:
+ if os.path.exists(shutdown_file):
+ with open(shutdown_file, 'r') as f:
+ shutdown_time_str = f.read().strip()
+
+ try:
+ shutdown_time = datetime.datetime.fromisoformat(shutdown_time_str)
+ current_time = datetime.datetime.now()
+
+ if current_time >= shutdown_time:
+ print(f"TTL expired at {shutdown_time}. Current time: {current_time}")
+ print("Initiating pod termination...")
+
+ if terminate_pod():
+ print("Pod termination successful")
+ break
+ else:
+ print("Pod termination failed, will retry in 60 seconds")
+ else:
+ time_left = shutdown_time - current_time
+ print(f"TTL check: {time_left} remaining until shutdown")
+
+ except ValueError as e:
+ print(f"Invalid shutdown time format in {shutdown_file}: {e}")
+ # Re-setup TTL if file is corrupted
+ shutdown_file = setup_ttl()
+ else:
+ print(f"Shutdown file {shutdown_file} not found, recreating...")
+ shutdown_file = setup_ttl()
+
+ except Exception as e:
+ print(f"Error in TTL monitoring: {e}")
+
+ time.sleep(60) # Check every minute
+
+if __name__ == "__main__":
+ try:
+ monitor_ttl()
+ except KeyboardInterrupt:
+ print("TTL monitoring service stopped")
+ sys.exit(0)
+ except Exception as e:
+ print(f"TTL monitoring service failed: {e}")
+ sys.exit(1)
\ No newline at end of file
diff --git a/ow-default.Dockerfile b/ow-default.Dockerfile
index f9e2a82..71bc138 100644
--- a/ow-default.Dockerfile
+++ b/ow-default.Dockerfile
@@ -35,5 +35,6 @@ RUN ln -s /usr/bin/python3 /usr/bin/python
EXPOSE 22
EXPOSE 8000
+EXPOSE 10101
ENTRYPOINT ["/my_app/entrypoint.sh"]
\ No newline at end of file
diff --git a/supabase/migrations/20250917150000_update_call_sites_to_app_auth.sql b/supabase/migrations/20250917150000_update_call_sites_to_app_auth.sql
new file mode 100644
index 0000000..c320315
--- /dev/null
+++ b/supabase/migrations/20250917150000_update_call_sites_to_app_auth.sql
@@ -0,0 +1,76 @@
+-- Switch call sites from auth.check_if_service_account() -> app_auth.check_if_service_account()
+-- Keep SECURITY DEFINER and add a safe search_path that includes app_auth.
+
+-- 1) public.get_organization_from_token (final version from 2024-12-05-00:00:10)
+CREATE OR REPLACE FUNCTION public.get_organization_from_token()
+RETURNS uuid
+LANGUAGE plpgsql
+SECURITY DEFINER
+SET search_path = public, auth, app_auth
+AS $$
+DECLARE
+ org_id uuid;
+BEGIN
+ -- Only handle service account tokens
+ IF NOT app_auth.check_if_service_account() THEN
+ RAISE EXCEPTION 'Only service account tokens are supported';
+ END IF;
+
+ -- Get org from claims
+ org_id := (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid;
+
+ -- Update last_used_at in tokens table
+ UPDATE public.tokens
+ SET last_used_at = now()
+ WHERE id = (current_setting('request.jwt.claims', true)::json->>'token_id')::uuid;
+
+ RETURN org_id;
+END;
+$$;
+
+-- 2) public.is_organization_member(uuid)
+CREATE OR REPLACE FUNCTION public.is_organization_member(org_id uuid)
+RETURNS boolean
+LANGUAGE plpgsql
+SECURITY DEFINER
+SET search_path = public, auth, app_auth
+AS $$
+BEGIN
+ -- If this is a service account, check the organization claim
+ IF app_auth.check_if_service_account() THEN
+ RETURN (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid = org_id;
+ END IF;
+
+ -- Otherwise check normal membership
+ RETURN EXISTS (
+ SELECT 1
+ FROM public.organization_members
+ WHERE organization_id = org_id
+ AND user_id = auth.uid()
+ );
+END;
+$$;
+
+-- 3) public.is_organization_admin(uuid)
+CREATE OR REPLACE FUNCTION public.is_organization_admin(org_id uuid)
+RETURNS boolean
+LANGUAGE plpgsql
+SECURITY DEFINER
+SET search_path = public, auth, app_auth
+AS $$
+BEGIN
+ -- Service accounts have admin access to their organization
+ IF app_auth.check_if_service_account() THEN
+ RETURN (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid = org_id;
+ END IF;
+
+ -- Otherwise check normal admin membership
+ RETURN EXISTS (
+ SELECT 1
+ FROM public.organization_members
+ WHERE organization_id = org_id
+ AND user_id = auth.uid()
+ AND role = 'admin'
+ );
+END;
+$$;
diff --git a/supabase/migrations/20250917153500_storage_policies_move_to_public.sql b/supabase/migrations/20250917153500_storage_policies_move_to_public.sql
new file mode 100644
index 0000000..534da42
--- /dev/null
+++ b/supabase/migrations/20250917153500_storage_policies_move_to_public.sql
@@ -0,0 +1,117 @@
+-- All app-owned helpers live in public. Storage policies reference these.
+-- No objects created/modified inside 'storage' or 'app_storage'.
+
+-- 1) Path helper: organizations//...
+CREATE OR REPLACE FUNCTION public.get_path_organization_id(path text)
+RETURNS uuid
+LANGUAGE plpgsql
+STABLE
+AS $$
+DECLARE
+ parts text[];
+ org_id uuid;
+BEGIN
+ parts := string_to_array(path, '/');
+
+ IF array_length(parts, 1) IS NULL OR parts[1] <> 'organizations' THEN
+ RETURN NULL;
+ END IF;
+
+ BEGIN
+ org_id := parts[2]::uuid;
+ RETURN org_id;
+ EXCEPTION WHEN others THEN
+ RETURN NULL;
+ END;
+END;
+$$;
+
+-- 2) Access helper: service-account claim OR membership
+CREATE OR REPLACE FUNCTION public.has_organization_access(org_id uuid)
+RETURNS boolean
+LANGUAGE plpgsql
+SECURITY DEFINER
+SET search_path = public, auth, app_auth
+STABLE
+AS $$
+BEGIN
+ -- Service account?
+ IF app_auth.check_if_service_account() THEN
+ RETURN (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid = org_id;
+ END IF;
+
+ -- Otherwise, membership
+ RETURN EXISTS (
+ SELECT 1
+ FROM public.organization_members
+ WHERE organization_id = org_id
+ AND user_id = auth.uid()
+ );
+END;
+$$;
+
+-- Permissions (optional but harmless)
+GRANT EXECUTE ON FUNCTION public.get_path_organization_id(text) TO anon, authenticated, service_role, postgres;
+GRANT EXECUTE ON FUNCTION public.has_organization_access(uuid) TO anon, authenticated, service_role, postgres;
+
+-- 3) Re-create Storage policies to reference public.* helpers
+
+-- Read
+DROP POLICY IF EXISTS "Organization members can read their files" ON storage.objects;
+CREATE POLICY "Organization members can read their files"
+ON storage.objects FOR SELECT
+USING (
+ bucket_id = 'files'
+ AND (
+ name LIKE '%.keep'
+ OR (
+ name LIKE 'organizations/%'
+ AND public.has_organization_access(public.get_path_organization_id(name))
+ )
+ )
+);
+
+-- Insert
+DROP POLICY IF EXISTS "Organization members can upload files" ON storage.objects;
+CREATE POLICY "Organization members can upload files"
+ON storage.objects FOR INSERT
+WITH CHECK (
+ bucket_id = 'files'
+ AND (
+ name LIKE '%.keep'
+ OR (
+ name LIKE 'organizations/%'
+ AND public.has_organization_access(public.get_path_organization_id(name))
+ )
+ )
+);
+
+-- Update
+DROP POLICY IF EXISTS "Organization members can update their files" ON storage.objects;
+CREATE POLICY "Organization members can update their files"
+ON storage.objects FOR UPDATE
+USING (
+ bucket_id = 'files'
+ AND (
+ name LIKE '%.keep'
+ OR (
+ name LIKE 'organizations/%'
+ AND public.has_organization_access(public.get_path_organization_id(name))
+ )
+ )
+);
+
+-- Delete
+DROP POLICY IF EXISTS "Organization members can delete their files" ON storage.objects;
+CREATE POLICY "Organization members can delete their files"
+ON storage.objects FOR DELETE
+USING (
+ bucket_id = 'files'
+ AND (
+ name LIKE '%.keep'
+ OR (
+ name LIKE 'organizations/%'
+ AND public.has_organization_access(public.get_path_organization_id(name))
+ )
+ )
+);
diff --git a/supabase/migrations_dev/20241120094415_remote_schema.sql b/supabase/migrations_dev/20241120094415_remote_schema.sql
deleted file mode 100644
index d09ba13..0000000
--- a/supabase/migrations_dev/20241120094415_remote_schema.sql
+++ /dev/null
@@ -1,161 +0,0 @@
-
-SET statement_timeout = 0;
-SET lock_timeout = 0;
-SET idle_in_transaction_session_timeout = 0;
-SET client_encoding = 'UTF8';
-SET standard_conforming_strings = on;
-SELECT pg_catalog.set_config('search_path', '', false);
-SET check_function_bodies = false;
-SET xmloption = content;
-SET client_min_messages = warning;
-SET row_security = off;
-CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
-COMMENT ON SCHEMA "public" IS 'standard public schema';
-CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
-CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
-CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
-CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
-CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
-CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
-CREATE TYPE "public"."job_status" AS ENUM (
- 'pending',
- 'in_progress',
- 'completed',
- 'failed',
- 'canceled'
-);
-ALTER TYPE "public"."job_status" OWNER TO "postgres";
-CREATE TYPE "public"."job_type" AS ENUM (
- 'fine-tuning',
- 'inference',
- 'script'
-);
-ALTER TYPE "public"."job_type" OWNER TO "postgres";
-SET default_tablespace = '';
-SET default_table_access_method = "heap";
-CREATE TABLE IF NOT EXISTS "public"."events" (
- "id" integer NOT NULL,
- "run_id" integer,
- "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
- "data" "jsonb" NOT NULL,
- "file" "text"
-);
-ALTER TABLE "public"."events" OWNER TO "postgres";
-CREATE SEQUENCE IF NOT EXISTS "public"."events_id_seq"
- AS integer
- START WITH 1
- INCREMENT BY 1
- NO MINVALUE
- NO MAXVALUE
- CACHE 1;
-ALTER TABLE "public"."events_id_seq" OWNER TO "postgres";
-ALTER SEQUENCE "public"."events_id_seq" OWNED BY "public"."events"."id";
-CREATE TABLE IF NOT EXISTS "public"."files" (
- "id" "text" NOT NULL,
- "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
- "filename" "text" NOT NULL,
- "purpose" "text" NOT NULL,
- "bytes" integer NOT NULL
-);
-ALTER TABLE "public"."files" OWNER TO "postgres";
-CREATE TABLE IF NOT EXISTS "public"."jobs" (
- "id" "text" NOT NULL,
- "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
- "type" "public"."job_type" NOT NULL,
- "model" "text",
- "params" "jsonb",
- "script" "text",
- "outputs" "jsonb",
- "requires_vram_gb" integer DEFAULT 24,
- "status" "public"."job_status" DEFAULT 'pending'::"public"."job_status",
- "worker_id" "text"
-);
-ALTER TABLE "public"."jobs" OWNER TO "postgres";
-CREATE TABLE IF NOT EXISTS "public"."runs" (
- "id" integer NOT NULL,
- "job_id" "text",
- "worker_id" "text",
- "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
- "status" "public"."job_status",
- "log_file" "text"
-);
-ALTER TABLE "public"."runs" OWNER TO "postgres";
-CREATE SEQUENCE IF NOT EXISTS "public"."runs_id_seq"
- AS integer
- START WITH 1
- INCREMENT BY 1
- NO MINVALUE
- NO MAXVALUE
- CACHE 1;
-ALTER TABLE "public"."runs_id_seq" OWNER TO "postgres";
-ALTER SEQUENCE "public"."runs_id_seq" OWNED BY "public"."runs"."id";
-CREATE TABLE IF NOT EXISTS "public"."worker" (
- "id" "text" NOT NULL,
- "created_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
- "status" "text",
- "cached_models" "text"[],
- "vram_gb" integer,
- "pod_id" "text",
- "ping" timestamp with time zone
-);
-ALTER TABLE "public"."worker" OWNER TO "postgres";
-ALTER TABLE ONLY "public"."events" ALTER COLUMN "id" SET DEFAULT "nextval"('"public"."events_id_seq"'::"regclass");
-ALTER TABLE ONLY "public"."runs" ALTER COLUMN "id" SET DEFAULT "nextval"('"public"."runs_id_seq"'::"regclass");
-ALTER TABLE ONLY "public"."events"
- ADD CONSTRAINT "events_pkey" PRIMARY KEY ("id");
-ALTER TABLE ONLY "public"."files"
- ADD CONSTRAINT "files_pkey" PRIMARY KEY ("id");
-ALTER TABLE ONLY "public"."jobs"
- ADD CONSTRAINT "jobs_pkey" PRIMARY KEY ("id");
-ALTER TABLE ONLY "public"."runs"
- ADD CONSTRAINT "runs_pkey" PRIMARY KEY ("id");
-ALTER TABLE ONLY "public"."worker"
- ADD CONSTRAINT "worker_pkey" PRIMARY KEY ("id");
-CREATE INDEX "events_run_id_idx" ON "public"."events" USING "btree" ("run_id");
-ALTER TABLE ONLY "public"."events"
- ADD CONSTRAINT "events_run_id_fkey" FOREIGN KEY ("run_id") REFERENCES "public"."runs"("id") ON DELETE CASCADE;
-ALTER TABLE ONLY "public"."jobs"
- ADD CONSTRAINT "jobs_worker_id_fkey" FOREIGN KEY ("worker_id") REFERENCES "public"."worker"("id");
-ALTER TABLE ONLY "public"."runs"
- ADD CONSTRAINT "runs_job_id_fkey" FOREIGN KEY ("job_id") REFERENCES "public"."jobs"("id") ON DELETE CASCADE;
-ALTER TABLE ONLY "public"."runs"
- ADD CONSTRAINT "runs_worker_id_fkey" FOREIGN KEY ("worker_id") REFERENCES "public"."worker"("id") ON DELETE SET NULL;
-ALTER PUBLICATION "supabase_realtime" OWNER TO "postgres";
-GRANT USAGE ON SCHEMA "public" TO "postgres";
-GRANT USAGE ON SCHEMA "public" TO "anon";
-GRANT USAGE ON SCHEMA "public" TO "authenticated";
-GRANT USAGE ON SCHEMA "public" TO "service_role";
-GRANT ALL ON TABLE "public"."events" TO "anon";
-GRANT ALL ON TABLE "public"."events" TO "authenticated";
-GRANT ALL ON TABLE "public"."events" TO "service_role";
-GRANT ALL ON SEQUENCE "public"."events_id_seq" TO "anon";
-GRANT ALL ON SEQUENCE "public"."events_id_seq" TO "authenticated";
-GRANT ALL ON SEQUENCE "public"."events_id_seq" TO "service_role";
-GRANT ALL ON TABLE "public"."files" TO "anon";
-GRANT ALL ON TABLE "public"."files" TO "authenticated";
-GRANT ALL ON TABLE "public"."files" TO "service_role";
-GRANT ALL ON TABLE "public"."jobs" TO "anon";
-GRANT ALL ON TABLE "public"."jobs" TO "authenticated";
-GRANT ALL ON TABLE "public"."jobs" TO "service_role";
-GRANT ALL ON TABLE "public"."runs" TO "anon";
-GRANT ALL ON TABLE "public"."runs" TO "authenticated";
-GRANT ALL ON TABLE "public"."runs" TO "service_role";
-GRANT ALL ON SEQUENCE "public"."runs_id_seq" TO "anon";
-GRANT ALL ON SEQUENCE "public"."runs_id_seq" TO "authenticated";
-GRANT ALL ON SEQUENCE "public"."runs_id_seq" TO "service_role";
-GRANT ALL ON TABLE "public"."worker" TO "anon";
-GRANT ALL ON TABLE "public"."worker" TO "authenticated";
-GRANT ALL ON TABLE "public"."worker" TO "service_role";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "postgres";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "anon";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "authenticated";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON SEQUENCES TO "service_role";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "postgres";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "anon";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "authenticated";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON FUNCTIONS TO "service_role";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "postgres";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "anon";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "authenticated";
-ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TABLES TO "service_role";
-RESET ALL;
diff --git a/supabase/migrations_dev/20241120161019_add_gpu_columns.sql b/supabase/migrations_dev/20241120161019_add_gpu_columns.sql
deleted file mode 100644
index c0821b1..0000000
--- a/supabase/migrations_dev/20241120161019_add_gpu_columns.sql
+++ /dev/null
@@ -1,9 +0,0 @@
--- Add gpu_type and gpu_count columns to worker table
-ALTER TABLE "public"."worker"
- ADD COLUMN "gpu_type" text,
- ADD COLUMN "gpu_count" integer;
-
--- Grant permissions to match existing table permissions
-GRANT ALL ON TABLE "public"."worker" TO "anon";
-GRANT ALL ON TABLE "public"."worker" TO "authenticated";
-GRANT ALL ON TABLE "public"."worker" TO "service_role";
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241127132826_add_docker_image.sql b/supabase/migrations_dev/20241127132826_add_docker_image.sql
deleted file mode 100644
index 93aa998..0000000
--- a/supabase/migrations_dev/20241127132826_add_docker_image.sql
+++ /dev/null
@@ -1,8 +0,0 @@
--- Add docker_image column to jobs table
-ALTER TABLE "public"."jobs"
- ADD COLUMN "docker_image" text;
-
--- Grant permissions to match existing table permissions
-GRANT ALL ON TABLE "public"."jobs" TO "anon";
-GRANT ALL ON TABLE "public"."jobs" TO "authenticated";
-GRANT ALL ON TABLE "public"."jobs" TO "service_role";
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241127153031_add_docker_image_to_workers.sql b/supabase/migrations_dev/20241127153031_add_docker_image_to_workers.sql
deleted file mode 100644
index 09dbfd5..0000000
--- a/supabase/migrations_dev/20241127153031_add_docker_image_to_workers.sql
+++ /dev/null
@@ -1 +0,0 @@
-alter table "public"."worker" add column "docker_image" text;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241202233213_add_updated_at_columns.sql b/supabase/migrations_dev/20241202233213_add_updated_at_columns.sql
deleted file mode 100644
index 3f8ca4c..0000000
--- a/supabase/migrations_dev/20241202233213_add_updated_at_columns.sql
+++ /dev/null
@@ -1,52 +0,0 @@
--- Add updated_at columns to tables
-ALTER TABLE "public"."worker"
- ADD COLUMN "updated_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP;
-
-ALTER TABLE "public"."jobs"
- ADD COLUMN "updated_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP;
-
-ALTER TABLE "public"."runs"
- ADD COLUMN "updated_at" timestamp with time zone DEFAULT CURRENT_TIMESTAMP;
-
--- Update existing rows to set updated_at = created_at
-UPDATE "public"."worker" SET "updated_at" = "created_at";
-UPDATE "public"."jobs" SET "updated_at" = "created_at";
-UPDATE "public"."runs" SET "updated_at" = "created_at";
-
--- Create a function to automatically set updated_at
-CREATE OR REPLACE FUNCTION public.handle_updated_at()
-RETURNS TRIGGER AS $$
-BEGIN
- NEW.updated_at = CURRENT_TIMESTAMP;
- RETURN NEW;
-END;
-$$ language 'plpgsql';
-
--- Create triggers for each table
-CREATE TRIGGER set_updated_at_worker
- BEFORE UPDATE ON public.worker
- FOR EACH ROW
- EXECUTE FUNCTION public.handle_updated_at();
-
-CREATE TRIGGER set_updated_at_jobs
- BEFORE UPDATE ON public.jobs
- FOR EACH ROW
- EXECUTE FUNCTION public.handle_updated_at();
-
-CREATE TRIGGER set_updated_at_runs
- BEFORE UPDATE ON public.runs
- FOR EACH ROW
- EXECUTE FUNCTION public.handle_updated_at();
-
--- Grant permissions to match existing table permissions
-GRANT ALL ON TABLE "public"."worker" TO "anon";
-GRANT ALL ON TABLE "public"."worker" TO "authenticated";
-GRANT ALL ON TABLE "public"."worker" TO "service_role";
-
-GRANT ALL ON TABLE "public"."jobs" TO "anon";
-GRANT ALL ON TABLE "public"."jobs" TO "authenticated";
-GRANT ALL ON TABLE "public"."jobs" TO "service_role";
-
-GRANT ALL ON TABLE "public"."runs" TO "anon";
-GRANT ALL ON TABLE "public"."runs" TO "authenticated";
-GRANT ALL ON TABLE "public"."runs" TO "service_role";
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241203000542_add_api_job_type.sql b/supabase/migrations_dev/20241203000542_add_api_job_type.sql
deleted file mode 100644
index 517015b..0000000
--- a/supabase/migrations_dev/20241203000542_add_api_job_type.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- Add 'api' to the job_type enum
-ALTER TYPE "public"."job_type" ADD VALUE IF NOT EXISTS 'api';
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000000_add_organizations.sql b/supabase/migrations_dev/20241205000000_add_organizations.sql
deleted file mode 100644
index 5d9eaae..0000000
--- a/supabase/migrations_dev/20241205000000_add_organizations.sql
+++ /dev/null
@@ -1,206 +0,0 @@
--- Create role enum type
-create type "public"."organization_role" as enum ('admin', 'user');
-
--- Create organizations table
-create table if not exists "public"."organizations" (
- "id" uuid not null default gen_random_uuid(),
- "created_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- "updated_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- "name" text not null,
- primary key (id)
-);
-
--- Create organization_members junction table with role
-create table if not exists "public"."organization_members" (
- "organization_id" uuid not null references organizations(id) on delete cascade,
- "user_id" uuid not null references auth.users(id) on delete cascade,
- "role" organization_role not null default 'user',
- "created_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- primary key (organization_id, user_id)
-);
-
--- Create third_party_api_keys table for RunPod, HuggingFace, etc.
-create table if not exists "public"."third_party_api_keys" (
- "organization_id" uuid not null references organizations(id) on delete cascade,
- "service" text not null, -- 'runpod', 'huggingface'
- "api_key" text not null,
- "created_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- "updated_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- primary key (organization_id, service)
-);
-
--- Create a default organization for existing data
-insert into "public"."organizations" (id, name)
-values ('00000000-0000-0000-0000-000000000000', 'Default Organization');
-
--- Add organization ownership to jobs
-alter table "public"."jobs"
- add column "organization_id" uuid references organizations(id) on delete cascade;
-
--- Add organization ownership to workers
-alter table "public"."worker"
- add column "organization_id" uuid references organizations(id) on delete cascade;
-
--- Migrate existing jobs and workers to default organization
-update "public"."jobs"
-set "organization_id" = '00000000-0000-0000-0000-000000000000'
-where "organization_id" is null;
-
-update "public"."worker"
-set "organization_id" = '00000000-0000-0000-0000-000000000000'
-where "organization_id" is null;
-
--- Make organization_id not null after migration
-alter table "public"."jobs"
- alter column "organization_id" set not null;
-
-alter table "public"."worker"
- alter column "organization_id" set not null;
-
--- Enable RLS on all tables
-alter table "public"."organizations" enable row level security;
-alter table "public"."organization_members" enable row level security;
-alter table "public"."third_party_api_keys" enable row level security;
-alter table "public"."jobs" enable row level security;
-alter table "public"."runs" enable row level security;
-alter table "public"."worker" enable row level security;
-alter table "public"."events" enable row level security;
-
--- Organizations: members can read, admins can update
-create policy "Members can view their organizations"
- on organizations for select
- using (
- exists (
- select 1 from organization_members
- where organization_id = organizations.id
- and user_id = auth.uid()
- )
- );
-
-create policy "Admins can update their organizations"
- on organizations for update
- using (
- exists (
- select 1 from organization_members
- where organization_id = organizations.id
- and user_id = auth.uid()
- and role = 'admin'
- )
- );
-
--- Organization members: members can view, admins can manage
-create policy "Members can view other members in their organizations"
- on organization_members for select
- using (
- exists (
- select 1 from organization_members as om
- where om.organization_id = organization_members.organization_id
- and om.user_id = auth.uid()
- )
- );
-
-create policy "Admins can manage members in their organizations"
- on organization_members for all
- using (
- exists (
- select 1 from organization_members
- where organization_id = organization_members.organization_id
- and user_id = auth.uid()
- and role = 'admin'
- )
- );
-
--- Jobs: members can CRUD jobs in their organizations
-create policy "Members can manage jobs in their organizations"
- on jobs for all
- using (
- exists (
- select 1 from organization_members
- where organization_id = jobs.organization_id
- and user_id = auth.uid()
- )
- );
-
--- Runs: members can manage runs of jobs in their organizations
-create policy "Members can manage runs in their organizations"
- on runs for all
- using (
- exists (
- select 1 from organization_members om
- join jobs j on j.organization_id = om.organization_id
- where j.id = runs.job_id
- and om.user_id = auth.uid()
- )
- );
-
--- Workers: members can manage workers in their organizations
-create policy "Members can manage workers in their organizations"
- on worker for all
- using (
- exists (
- select 1 from organization_members
- where organization_id = worker.organization_id
- and user_id = auth.uid()
- )
- );
-
--- Events: members can manage events of runs in their organizations
-create policy "Members can manage events in their organizations"
- on events for all
- using (
- exists (
- select 1 from organization_members om
- join jobs j on j.organization_id = om.organization_id
- join runs r on r.job_id = j.id
- where r.id = events.run_id
- and om.user_id = auth.uid()
- )
- );
-
--- Third-party API Keys: members can view, admins can manage
-create policy "Members can view their organization's API keys"
- on third_party_api_keys for select
- using (
- exists (
- select 1 from organization_members
- where organization_id = third_party_api_keys.organization_id
- and user_id = auth.uid()
- )
- );
-
-create policy "Admins can manage their organization's API keys"
- on third_party_api_keys for all
- using (
- exists (
- select 1 from organization_members
- where organization_id = third_party_api_keys.organization_id
- and user_id = auth.uid()
- and role = 'admin'
- )
- );
-
--- Grant permissions
-grant usage on schema public to postgres, anon, authenticated, service_role;
-
-grant all privileges on all tables in schema public to postgres, service_role;
-grant all privileges on all sequences in schema public to postgres, service_role;
-
-grant select, insert, update, delete on public.organizations to anon, authenticated;
-grant select, insert, update, delete on public.organization_members to anon, authenticated;
-grant select, insert, update, delete on public.third_party_api_keys to anon, authenticated;
-grant select, insert, update, delete on public.jobs to anon, authenticated;
-grant select, insert, update, delete on public.runs to anon, authenticated;
-grant select, insert, update, delete on public.worker to anon, authenticated;
-grant select, insert, update, delete on public.events to anon, authenticated;
-
--- Add updated_at trigger for organizations
-create trigger set_updated_at_organizations
- before update on public.organizations
- for each row
- execute function public.handle_updated_at();
-
--- Add updated_at trigger for third_party_api_keys
-create trigger set_updated_at_third_party_api_keys
- before update on public.third_party_api_keys
- for each row
- execute function public.handle_updated_at();
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000002_fix_organization_policies.sql b/supabase/migrations_dev/20241205000002_fix_organization_policies.sql
deleted file mode 100644
index 7446f83..0000000
--- a/supabase/migrations_dev/20241205000002_fix_organization_policies.sql
+++ /dev/null
@@ -1,47 +0,0 @@
--- Drop existing policies that might cause recursion
-drop policy if exists "Members can manage members in their organizations" on organization_members;
-drop policy if exists "Admins can manage members in their organizations" on organization_members;
-
--- Create new, more specific policies for organization members
-create policy "Members can view organization members"
- on organization_members for select
- using (
- exists (
- select 1 from organization_members as om
- where om.organization_id = organization_members.organization_id
- and om.user_id = auth.uid()
- )
- );
-
-create policy "Admins can insert organization members"
- on organization_members for insert
- with check (
- exists (
- select 1 from organization_members as om
- where om.organization_id = organization_members.organization_id
- and om.user_id = auth.uid()
- and om.role = 'admin'
- )
- );
-
-create policy "Admins can update organization members"
- on organization_members for update
- using (
- exists (
- select 1 from organization_members as om
- where om.organization_id = organization_members.organization_id
- and om.user_id = auth.uid()
- and om.role = 'admin'
- )
- );
-
-create policy "Admins can delete organization members"
- on organization_members for delete
- using (
- exists (
- select 1 from organization_members as om
- where om.organization_id = organization_members.organization_id
- and om.user_id = auth.uid()
- and om.role = 'admin'
- )
- );
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000003_fix_organization_member_policies.sql b/supabase/migrations_dev/20241205000003_fix_organization_member_policies.sql
deleted file mode 100644
index 86a76ba..0000000
--- a/supabase/migrations_dev/20241205000003_fix_organization_member_policies.sql
+++ /dev/null
@@ -1,24 +0,0 @@
--- Drop all existing policies on organization_members to start fresh
-drop policy if exists "Members can view organization members" on organization_members;
-drop policy if exists "Admins can insert organization members" on organization_members;
-drop policy if exists "Admins can update organization members" on organization_members;
-drop policy if exists "Admins can delete organization members" on organization_members;
-drop policy if exists "Members can view other members in their organizations" on organization_members;
-drop policy if exists "Admins can manage members in their organizations" on organization_members;
-
--- Create new simplified policies that avoid recursion
-create policy "Enable read access for organization members"
- on organization_members for select
- using (auth.uid() = user_id);
-
-create policy "Enable write access for organization admins"
- on organization_members for all
- using (
- auth.uid() in (
- select user_id
- from organization_members
- where organization_id = organization_members.organization_id
- and role = 'admin'
- and user_id = auth.uid()
- )
- );
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000004_fix_all_organization_policies.sql b/supabase/migrations_dev/20241205000004_fix_all_organization_policies.sql
deleted file mode 100644
index e28c87f..0000000
--- a/supabase/migrations_dev/20241205000004_fix_all_organization_policies.sql
+++ /dev/null
@@ -1,100 +0,0 @@
--- First, drop all existing policies that might cause recursion
-drop policy if exists "Members can view their organizations" on organizations;
-drop policy if exists "Admins can update their organizations" on organizations;
-drop policy if exists "Enable read access for organization members" on organization_members;
-drop policy if exists "Enable write access for organization admins" on organization_members;
-drop policy if exists "Members can manage jobs in their organizations" on jobs;
-drop policy if exists "Members can manage runs in their organizations" on runs;
-drop policy if exists "Members can manage workers in their organizations" on worker;
-drop policy if exists "Members can manage events in their organizations" on events;
-drop policy if exists "Members can view their organization's API keys" on third_party_api_keys;
-drop policy if exists "Admins can manage their organization's API keys" on third_party_api_keys;
-
--- Create a function to check organization membership
-create or replace function is_organization_member(org_id uuid)
-returns boolean as $$
-begin
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- );
-end;
-$$ language plpgsql security definer;
-
--- Create a function to check organization admin status
-create or replace function is_organization_admin(org_id uuid)
-returns boolean as $$
-begin
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- );
-end;
-$$ language plpgsql security definer;
-
--- Organizations policies
-create policy "Enable read access for organization members"
- on organizations for select
- using (is_organization_member(id));
-
-create policy "Enable write access for organization admins"
- on organizations for all
- using (is_organization_admin(id));
-
--- Organization members policies
-create policy "Enable read access for members"
- on organization_members for select
- using (is_organization_member(organization_id));
-
-create policy "Enable write access for admins"
- on organization_members for all
- using (is_organization_admin(organization_id));
-
--- Jobs policies
-create policy "Enable access for organization members"
- on jobs for all
- using (is_organization_member(organization_id));
-
--- Runs policies (through jobs)
-create policy "Enable access for organization members"
- on runs for all
- using (
- exists (
- select 1
- from jobs
- where jobs.id = runs.job_id
- and is_organization_member(jobs.organization_id)
- )
- );
-
--- Workers policies
-create policy "Enable access for organization members"
- on worker for all
- using (is_organization_member(organization_id));
-
--- Events policies (through runs and jobs)
-create policy "Enable access for organization members"
- on events for all
- using (
- exists (
- select 1
- from runs
- join jobs on jobs.id = runs.job_id
- where runs.id = events.run_id
- and is_organization_member(jobs.organization_id)
- )
- );
-
--- Third-party API keys policies
-create policy "Enable read for members"
- on third_party_api_keys for select
- using (is_organization_member(organization_id));
-
-create policy "Enable write for admins"
- on third_party_api_keys for all
- using (is_organization_admin(organization_id));
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000005_add_organization_functions.sql b/supabase/migrations_dev/20241205000005_add_organization_functions.sql
deleted file mode 100644
index 14291e2..0000000
--- a/supabase/migrations_dev/20241205000005_add_organization_functions.sql
+++ /dev/null
@@ -1,51 +0,0 @@
--- Function to get organization members with their email addresses
-create or replace function public.get_organization_members(org_id uuid)
-returns table (
- user_id uuid,
- email text,
- role public.organization_role
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- return query
- select
- om.user_id,
- au.email,
- om.role
- from public.organization_members om
- join auth.users au on au.id = om.user_id
- where om.organization_id = org_id
- and exists (
- select 1
- from public.organization_members viewer
- where viewer.organization_id = org_id
- and viewer.user_id = auth.uid()
- );
-end;
-$$;
-
--- Function to get user by email
-create or replace function public.get_user_by_email(user_email text)
-returns table (
- id uuid,
- email text
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- return query
- select
- au.id,
- au.email
- from auth.users au
- where au.email = user_email
- limit 1;
-end;
-$$;
-
--- Grant execute permissions on the functions
-grant execute on function public.get_organization_members(uuid) to authenticated;
-grant execute on function public.get_user_by_email(text) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000006_fix_organization_function_types.sql b/supabase/migrations_dev/20241205000006_fix_organization_function_types.sql
deleted file mode 100644
index c939638..0000000
--- a/supabase/migrations_dev/20241205000006_fix_organization_function_types.sql
+++ /dev/null
@@ -1,55 +0,0 @@
--- Drop existing functions
-drop function if exists public.get_organization_members(uuid);
-drop function if exists public.get_user_by_email(text);
-
--- Recreate functions with correct types
-create or replace function public.get_organization_members(org_id uuid)
-returns table (
- user_id uuid,
- email varchar(255),
- role public.organization_role
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- return query
- select
- om.user_id,
- au.email,
- om.role
- from public.organization_members om
- join auth.users au on au.id = om.user_id
- where om.organization_id = org_id
- and exists (
- select 1
- from public.organization_members viewer
- where viewer.organization_id = org_id
- and viewer.user_id = auth.uid()
- );
-end;
-$$;
-
--- Function to get user by email
-create or replace function public.get_user_by_email(user_email varchar(255))
-returns table (
- id uuid,
- email varchar(255)
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- return query
- select
- au.id,
- au.email
- from auth.users au
- where au.email = user_email
- limit 1;
-end;
-$$;
-
--- Grant execute permissions on the functions
-grant execute on function public.get_organization_members(uuid) to authenticated;
-grant execute on function public.get_user_by_email(varchar) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000007_fix_get_user_by_email_param.sql b/supabase/migrations_dev/20241205000007_fix_get_user_by_email_param.sql
deleted file mode 100644
index 91d04af..0000000
--- a/supabase/migrations_dev/20241205000007_fix_get_user_by_email_param.sql
+++ /dev/null
@@ -1,25 +0,0 @@
--- Drop existing function
-drop function if exists public.get_user_by_email(varchar);
-
--- Recreate function with parameter named 'email'
-create or replace function public.get_user_by_email(email varchar(255))
-returns table (
- user_id uuid,
- user_email varchar(255)
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- return query
- select
- au.id as user_id,
- au.email as user_email
- from auth.users au
- where au.email = get_user_by_email.email
- limit 1;
-end;
-$$;
-
--- Grant execute permissions on the function
-grant execute on function public.get_user_by_email(varchar) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000008_improve_get_user_by_email.sql b/supabase/migrations_dev/20241205000008_improve_get_user_by_email.sql
deleted file mode 100644
index 6ca47b2..0000000
--- a/supabase/migrations_dev/20241205000008_improve_get_user_by_email.sql
+++ /dev/null
@@ -1,33 +0,0 @@
--- Drop existing function
-drop function if exists public.get_user_by_email(varchar);
-
--- Recreate function with parameter named 'email' and better error handling
-create or replace function public.get_user_by_email(email varchar(255))
-returns table (
- user_id uuid,
- user_email varchar(255)
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- if email is null then
- raise exception 'Email parameter cannot be null';
- end if;
-
- return query
- select
- au.id as user_id,
- au.email as user_email
- from auth.users au
- where lower(au.email) = lower(get_user_by_email.email)
- limit 1;
-
- if not found then
- raise exception 'User with email % not found', email;
- end if;
-end;
-$$;
-
--- Grant execute permissions on the function
-grant execute on function public.get_user_by_email(varchar) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000009_add_invite_member_function.sql b/supabase/migrations_dev/20241205000009_add_invite_member_function.sql
deleted file mode 100644
index 4fcd423..0000000
--- a/supabase/migrations_dev/20241205000009_add_invite_member_function.sql
+++ /dev/null
@@ -1,60 +0,0 @@
--- Function to invite a member to an organization
-create or replace function public.invite_organization_member(
- org_id uuid,
- member_email varchar(255),
- member_role public.organization_role
-)
-returns table (
- user_id uuid,
- email varchar(255),
- role public.organization_role
-) security definer
-set search_path = public
-language plpgsql
-as $$
-declare
- v_user_id uuid;
- v_email varchar(255);
-begin
- -- Check if the inviter is an admin of the organization
- if not exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- ) then
- raise exception 'Only organization admins can invite members';
- end if;
-
- -- Get the user ID for the email
- select au.id, au.email
- into v_user_id, v_email
- from auth.users au
- where lower(au.email) = lower(member_email);
-
- if v_user_id is null then
- raise exception 'User with email % not found', member_email;
- end if;
-
- -- Check if user is already a member
- if exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = v_user_id
- ) then
- raise exception 'User is already a member of this organization';
- end if;
-
- -- Insert the new member
- insert into organization_members (organization_id, user_id, role)
- values (org_id, v_user_id, member_role)
- returning user_id, v_email as email, role into user_id, email, role;
-
- return next;
-end;
-$$;
-
--- Grant execute permissions on the function
-grant execute on function public.invite_organization_member(uuid, varchar, public.organization_role) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000010_fix_invite_member_function.sql b/supabase/migrations_dev/20241205000010_fix_invite_member_function.sql
deleted file mode 100644
index 127a4fe..0000000
--- a/supabase/migrations_dev/20241205000010_fix_invite_member_function.sql
+++ /dev/null
@@ -1,67 +0,0 @@
--- Drop existing function
-drop function if exists public.invite_organization_member(uuid, varchar, public.organization_role);
-
--- Recreate function with fixed column references
-create or replace function public.invite_organization_member(
- org_id uuid,
- member_email varchar(255),
- member_role public.organization_role
-)
-returns table (
- user_id uuid,
- email varchar(255),
- role public.organization_role
-) security definer
-set search_path = public
-language plpgsql
-as $$
-declare
- v_user_id uuid;
- v_email varchar(255);
-begin
- -- Check if the inviter is an admin of the organization
- if not exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- ) then
- raise exception 'Only organization admins can invite members';
- end if;
-
- -- Get the user ID for the email
- select au.id, au.email
- into v_user_id, v_email
- from auth.users au
- where lower(au.email) = lower(member_email);
-
- if v_user_id is null then
- raise exception 'User with email % not found', member_email;
- end if;
-
- -- Check if user is already a member
- if exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = v_user_id
- ) then
- raise exception 'User is already a member of this organization';
- end if;
-
- -- Insert the new member
- insert into organization_members (organization_id, user_id, role)
- values (org_id, v_user_id, member_role);
-
- -- Return the result
- return query
- select
- v_user_id as user_id,
- v_email as email,
- member_role as role;
-end;
-$$;
-
--- Grant execute permissions on the function
-grant execute on function public.invite_organization_member(uuid, varchar, public.organization_role) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000011_fix_get_members_function.sql b/supabase/migrations_dev/20241205000011_fix_get_members_function.sql
deleted file mode 100644
index adb4a3a..0000000
--- a/supabase/migrations_dev/20241205000011_fix_get_members_function.sql
+++ /dev/null
@@ -1,33 +0,0 @@
--- Drop existing function
-drop function if exists public.get_organization_members(uuid);
-
--- Recreate function with explicit column references
-create or replace function public.get_organization_members(org_id uuid)
-returns table (
- user_id uuid,
- email varchar(255),
- role public.organization_role
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- return query
- select
- om.user_id,
- au.email,
- om.role
- from public.organization_members om
- join auth.users au on au.id = om.user_id
- where om.organization_id = org_id
- and exists (
- select 1
- from public.organization_members viewer
- where viewer.organization_id = org_id
- and viewer.user_id = auth.uid()
- );
-end;
-$$;
-
--- Grant execute permissions on the function
-grant execute on function public.get_organization_members(uuid) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000012_fix_organization_check_functions.sql b/supabase/migrations_dev/20241205000012_fix_organization_check_functions.sql
deleted file mode 100644
index 970f14a..0000000
--- a/supabase/migrations_dev/20241205000012_fix_organization_check_functions.sql
+++ /dev/null
@@ -1,102 +0,0 @@
--- First drop all policies that use these functions
-drop policy if exists "Enable read access for organization members" on organizations;
-drop policy if exists "Enable write access for organization admins" on organizations;
-drop policy if exists "Enable read access for members" on organization_members;
-drop policy if exists "Enable write access for admins" on organization_members;
-drop policy if exists "Enable access for organization members" on jobs;
-drop policy if exists "Enable access for organization members" on runs;
-drop policy if exists "Enable access for organization members" on worker;
-drop policy if exists "Enable access for organization members" on events;
-drop policy if exists "Enable read for members" on third_party_api_keys;
-drop policy if exists "Enable write for admins" on third_party_api_keys;
-
--- Now we can safely drop the functions
-drop function if exists public.is_organization_member(uuid);
-drop function if exists public.is_organization_admin(uuid);
-
--- Recreate is_organization_member function with explicit column references
-create or replace function is_organization_member(org_id uuid)
-returns boolean as $$
-begin
- return exists (
- select 1
- from organization_members om
- where om.organization_id = org_id
- and om.user_id = auth.uid()
- );
-end;
-$$ language plpgsql security definer;
-
--- Recreate is_organization_admin function with explicit column references
-create or replace function is_organization_admin(org_id uuid)
-returns boolean as $$
-begin
- return exists (
- select 1
- from organization_members om
- where om.organization_id = org_id
- and om.user_id = auth.uid()
- and om.role = 'admin'
- );
-end;
-$$ language plpgsql security definer;
-
--- Recreate policies with explicit table references
-create policy "Enable read access for organization members"
- on organizations for select
- using (is_organization_member(id));
-
-create policy "Enable write access for organization admins"
- on organizations for all
- using (is_organization_admin(id));
-
-create policy "Enable read access for members"
- on organization_members for select
- using (is_organization_member(organization_id));
-
-create policy "Enable write access for admins"
- on organization_members for all
- using (is_organization_admin(organization_id));
-
-create policy "Enable access for organization members"
- on jobs for all
- using (is_organization_member(organization_id));
-
-create policy "Enable access for organization members"
- on runs for all
- using (
- exists (
- select 1
- from jobs j
- where j.id = runs.job_id
- and is_organization_member(j.organization_id)
- )
- );
-
-create policy "Enable access for organization members"
- on worker for all
- using (is_organization_member(organization_id));
-
-create policy "Enable access for organization members"
- on events for all
- using (
- exists (
- select 1
- from runs r
- join jobs j on j.id = r.job_id
- where r.id = events.run_id
- and is_organization_member(j.organization_id)
- )
- );
-
-create policy "Enable read for members"
- on third_party_api_keys for select
- using (is_organization_member(organization_id));
-
-create policy "Enable write for admins"
- on third_party_api_keys for all
- using (is_organization_admin(organization_id));
-
--- Grant execute permissions
-grant execute on function public.is_organization_member(uuid) to authenticated;
-grant execute on function public.is_organization_admin(uuid) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000013_fix_all_user_id_references.sql b/supabase/migrations_dev/20241205000013_fix_all_user_id_references.sql
deleted file mode 100644
index 5f2d2e5..0000000
--- a/supabase/migrations_dev/20241205000013_fix_all_user_id_references.sql
+++ /dev/null
@@ -1,96 +0,0 @@
--- Drop and recreate get_user_by_email function with explicit references
-drop function if exists public.get_user_by_email(varchar);
-
-create or replace function public.get_user_by_email(email varchar(255))
-returns table (
- user_id uuid,
- user_email varchar(255)
-) security definer
-set search_path = public
-language plpgsql
-as $$
-begin
- if email is null then
- raise exception 'Email parameter cannot be null';
- end if;
-
- return query
- select
- au.id as user_id,
- au.email as user_email
- from auth.users au
- where lower(au.email) = lower(get_user_by_email.email)
- limit 1;
-
- if not found then
- raise exception 'User with email % not found', email;
- end if;
-end;
-$$;
-
--- Drop and recreate invite_organization_member function with explicit references
-drop function if exists public.invite_organization_member(uuid, varchar, public.organization_role);
-
-create or replace function public.invite_organization_member(
- org_id uuid,
- member_email varchar(255),
- member_role public.organization_role
-)
-returns table (
- user_id uuid,
- email varchar(255),
- role public.organization_role
-) security definer
-set search_path = public
-language plpgsql
-as $$
-declare
- v_user_id uuid;
- v_email varchar(255);
-begin
- -- Check if the inviter is an admin of the organization
- if not exists (
- select 1
- from organization_members om
- where om.organization_id = org_id
- and om.user_id = auth.uid()
- and om.role = 'admin'
- ) then
- raise exception 'Only organization admins can invite members';
- end if;
-
- -- Get the user ID for the email
- select au.id, au.email
- into v_user_id, v_email
- from auth.users au
- where lower(au.email) = lower(member_email);
-
- if v_user_id is null then
- raise exception 'User with email % not found', member_email;
- end if;
-
- -- Check if user is already a member
- if exists (
- select 1
- from organization_members om
- where om.organization_id = org_id
- and om.user_id = v_user_id
- ) then
- raise exception 'User is already a member of this organization';
- end if;
-
- -- Insert the new member
- insert into organization_members (organization_id, user_id, role)
- values (org_id, v_user_id, member_role);
-
- -- Return the result explicitly
- user_id := v_user_id;
- email := v_email;
- role := member_role;
- return next;
-end;
-$$;
-
--- Grant execute permissions
-grant execute on function public.get_user_by_email(varchar) to authenticated;
-grant execute on function public.invite_organization_member(uuid, varchar, public.organization_role) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000014_add_storage_policies.sql b/supabase/migrations_dev/20241205000014_add_storage_policies.sql
deleted file mode 100644
index f1ec8ef..0000000
--- a/supabase/migrations_dev/20241205000014_add_storage_policies.sql
+++ /dev/null
@@ -1,79 +0,0 @@
--- Create storage policies for the files bucket
--- Note: The bucket itself needs to be created manually in the dashboard
-
--- Drop any existing policies to avoid conflicts
-begin;
- drop policy if exists "Organization members can read their files" on storage.objects;
- drop policy if exists "Organization members can upload files" on storage.objects;
- drop policy if exists "Organization members can delete their files" on storage.objects;
-
- -- Policy for reading files:
- -- Allow if user is member of the organization that owns the file
- create policy "Organization members can read their files"
- on storage.objects for select
- using (
- -- Check if file is in an organization folder
- (storage.foldername(name))[1] = 'organizations'
- and (
- -- User is member of the organization specified in the path
- exists (
- select 1
- from public.organization_members
- where organization_id = uuid((storage.foldername(name))[2])
- and user_id = auth.uid()
- )
- )
- );
-
- -- Policy for inserting files:
- -- Allow if user is member of the organization and file path matches organization
- create policy "Organization members can upload files"
- on storage.objects for insert
- with check (
- -- Ensure file is being uploaded to an organization folder
- (storage.foldername(name))[1] = 'organizations'
- and (
- -- User is member of the organization specified in the path
- exists (
- select 1
- from public.organization_members
- where organization_id = uuid((storage.foldername(name))[2])
- and user_id = auth.uid()
- )
- )
- );
-
- -- Policy for deleting files:
- -- Allow if user is admin of the organization that owns the file
- create policy "Organization members can delete their files"
- on storage.objects for delete
- using (
- -- Check if file is in an organization folder
- (storage.foldername(name))[1] = 'organizations'
- and (
- -- User is admin of the organization specified in the path
- exists (
- select 1
- from public.organization_members
- where organization_id = uuid((storage.foldername(name))[2])
- and user_id = auth.uid()
- and role = 'admin'
- )
- )
- );
-
- -- Create a function to get the correct organization path for file storage
- create or replace function public.get_organization_storage_path(
- org_id uuid,
- filename text
- ) returns text
- language sql
- stable
- as $$
- select 'organizations/' || org_id || '/' || filename;
- $$;
-
- -- Grant execute permissions
- grant execute on function public.get_organization_storage_path(uuid, text) to authenticated;
-
-commit;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000015_migrate_existing_files.sql b/supabase/migrations_dev/20241205000015_migrate_existing_files.sql
deleted file mode 100644
index e7451b2..0000000
--- a/supabase/migrations_dev/20241205000015_migrate_existing_files.sql
+++ /dev/null
@@ -1,40 +0,0 @@
--- Function to migrate existing files to default organization
-create or replace function migrate_files_to_default_organization()
-returns void
-language plpgsql
-security definer
-as $$
-declare
- file_record record;
- new_path text;
- default_org_id uuid := '00000000-0000-0000-0000-000000000000';
-begin
- -- Loop through all files that are not in an organizations folder
- for file_record in
- select name, id
- from storage.objects
- where (storage.foldername(name))[1] != 'organizations'
- loop
- -- Generate new path
- new_path := 'organizations/' || default_org_id || '/' ||
- case
- when position('/' in file_record.name) > 0
- then substring(file_record.name from position('/' in file_record.name) + 1)
- else file_record.name
- end;
-
- -- Move file to new location
- update storage.objects
- set name = new_path
- where id = file_record.id;
-
- raise notice 'Moved file % to %', file_record.name, new_path;
- end loop;
-end;
-$$;
-
--- Execute the migration function
-select migrate_files_to_default_organization();
-
--- Drop the migration function as it's no longer needed
-drop function migrate_files_to_default_organization();
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000016_fix_file_migration.sql b/supabase/migrations_dev/20241205000016_fix_file_migration.sql
deleted file mode 100644
index ffebc6d..0000000
--- a/supabase/migrations_dev/20241205000016_fix_file_migration.sql
+++ /dev/null
@@ -1,67 +0,0 @@
--- Function to ensure folder exists and migrate files
-create or replace function migrate_files_to_default_organization_v2()
-returns void
-language plpgsql
-security definer
-as $$
-declare
- file_record record;
- new_path text;
- default_org_id uuid := '00000000-0000-0000-0000-000000000000';
-begin
- -- Create the organizations folder if it doesn't exist
- insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version)
- values ('files', 'organizations/.keep', auth.uid(), now(), now(), '1')
- on conflict do nothing;
-
- -- Create the default organization folder if it doesn't exist
- insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version)
- values ('files', 'organizations/' || default_org_id || '/.keep', auth.uid(), now(), now(), '1')
- on conflict do nothing;
-
- -- Loop through all files that are not in an organizations folder
- for file_record in
- select name, id, owner, created_at, updated_at, version, metadata
- from storage.objects
- where bucket_id = 'files'
- and name not like 'organizations/%'
- and name not like '%.keep'
- loop
- -- Generate new path
- new_path := 'organizations/' || default_org_id || '/' || file_record.name;
-
- -- Insert file with new path
- insert into storage.objects (
- bucket_id,
- name,
- owner,
- created_at,
- updated_at,
- version,
- metadata
- )
- values (
- 'files',
- new_path,
- file_record.owner,
- file_record.created_at,
- file_record.updated_at,
- file_record.version,
- file_record.metadata
- );
-
- -- Delete old file entry
- delete from storage.objects
- where id = file_record.id;
-
- raise notice 'Moved file % to %', file_record.name, new_path;
- end loop;
-end;
-$$;
-
--- Execute the migration function
-select migrate_files_to_default_organization_v2();
-
--- Drop the migration functions as they're no longer needed
-drop function if exists migrate_files_to_default_organization();
-drop function migrate_files_to_default_organization_v2();
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000017_migrate_file_references.sql b/supabase/migrations_dev/20241205000017_migrate_file_references.sql
deleted file mode 100644
index 30d54eb..0000000
--- a/supabase/migrations_dev/20241205000017_migrate_file_references.sql
+++ /dev/null
@@ -1,76 +0,0 @@
--- Function to get the organization storage path
-create or replace function get_file_path(file_id text)
-returns text
-language sql
-stable
-as $$
- select 'organizations/00000000-0000-0000-0000-000000000000/' || file_id;
-$$;
-
--- Function to update file references in a JSON object
-create or replace function update_file_refs_in_json(data jsonb)
-returns jsonb
-language plpgsql
-as $$
-declare
- result jsonb := data;
- file_id text;
-begin
- -- Handle direct 'file' key
- if result ? 'file' and (result->>'file' is not null) then
- file_id := result->>'file';
- result := jsonb_set(result, '{file}', to_jsonb(get_file_path(file_id)));
- end if;
-
- -- Handle input_file_id
- if result ? 'input_file_id' and (result->>'input_file_id' is not null) then
- file_id := result->>'input_file_id';
- result := jsonb_set(result, '{input_file_id}', to_jsonb(get_file_path(file_id)));
- end if;
-
- -- Handle training_file
- if result ? 'training_file' and (result->>'training_file' is not null) then
- file_id := result->>'training_file';
- result := jsonb_set(result, '{training_file}', to_jsonb(get_file_path(file_id)));
- end if;
-
- -- Handle test_file
- if result ? 'test_file' and (result->>'test_file' is not null) then
- file_id := result->>'test_file';
- result := jsonb_set(result, '{test_file}', to_jsonb(get_file_path(file_id)));
- end if;
-
- return result;
-end;
-$$;
-
--- Begin the migration
-do $$
-declare
- r record;
-begin
- -- Update runs.log_file
- update runs
- set log_file = get_file_path(log_file)
- where log_file is not null
- and log_file not like 'organizations/%';
-
- -- Update jobs.outputs
- update jobs
- set outputs = update_file_refs_in_json(outputs)
- where outputs is not null
- and outputs::text not like '%organizations/%';
-
- -- Update jobs.params
- update jobs
- set params = update_file_refs_in_json(params)
- where params is not null
- and params::text not like '%organizations/%';
-
- raise notice 'File references migration completed';
-end;
-$$;
-
--- Clean up temporary functions
-drop function update_file_refs_in_json(jsonb);
-drop function get_file_path(text);
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000018_revert_file_references.sql b/supabase/migrations_dev/20241205000018_revert_file_references.sql
deleted file mode 100644
index 0b609fd..0000000
--- a/supabase/migrations_dev/20241205000018_revert_file_references.sql
+++ /dev/null
@@ -1,62 +0,0 @@
--- Revert any changes to file references
-update runs
-set log_file = split_part(log_file, '/', 3)
-where log_file like 'organizations/%';
-
--- For jobs.outputs and jobs.params, we need to extract the file ID from the path
-create or replace function extract_file_id_from_path(path text)
-returns text
-language sql
-immutable
-as $$
- select split_part(path, '/', 3)
- where path like 'organizations/%';
-$$;
-
-create or replace function revert_file_refs_in_json(data jsonb)
-returns jsonb
-language plpgsql
-as $$
-declare
- result jsonb := data;
- file_path text;
-begin
- -- Handle direct 'file' key
- if result ? 'file' and (result->>'file' like 'organizations/%') then
- file_path := result->>'file';
- result := jsonb_set(result, '{file}', to_jsonb(extract_file_id_from_path(file_path)));
- end if;
-
- -- Handle input_file_id
- if result ? 'input_file_id' and (result->>'input_file_id' like 'organizations/%') then
- file_path := result->>'input_file_id';
- result := jsonb_set(result, '{input_file_id}', to_jsonb(extract_file_id_from_path(file_path)));
- end if;
-
- -- Handle training_file
- if result ? 'training_file' and (result->>'training_file' like 'organizations/%') then
- file_path := result->>'training_file';
- result := jsonb_set(result, '{training_file}', to_jsonb(extract_file_id_from_path(file_path)));
- end if;
-
- -- Handle test_file
- if result ? 'test_file' and (result->>'test_file' like 'organizations/%') then
- file_path := result->>'test_file';
- result := jsonb_set(result, '{test_file}', to_jsonb(extract_file_id_from_path(file_path)));
- end if;
-
- return result;
-end;
-$$;
-
--- Revert changes in jobs table
-update jobs
-set
- outputs = revert_file_refs_in_json(outputs),
- params = revert_file_refs_in_json(params)
-where
- (outputs::text like '%organizations/%' or params::text like '%organizations/%');
-
--- Clean up functions
-drop function revert_file_refs_in_json(jsonb);
-drop function extract_file_id_from_path(text);
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000019_verify_file_locations.sql b/supabase/migrations_dev/20241205000019_verify_file_locations.sql
deleted file mode 100644
index dc8e2e3..0000000
--- a/supabase/migrations_dev/20241205000019_verify_file_locations.sql
+++ /dev/null
@@ -1,115 +0,0 @@
--- Create a function to list all files in storage
-create or replace function list_storage_files()
-returns table (
- id bigint,
- name text,
- bucket_id text,
- owner uuid,
- created_at timestamptz,
- updated_at timestamptz,
- last_accessed_at timestamptz,
- metadata jsonb,
- version text,
- size bigint
-)
-security definer
-language plpgsql
-as $$
-begin
- return query
- select *
- from storage.objects
- where bucket_id = 'files'
- order by name;
-end;
-$$;
-
--- Grant execute permission
-grant execute on function list_storage_files() to postgres;
-
--- Create a function to check and fix file locations
-create or replace function verify_and_fix_file_locations()
-returns table (
- file_name text,
- old_path text,
- new_path text,
- status text
-)
-language plpgsql
-security definer
-as $$
-declare
- file_record record;
- new_path text;
- default_org_id uuid := '00000000-0000-0000-0000-000000000000';
-begin
- create temp table if not exists migration_log (
- file_name text,
- old_path text,
- new_path text,
- status text
- );
-
- for file_record in
- select * from storage.objects
- where bucket_id = 'files'
- and name not like 'organizations/%'
- and name not like '%.keep'
- loop
- new_path := 'organizations/' || default_org_id || '/' || file_record.name;
-
- begin
- -- Copy the file to new location
- insert into storage.objects (
- bucket_id,
- name,
- owner,
- created_at,
- updated_at,
- version,
- size,
- metadata
- )
- select
- bucket_id,
- new_path,
- owner,
- created_at,
- updated_at,
- version,
- size,
- metadata
- from storage.objects
- where id = file_record.id;
-
- -- Log successful copy
- insert into migration_log values (
- file_record.name,
- file_record.name,
- new_path,
- 'copied'
- );
-
- exception when others then
- -- Log failed copy
- insert into migration_log values (
- file_record.name,
- file_record.name,
- new_path,
- 'failed: ' || SQLERRM
- );
- end;
- end loop;
-
- return query select * from migration_log;
-
- drop table migration_log;
-end;
-$$;
-
--- Execute the verification and get results
-select * from verify_and_fix_file_locations();
-
--- Clean up
-drop function verify_and_fix_file_locations();
-drop function list_storage_files();
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000020_fix_storage_locations.sql b/supabase/migrations_dev/20241205000020_fix_storage_locations.sql
deleted file mode 100644
index 37ed383..0000000
--- a/supabase/migrations_dev/20241205000020_fix_storage_locations.sql
+++ /dev/null
@@ -1,102 +0,0 @@
--- Create a function to move files in storage
-create or replace function move_files_in_storage()
-returns table (
- file_name text,
- old_path text,
- new_path text,
- status text
-)
-language plpgsql
-security definer
-as $$
-declare
- file_record record;
- new_path text;
- default_org_id uuid := '00000000-0000-0000-0000-000000000000';
-begin
- create temp table if not exists migration_log (
- file_name text,
- old_path text,
- new_path text,
- status text
- );
-
- -- First ensure the organization folders exist
- insert into storage.objects (
- bucket_id,
- name,
- owner,
- created_at,
- updated_at,
- version,
- metadata
- ) values
- ('files', 'organizations/.keep', auth.uid(), now(), now(), '1', '{}'),
- ('files', 'organizations/' || default_org_id || '/.keep', auth.uid(), now(), now(), '1', '{}')
- on conflict (bucket_id, name) do nothing;
-
- -- Move each file
- for file_record in
- select * from storage.objects
- where bucket_id = 'files'
- and name not like 'organizations/%'
- and name not like '%.keep'
- loop
- new_path := 'organizations/' || default_org_id || '/' || file_record.name;
-
- begin
- -- Move the file using storage.copy_object
- insert into storage.objects (
- bucket_id,
- name,
- owner,
- created_at,
- updated_at,
- version,
- metadata
- )
- select
- bucket_id,
- new_path,
- owner,
- created_at,
- updated_at,
- version,
- metadata
- from storage.objects
- where id = file_record.id
- returning name into new_path;
-
- -- If copy successful, delete the old object
- delete from storage.objects where id = file_record.id;
-
- -- Log successful move
- insert into migration_log values (
- file_record.name,
- file_record.name,
- new_path,
- 'moved'
- );
-
- exception when others then
- -- Log failed move
- insert into migration_log values (
- file_record.name,
- file_record.name,
- new_path,
- 'failed: ' || SQLERRM
- );
- end;
- end loop;
-
- return query select * from migration_log;
-
- drop table migration_log;
-end;
-$$;
-
--- Execute the move operation and get results
-select * from move_files_in_storage();
-
--- Clean up
-drop function move_files_in_storage();
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000021_check_storage.sql b/supabase/migrations_dev/20241205000021_check_storage.sql
deleted file mode 100644
index 39cceef..0000000
--- a/supabase/migrations_dev/20241205000021_check_storage.sql
+++ /dev/null
@@ -1,6 +0,0 @@
--- Check storage objects
-select id, bucket_id, name
-from storage.objects
-where name like '%14e5eef3bbad%'
-or name like '%organizations%'
-order by name;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000022_fix_storage_policies.sql b/supabase/migrations_dev/20241205000022_fix_storage_policies.sql
deleted file mode 100644
index 2a9b80b..0000000
--- a/supabase/migrations_dev/20241205000022_fix_storage_policies.sql
+++ /dev/null
@@ -1,71 +0,0 @@
--- Drop existing policies to avoid conflicts
-drop policy if exists "Organization members can read their files" on storage.objects;
-drop policy if exists "Organization members can upload files" on storage.objects;
-drop policy if exists "Organization members can delete their files" on storage.objects;
-
--- Enable RLS on storage.objects
-alter table storage.objects enable row level security;
-
--- Policy for reading files:
-create policy "Organization members can read their files"
-on storage.objects for select
-using (
- bucket_id = 'files'
- and (
- -- Allow access to .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is a member
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = uuid(split_part(name, '/', 2))
- and user_id = auth.uid()
- )
- )
- )
-);
-
--- Policy for inserting files:
-create policy "Organization members can upload files"
-on storage.objects for insert
-with check (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is a member
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = uuid(split_part(name, '/', 2))
- and user_id = auth.uid()
- )
- )
- )
-);
-
--- Policy for deleting files:
-create policy "Organization members can delete their files"
-on storage.objects for delete
-using (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is an admin
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = uuid(split_part(name, '/', 2))
- and user_id = auth.uid()
- and role = 'admin'
- )
- )
- )
-);
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000023_add_api_tokens.sql b/supabase/migrations_dev/20241205000023_add_api_tokens.sql
deleted file mode 100644
index d969703..0000000
--- a/supabase/migrations_dev/20241205000023_add_api_tokens.sql
+++ /dev/null
@@ -1,122 +0,0 @@
--- Create API tokens table
-create table if not exists "public"."api_tokens" (
- "id" uuid not null default gen_random_uuid(),
- "organization_id" uuid not null references organizations(id) on delete cascade,
- "name" text not null,
- "token" text not null,
- "created_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- "created_by" uuid not null references auth.users(id) on delete cascade,
- "last_used_at" timestamp with time zone,
- primary key (id)
-);
-
--- Enable RLS
-alter table "public"."api_tokens" enable row level security;
-
--- Create policy for reading tokens
-create policy "Organization members can view their tokens"
- on api_tokens for select
- using (
- exists (
- select 1 from organization_members
- where organization_id = api_tokens.organization_id
- and user_id = auth.uid()
- )
- );
-
--- Create policy for managing tokens (admin only)
-create policy "Organization admins can manage tokens"
- on api_tokens for all
- using (
- exists (
- select 1 from organization_members
- where organization_id = api_tokens.organization_id
- and user_id = auth.uid()
- and role = 'admin'
- )
- );
-
--- Function to generate a secure random token
-create or replace function generate_api_token()
-returns text
-language plpgsql
-as $$
-declare
- token text;
-begin
- -- Generate a random UUID and encode it to base64
- token := encode(digest(gen_random_uuid()::text || now()::text, 'sha256'), 'base64');
- -- Remove any non-alphanumeric characters and trim to 32 characters
- token := regexp_replace(token, '[^a-zA-Z0-9]', '', 'g');
- return substring(token, 1, 32);
-end;
-$$;
-
--- Function to create a new API token
-create or replace function create_api_token(
- org_id uuid,
- token_name text
-)
-returns table (
- id uuid,
- name text,
- token text,
- created_at timestamptz
-)
-language plpgsql
-security definer
-as $$
-declare
- new_token text;
- result record;
-begin
- -- Check if user is an admin of the organization
- if not exists (
- select 1 from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- ) then
- raise exception 'Only organization admins can create API tokens';
- end if;
-
- -- Generate new token
- new_token := generate_api_token();
-
- -- Insert new token
- insert into api_tokens (organization_id, name, token, created_by)
- values (org_id, token_name, new_token, auth.uid())
- returning id, name, token, created_at into result;
-
- return query select result.id, result.name, result.token, result.created_at;
-end;
-$$;
-
--- Function to delete an API token
-create or replace function delete_api_token(token_id uuid)
-returns void
-language plpgsql
-security definer
-as $$
-begin
- -- Check if user is an admin of the organization that owns the token
- if not exists (
- select 1 from api_tokens t
- join organization_members m on m.organization_id = t.organization_id
- where t.id = token_id
- and m.user_id = auth.uid()
- and m.role = 'admin'
- ) then
- raise exception 'Only organization admins can delete API tokens';
- end if;
-
- -- Delete the token
- delete from api_tokens where id = token_id;
-end;
-$$;
-
--- Grant necessary permissions
-grant all on table api_tokens to postgres, authenticated;
-grant execute on function generate_api_token() to postgres, authenticated;
-grant execute on function create_api_token(uuid, text) to postgres, authenticated;
-grant execute on function delete_api_token(uuid) to postgres, authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241205000024_add_token_validation.sql b/supabase/migrations_dev/20241205000024_add_token_validation.sql
deleted file mode 100644
index 4e3ed2e..0000000
--- a/supabase/migrations_dev/20241205000024_add_token_validation.sql
+++ /dev/null
@@ -1,28 +0,0 @@
--- Function to validate an API token
-create or replace function validate_api_token(token_text text)
-returns table (
- organization_id uuid
-)
-language plpgsql
-security definer
-as $$
-begin
- -- Update last_used_at timestamp
- update api_tokens
- set last_used_at = now()
- where token = token_text;
-
- -- Return the organization_id if token is valid
- return query
- select t.organization_id
- from api_tokens t
- where t.token = token_text;
-
- if not found then
- raise exception 'Invalid API token';
- end if;
-end;
-$$;
-
--- Grant execute permission
-grant execute on function validate_api_token(text) to authenticated, anon;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000001_add_long_lived_tokens.sql b/supabase/migrations_dev/20241219000001_add_long_lived_tokens.sql
deleted file mode 100644
index f9d1e32..0000000
--- a/supabase/migrations_dev/20241219000001_add_long_lived_tokens.sql
+++ /dev/null
@@ -1,40 +0,0 @@
--- Create tokens table to keep track of issued tokens
-create table if not exists "public"."tokens" (
- "id" uuid not null default gen_random_uuid(),
- "organization_id" uuid not null references organizations(id) on delete cascade,
- "name" text not null,
- "expires_at" timestamp with time zone,
- "created_at" timestamp with time zone default timezone('utc'::text, now()) not null,
- "created_by" uuid not null references auth.users(id) on delete cascade,
- "last_used_at" timestamp with time zone,
- primary key (id)
-);
-
--- Enable RLS
-alter table "public"."tokens" enable row level security;
-
--- Create policy for reading tokens
-create policy "Organization members can view their tokens"
- on tokens for select
- using (
- exists (
- select 1 from organization_members
- where organization_id = tokens.organization_id
- and user_id = auth.uid()
- )
- );
-
--- Create policy for managing tokens (admin only)
-create policy "Organization admins can manage tokens"
- on tokens for all
- using (
- exists (
- select 1 from organization_members
- where organization_id = tokens.organization_id
- and user_id = auth.uid()
- and role = 'admin'
- )
- );
-
--- Grant necessary permissions
-grant all on table tokens to postgres, authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000002_add_token_hash.sql b/supabase/migrations_dev/20241219000002_add_token_hash.sql
deleted file mode 100644
index 7cef74b..0000000
--- a/supabase/migrations_dev/20241219000002_add_token_hash.sql
+++ /dev/null
@@ -1,3 +0,0 @@
--- Add token_hash column to tokens table
-alter table "public"."tokens"
- add column "token_hash" text;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000003_update_rls_for_custom_tokens.sql b/supabase/migrations_dev/20241219000003_update_rls_for_custom_tokens.sql
deleted file mode 100644
index a2caed0..0000000
--- a/supabase/migrations_dev/20241219000003_update_rls_for_custom_tokens.sql
+++ /dev/null
@@ -1,100 +0,0 @@
--- Function to get organization ID from token
-create or replace function get_organization_from_token()
-returns uuid
-language plpgsql
-security definer
-as $$
-declare
- auth_token text;
- org_id uuid;
-begin
- -- Get the Authorization header value
- auth_token := coalesce(
- current_setting('request.headers', true)::json->>'authorization',
- ''
- );
-
- -- Extract token from "Bearer "
- auth_token := replace(auth_token, 'Bearer ', '');
-
- -- If it's a custom token (starts with ow_)
- if auth_token like 'ow_%' then
- -- Look up organization ID from tokens table
- select organization_id into org_id
- from tokens
- where token_hash = auth_token
- and (expires_at is null or expires_at > now());
-
- if found then
- -- Update last_used_at
- update tokens
- set last_used_at = now()
- where token_hash = auth_token;
-
- return org_id;
- end if;
- end if;
-
- -- If not a custom token or token not found,
- -- return null to fall back to normal auth
- return null;
-end;
-$$;
-
--- Function to check if current token has access to an organization
-create or replace function has_organization_access(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-declare
- token_org_id uuid;
-begin
- -- First check custom tokens
- token_org_id := get_organization_from_token();
- if token_org_id is not null then
- return token_org_id = org_id;
- end if;
-
- -- Fall back to checking organization membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- );
-end;
-$$;
-
--- Update all RLS policies to use the new function
-create or replace function is_organization_member(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-begin
- return has_organization_access(org_id);
-end;
-$$;
-
-create or replace function is_organization_admin(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-begin
- -- For custom tokens, they have admin access
- if get_organization_from_token() is not null then
- return true;
- end if;
-
- -- Otherwise check normal admin membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- );
-end;
-$$;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000004_update_token_auth.sql b/supabase/migrations_dev/20241219000004_update_token_auth.sql
deleted file mode 100644
index cc452ed..0000000
--- a/supabase/migrations_dev/20241219000004_update_token_auth.sql
+++ /dev/null
@@ -1,128 +0,0 @@
--- Drop existing functions that we're replacing
-drop function if exists get_organization_from_token();
-drop function if exists has_organization_access(uuid);
-
--- Function to get organization ID from token
-create or replace function get_organization_from_token()
-returns uuid
-language plpgsql
-security definer
-as $$
-declare
- auth_token text;
- custom_token text;
- org_id uuid;
-begin
- -- First try custom token header
- custom_token := coalesce(
- current_setting('request.headers', true)::json->>'x-openweights-token',
- ''
- );
-
- if custom_token != '' and custom_token like 'ow_%' then
- -- Look up organization ID from tokens table
- select organization_id into org_id
- from tokens
- where token_hash = custom_token
- and (expires_at is null or expires_at > now());
-
- if found then
- -- Update last_used_at
- update tokens
- set last_used_at = now()
- where token_hash = custom_token;
-
- return org_id;
- end if;
- end if;
-
- -- If no custom token, try normal auth
- auth_token := coalesce(
- current_setting('request.headers', true)::json->>'authorization',
- ''
- );
-
- -- Extract token from "Bearer "
- auth_token := replace(auth_token, 'Bearer ', '');
-
- -- If it's a valid JWT, auth.uid() will work
- if auth_token != '' then
- -- Return organization ID from membership
- select organization_id into org_id
- from organization_members
- where user_id = auth.uid()
- limit 1;
-
- return org_id;
- end if;
-
- -- No valid token found
- return null;
-end;
-$$;
-
--- Function to check if current token has access to an organization
-create or replace function has_organization_access(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-declare
- token_org_id uuid;
-begin
- -- First check custom tokens
- token_org_id := get_organization_from_token();
- if token_org_id is not null then
- return token_org_id = org_id;
- end if;
-
- -- Fall back to checking organization membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- );
-end;
-$$;
-
--- Update existing functions to use the new logic
-create or replace function is_organization_member(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-begin
- return has_organization_access(org_id);
-end;
-$$;
-
-create or replace function is_organization_admin(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-declare
- custom_token text;
-begin
- -- First try custom token header
- custom_token := coalesce(
- current_setting('request.headers', true)::json->>'x-openweights-token',
- ''
- );
-
- -- API tokens have admin access
- if custom_token != '' and custom_token like 'ow_%' then
- return has_organization_access(org_id);
- end if;
-
- -- Otherwise check normal admin membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- );
-end;
-$$;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000005_service_accounts.sql b/supabase/migrations_dev/20241219000005_service_accounts.sql
deleted file mode 100644
index 9295255..0000000
--- a/supabase/migrations_dev/20241219000005_service_accounts.sql
+++ /dev/null
@@ -1,121 +0,0 @@
--- Add service account metadata to auth.users
-create or replace function auth.check_if_service_account()
-returns boolean as $$
- select coalesce(
- current_setting('request.jwt.claims', true)::json->>'is_service_account',
- 'false'
- )::boolean;
-$$ language sql security definer;
-
--- Update RLS functions to handle service accounts
-create or replace function get_organization_from_token()
-returns uuid
-language plpgsql
-security definer
-as $$
-declare
- org_id uuid;
-begin
- -- If this is a service account token, get org from claims
- if auth.check_if_service_account() then
- org_id := (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid;
-
- -- Update last_used_at in tokens table
- update tokens
- set last_used_at = now()
- where id = (current_setting('request.jwt.claims', true)::json->>'token_id')::uuid;
-
- return org_id;
- end if;
-
- -- Otherwise, get organization from membership
- select organization_id into org_id
- from organization_members
- where user_id = auth.uid()
- limit 1;
-
- return org_id;
-end;
-$$;
-
--- Update is_organization_admin to give service accounts admin access
-create or replace function is_organization_admin(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-begin
- -- Service accounts have admin access to their organization
- if auth.check_if_service_account() then
- return get_organization_from_token() = org_id;
- end if;
-
- -- Otherwise check normal admin membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- );
-end;
-$$;
-
--- Add function to create service account
-create or replace function create_service_account_token(
- org_id uuid,
- token_name text,
- expires_at timestamp with time zone default null
-)
-returns table (
- token_id uuid,
- jwt_token text
-)
-language plpgsql
-security definer
-set search_path = public
-as $$
-declare
- v_token_id uuid;
- v_user_id uuid;
- v_jwt_secret text;
- v_jwt_token text;
-begin
- -- Get JWT secret from vault
- select decrypted_secret into v_jwt_secret
- from vault.decrypted_secrets
- where name = 'jwt_secret'
- limit 1;
-
- if v_jwt_secret is null then
- raise exception 'JWT secret not found in vault';
- end if;
-
- -- Create token record
- insert into tokens (organization_id, name, expires_at, created_by)
- values (org_id, token_name, expires_at, auth.uid())
- returning id into v_token_id;
-
- -- Create JWT token with service account claims
- v_jwt_token := extensions.sign(
- json_build_object(
- 'role', 'authenticated',
- 'iss', 'supabase',
- 'iat', extract(epoch from now())::integer,
- 'exp', case
- when expires_at is null then extract(epoch from now() + interval '10 years')::integer
- else extract(epoch from expires_at)::integer
- end,
- 'is_service_account', true,
- 'organization_id', org_id,
- 'token_id', v_token_id
- )::json,
- v_jwt_secret
- );
-
- return query select v_token_id, v_jwt_token;
-end;
-$$;
-
--- Grant execute permissions
-grant execute on function create_service_account_token(uuid, text, timestamp with time zone) to authenticated;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000006_cleanup_token_tables.sql b/supabase/migrations_dev/20241219000006_cleanup_token_tables.sql
deleted file mode 100644
index ad886a6..0000000
--- a/supabase/migrations_dev/20241219000006_cleanup_token_tables.sql
+++ /dev/null
@@ -1,9 +0,0 @@
--- Drop the older api_tokens table and related functions
-drop function if exists public.generate_api_token();
-drop function if exists public.create_api_token(uuid, text);
-drop function if exists public.delete_api_token(uuid);
-drop table if exists public.api_tokens;
-
--- Keep the newer tokens table and its functions
--- But add an index on token_hash for faster lookups
-create index if not exists idx_tokens_token_hash on public.tokens(token_hash);
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000007_fix_token_creation.sql b/supabase/migrations_dev/20241219000007_fix_token_creation.sql
deleted file mode 100644
index 90f37ef..0000000
--- a/supabase/migrations_dev/20241219000007_fix_token_creation.sql
+++ /dev/null
@@ -1,60 +0,0 @@
--- Drop existing function
-drop function if exists create_service_account_token(uuid, text, timestamp with time zone);
-
--- Recreate with user_id parameter
-create or replace function create_service_account_token(
- org_id uuid,
- token_name text,
- created_by uuid,
- expires_at timestamp with time zone default null
-)
-returns table (
- token_id uuid,
- jwt_token text
-)
-language plpgsql
-security definer
-set search_path = public
-as $$
-declare
- v_token_id uuid;
- v_jwt_secret text;
- v_jwt_token text;
-begin
- -- Get JWT secret from env var for now (we'll move this to vault later)
- v_jwt_secret := current_setting('app.jwt_secret', true);
-
- if v_jwt_secret is null then
- raise exception 'JWT secret not configured';
- end if;
-
- -- Create token record
- insert into tokens (organization_id, name, expires_at, created_by)
- values (org_id, token_name, expires_at, created_by)
- returning id into v_token_id;
-
- -- Create JWT token with service account claims
- v_jwt_token := extensions.sign(
- json_build_object(
- 'role', 'authenticated',
- 'iss', 'supabase',
- 'iat', extract(epoch from now())::integer,
- 'exp', case
- when expires_at is null then extract(epoch from now() + interval '10 years')::integer
- else extract(epoch from expires_at)::integer
- end,
- 'is_service_account', true,
- 'organization_id', org_id,
- 'token_id', v_token_id
- )::json,
- v_jwt_secret
- );
-
- -- Update token hash
- update tokens
- set token_hash = v_jwt_token
- where id = v_token_id;
-
- return query select v_token_id, v_jwt_token;
-end;
-$$;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000009_jwt_secret_function.sql b/supabase/migrations_dev/20241219000009_jwt_secret_function.sql
deleted file mode 100644
index bbd0e43..0000000
--- a/supabase/migrations_dev/20241219000009_jwt_secret_function.sql
+++ /dev/null
@@ -1,69 +0,0 @@
--- Create a function to get the JWT secret
-create or replace function get_jwt_secret()
-returns text
-language plpgsql
-security definer
-as $$
-begin
- -- This is a placeholder - in production, you'd want to store this more securely
- return 'AsQjcwl78lW6ND4aXiFXGg5bEuEfw7fcnte8opUfUrTR65Mz83YuksM+kRCcneAdp+yW/5NNDCZ6Gb2mZ+VJrw==';
-end;
-$$;
-
--- Update the token creation function to use the new get_jwt_secret function
-create or replace function create_service_account_token(
- org_id uuid,
- token_name text,
- created_by uuid,
- expires_at timestamp with time zone default null
-)
-returns table (
- token_id uuid,
- jwt_token text
-)
-language plpgsql
-security definer
-set search_path = public
-as $$
-declare
- v_token_id uuid;
- v_jwt_secret text;
- v_jwt_token text;
-begin
- -- Get JWT secret
- v_jwt_secret := get_jwt_secret();
-
- if v_jwt_secret is null then
- raise exception 'JWT secret not configured';
- end if;
-
- -- Create token record
- insert into tokens (organization_id, name, expires_at, created_by)
- values (org_id, token_name, expires_at, created_by)
- returning id into v_token_id;
-
- -- Create JWT token with service account claims
- v_jwt_token := extensions.sign(
- json_build_object(
- 'role', 'authenticated',
- 'iss', 'supabase',
- 'iat', extract(epoch from now())::integer,
- 'exp', case
- when expires_at is null then extract(epoch from now() + interval '10 years')::integer
- else extract(epoch from expires_at)::integer
- end,
- 'is_service_account', true,
- 'organization_id', org_id,
- 'token_id', v_token_id
- )::json,
- v_jwt_secret
- );
-
- -- Update token hash
- update tokens
- set token_hash = v_jwt_token
- where id = v_token_id;
-
- return query select v_token_id, v_jwt_token;
-end;
-$$;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000010_fix_token_validation.sql b/supabase/migrations_dev/20241219000010_fix_token_validation.sql
deleted file mode 100644
index 0ccd0e5..0000000
--- a/supabase/migrations_dev/20241219000010_fix_token_validation.sql
+++ /dev/null
@@ -1,221 +0,0 @@
--- First drop all policies that use these functions
-drop policy if exists "Enable read access for organization members" on organizations;
-drop policy if exists "Enable write access for organization admins" on organizations;
-drop policy if exists "Enable read access for members" on organization_members;
-drop policy if exists "Enable write access for admins" on organization_members;
-drop policy if exists "Enable access for organization members" on jobs;
-drop policy if exists "Enable access for organization members" on runs;
-drop policy if exists "Enable access for organization members" on worker;
-drop policy if exists "Enable access for organization members" on events;
-drop policy if exists "Enable read for members" on third_party_api_keys;
-drop policy if exists "Enable write for admins" on third_party_api_keys;
-
--- Now we can safely drop the functions
-drop function if exists get_organization_from_token();
-drop function if exists has_organization_access(uuid);
-drop function if exists is_organization_member(uuid);
-drop function if exists is_organization_admin(uuid);
-
--- Function to update token last used timestamp
-create or replace function update_token_last_used(token_id uuid)
-returns void
-language plpgsql
-security definer
-as $$
-begin
- -- Update last_used_at in a separate transaction
- update tokens
- set last_used_at = now()
- where id = token_id;
-exception
- when others then
- -- Ignore any errors during update
- null;
-end;
-$$;
-
--- Function to get organization ID from token
-create or replace function get_organization_from_token()
-returns uuid
-language plpgsql
-security definer
-as $$
-declare
- auth_token text;
- custom_token text;
- org_id uuid;
- token_id uuid;
-begin
- -- First try custom token header
- custom_token := coalesce(
- current_setting('request.headers', true)::json->>'x-openweights-token',
- ''
- );
-
- if custom_token != '' and custom_token like 'ow_%' then
- -- Look up organization ID from tokens table
- select t.organization_id, t.id into org_id, token_id
- from tokens t
- where t.token_hash = custom_token
- and (t.expires_at is null or t.expires_at > now());
-
- if found then
- -- Try to update last_used_at in background
- perform update_token_last_used(token_id);
- return org_id;
- end if;
- end if;
-
- -- If no custom token, try normal auth
- auth_token := coalesce(
- current_setting('request.headers', true)::json->>'authorization',
- ''
- );
-
- -- Extract token from "Bearer "
- auth_token := replace(auth_token, 'Bearer ', '');
-
- -- If it's a valid JWT, auth.uid() will work
- if auth_token != '' then
- -- Return organization ID from membership
- select organization_id into org_id
- from organization_members
- where user_id = auth.uid()
- limit 1;
-
- return org_id;
- end if;
-
- -- No valid token found
- return null;
-end;
-$$;
-
--- Function to check if current token has access to an organization
-create or replace function has_organization_access(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-declare
- token_org_id uuid;
-begin
- -- First check custom tokens
- token_org_id := get_organization_from_token();
- if token_org_id is not null then
- return token_org_id = org_id;
- end if;
-
- -- Fall back to checking organization membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- );
-end;
-$$;
-
--- Update existing functions to use the new logic
-create or replace function is_organization_member(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-begin
- return has_organization_access(org_id);
-end;
-$$;
-
-create or replace function is_organization_admin(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-declare
- custom_token text;
-begin
- -- First try custom token header
- custom_token := coalesce(
- current_setting('request.headers', true)::json->>'x-openweights-token',
- ''
- );
-
- -- API tokens have admin access
- if custom_token != '' and custom_token like 'ow_%' then
- return has_organization_access(org_id);
- end if;
-
- -- Otherwise check normal admin membership
- return exists (
- select 1
- from organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- and role = 'admin'
- );
-end;
-$$;
-
--- Recreate policies with explicit table references
-create policy "Enable read access for organization members"
- on organizations for select
- using (is_organization_member(id));
-
-create policy "Enable write access for organization admins"
- on organizations for all
- using (is_organization_admin(id));
-
-create policy "Enable read access for members"
- on organization_members for select
- using (is_organization_member(organization_id));
-
-create policy "Enable write access for admins"
- on organization_members for all
- using (is_organization_admin(organization_id));
-
-create policy "Enable access for organization members"
- on jobs for all
- using (is_organization_member(organization_id));
-
-create policy "Enable access for organization members"
- on runs for all
- using (
- exists (
- select 1
- from jobs j
- where j.id = runs.job_id
- and is_organization_member(j.organization_id)
- )
- );
-
-create policy "Enable access for organization members"
- on worker for all
- using (is_organization_member(organization_id));
-
-create policy "Enable access for organization members"
- on events for all
- using (
- exists (
- select 1
- from runs r
- join jobs j on j.id = r.job_id
- where r.id = events.run_id
- and is_organization_member(j.organization_id)
- )
- );
-
-create policy "Enable read for members"
- on third_party_api_keys for select
- using (is_organization_member(organization_id));
-
-create policy "Enable write for admins"
- on third_party_api_keys for all
- using (is_organization_admin(organization_id));
-
--- Grant execute permissions
-grant execute on function update_token_last_used(uuid) to postgres, authenticated, anon;
-grant execute on function get_organization_from_token() to postgres, authenticated, anon;
-grant execute on function has_organization_access(uuid) to postgres, authenticated, anon;
-grant execute on function is_organization_member(uuid) to postgres, authenticated, anon;
-grant execute on function is_organization_admin(uuid) to postgres, authenticated, anon;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000011_fix_job_policies.sql b/supabase/migrations_dev/20241219000011_fix_job_policies.sql
deleted file mode 100644
index 35df091..0000000
--- a/supabase/migrations_dev/20241219000011_fix_job_policies.sql
+++ /dev/null
@@ -1,22 +0,0 @@
--- Drop existing job policy
-drop policy if exists "Enable access for organization members" on jobs;
-
--- Create separate policies for read and write operations
-create policy "Organization members can read jobs"
- on jobs for select
- using (is_organization_member(organization_id));
-
-create policy "Organization members can insert jobs"
- on jobs for insert
- with check (
- -- For new jobs, organization_id must match the user's organization
- organization_id = get_organization_from_token()
- );
-
-create policy "Organization members can update their jobs"
- on jobs for update
- using (is_organization_member(organization_id));
-
-create policy "Organization members can delete their jobs"
- on jobs for delete
- using (is_organization_member(organization_id));
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000012_add_get_org_function.sql b/supabase/migrations_dev/20241219000012_add_get_org_function.sql
deleted file mode 100644
index d162aa5..0000000
--- a/supabase/migrations_dev/20241219000012_add_get_org_function.sql
+++ /dev/null
@@ -1,31 +0,0 @@
--- Function to get organization ID from token
-create or replace function get_organization_from_token()
-returns uuid
-language plpgsql
-security definer
-as $$
-declare
- org_id uuid;
-begin
- -- If this is a service account token, get org from claims
- if auth.check_if_service_account() then
- org_id := (current_setting('request.jwt.claims', true)::json->>'organization_id')::uuid;
- return org_id;
- end if;
-
- -- Otherwise, get organization from membership
- select organization_id into org_id
- from organization_members
- where user_id = auth.uid()
- limit 1;
-
- if org_id is null then
- raise exception 'No organization found for current user/token';
- end if;
-
- return org_id;
-end;
-$$;
-
--- Grant execute permissions
-grant execute on function get_organization_from_token() to authenticated, anon;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000013_fix_storage_policies.sql b/supabase/migrations_dev/20241219000013_fix_storage_policies.sql
deleted file mode 100644
index 0f67801..0000000
--- a/supabase/migrations_dev/20241219000013_fix_storage_policies.sql
+++ /dev/null
@@ -1,119 +0,0 @@
--- Drop existing storage policies
-drop policy if exists "Organization members can read their files" on storage.objects;
-drop policy if exists "Organization members can upload files" on storage.objects;
-drop policy if exists "Organization members can delete their files" on storage.objects;
-
--- Enable RLS on storage.objects
-alter table storage.objects enable row level security;
-
--- Function to extract organization ID from storage path
-create or replace function storage.get_path_organization_id(path text)
-returns uuid
-language plpgsql
-as $$
-declare
- parts text[];
- org_id uuid;
-begin
- -- Split path into parts
- parts := string_to_array(path, '/');
-
- -- Check if path starts with 'organizations'
- if parts[1] != 'organizations' then
- return null;
- end if;
-
- -- Try to convert second part to UUID
- begin
- org_id := parts[2]::uuid;
- return org_id;
- exception when others then
- return null;
- end;
-end;
-$$;
-
--- Policy for reading files:
-create policy "Organization members can read their files"
-on storage.objects for select
-using (
- bucket_id = 'files'
- and (
- -- Allow access to .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is a member
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = storage.get_path_organization_id(name)
- and user_id = auth.uid()
- )
- )
- )
-);
-
--- Policy for inserting files:
-create policy "Organization members can upload files"
-on storage.objects for insert
-with check (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is a member
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = storage.get_path_organization_id(name)
- and user_id = auth.uid()
- )
- )
- )
-);
-
--- Policy for updating files:
-create policy "Organization members can update their files"
-on storage.objects for update
-using (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is a member
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = storage.get_path_organization_id(name)
- and user_id = auth.uid()
- )
- )
- )
-);
-
--- Policy for deleting files:
-create policy "Organization members can delete their files"
-on storage.objects for delete
-using (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user is an admin
- name like 'organizations/%'
- and exists (
- select 1
- from public.organization_members
- where organization_id = storage.get_path_organization_id(name)
- and user_id = auth.uid()
- and role = 'admin'
- )
- )
- )
-);
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000014_fix_storage_setup.sql b/supabase/migrations_dev/20241219000014_fix_storage_setup.sql
deleted file mode 100644
index f4c0f44..0000000
--- a/supabase/migrations_dev/20241219000014_fix_storage_setup.sql
+++ /dev/null
@@ -1,141 +0,0 @@
--- Drop existing storage policies
-drop policy if exists "Organization members can read their files" on storage.objects;
-drop policy if exists "Organization members can upload files" on storage.objects;
-drop policy if exists "Organization members can update their files" on storage.objects;
-drop policy if exists "Organization members can delete their files" on storage.objects;
-
--- Enable RLS on storage.objects
-alter table storage.objects enable row level security;
-
--- Function to extract organization ID from storage path
-create or replace function storage.get_path_organization_id(path text)
-returns uuid
-language plpgsql
-as $$
-declare
- parts text[];
- org_id uuid;
-begin
- -- Split path into parts
- parts := string_to_array(path, '/');
-
- -- Check if path starts with 'organizations'
- if parts[1] != 'organizations' then
- return null;
- end if;
-
- -- Try to convert second part to UUID
- begin
- org_id := parts[2]::uuid;
- return org_id;
- exception when others then
- return null;
- end;
-end;
-$$;
-
--- Function to check if user has access to organization
-create or replace function storage.has_organization_access(org_id uuid)
-returns boolean
-language plpgsql
-security definer
-as $$
-begin
- -- First check custom tokens
- if auth.jwt() ? 'organization_id' then
- return (auth.jwt()->>'organization_id')::uuid = org_id;
- end if;
-
- -- Otherwise check organization membership
- return exists (
- select 1
- from public.organization_members
- where organization_id = org_id
- and user_id = auth.uid()
- );
-end;
-$$;
-
--- Policy for reading files:
-create policy "Organization members can read their files"
-on storage.objects for select
-using (
- bucket_id = 'files'
- and (
- -- Allow access to .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user has access
- name like 'organizations/%'
- and storage.has_organization_access(storage.get_path_organization_id(name))
- )
- )
-);
-
--- Policy for inserting files:
-create policy "Organization members can upload files"
-on storage.objects for insert
-with check (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user has access
- name like 'organizations/%'
- and storage.has_organization_access(storage.get_path_organization_id(name))
- )
- )
-);
-
--- Policy for updating files:
-create policy "Organization members can update their files"
-on storage.objects for update
-using (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user has access
- name like 'organizations/%'
- and storage.has_organization_access(storage.get_path_organization_id(name))
- )
- )
-);
-
--- Policy for deleting files:
-create policy "Organization members can delete their files"
-on storage.objects for delete
-using (
- bucket_id = 'files'
- and (
- -- Allow .keep files
- name like '%.keep'
- or (
- -- Check if file is in an organization folder and user has access
- name like 'organizations/%'
- and storage.has_organization_access(storage.get_path_organization_id(name))
- )
- )
-);
-
--- Create organization folders
-do $$
-declare
- org record;
-begin
- -- Create organizations folder
- insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version)
- values ('files', 'organizations/.keep', auth.uid(), now(), now(), '1')
- on conflict do nothing;
-
- -- Create folder for each organization
- for org in select id from public.organizations
- loop
- insert into storage.objects (bucket_id, name, owner, created_at, updated_at, version)
- values ('files', 'organizations/' || org.id || '/.keep', auth.uid(), now(), now(), '1')
- on conflict do nothing;
- end loop;
-end;
-$$;
\ No newline at end of file
diff --git a/supabase/migrations_dev/20241219000015_fix_run_policies.sql b/supabase/migrations_dev/20241219000015_fix_run_policies.sql
deleted file mode 100644
index 42941a8..0000000
--- a/supabase/migrations_dev/20241219000015_fix_run_policies.sql
+++ /dev/null
@@ -1,75 +0,0 @@
--- Drop existing run policy
-drop policy if exists "Enable access for organization members" on runs;
-
--- Create separate policies for read and write operations
-create policy "Organization members can read runs"
- on runs for select
- using (
- exists (
- select 1
- from jobs j
- where j.id = runs.job_id
- and (
- is_organization_member(j.organization_id)
- or
- (
- auth.jwt() ? 'is_service_account'
- and (auth.jwt()->>'organization_id')::uuid = j.organization_id
- )
- )
- )
- );
-
-create policy "Organization members can insert runs"
- on runs for insert
- with check (
- exists (
- select 1
- from jobs j
- where j.id = job_id
- and (
- is_organization_member(j.organization_id)
- or
- (
- auth.jwt() ? 'is_service_account'
- and (auth.jwt()->>'organization_id')::uuid = j.organization_id
- )
- )
- )
- );
-
-create policy "Organization members can update runs"
- on runs for update
- using (
- exists (
- select 1
- from jobs j
- where j.id = runs.job_id
- and (
- is_organization_member(j.organization_id)
- or
- (
- auth.jwt() ? 'is_service_account'
- and (auth.jwt()->>'organization_id')::uuid = j.organization_id
- )
- )
- )
- );
-
-create policy "Organization members can delete runs"
- on runs for delete
- using (
- exists (
- select 1
- from jobs j
- where j.id = runs.job_id
- and (
- is_organization_member(j.organization_id)
- or
- (
- auth.jwt() ? 'is_service_account'
- and (auth.jwt()->>'organization_id')::uuid = j.organization_id
- )
- )
- )
- );
\ No newline at end of file
diff --git a/supabase/seed.sql b/supabase/seed.sql
new file mode 100644
index 0000000..d7cd575
--- /dev/null
+++ b/supabase/seed.sql
@@ -0,0 +1,6 @@
+-- Ensure 'files' bucket exists in the shadow DB used by `supabase db pull`.
+-- This runs before migrations during local/shadow setup and is a NOOP on subsequent runs.
+
+insert into storage.buckets (id, name, public)
+values ('files', 'files', false)
+on conflict (id) do nothing;
\ No newline at end of file