Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 221 additions & 19 deletions omlx/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ class MSRetryRequest(BaseModel):
ms_token: str = ""


class EnginePackageInstallRequest(BaseModel):
"""Request model for installing/updating an engine package."""

package: str # e.g., "mlx-vlm"
version: str # e.g., "0.5.0" or "latest" or "git+https://...@commit"


class OQStartRequest(BaseModel):
"""Request model for starting an oQ quantization task."""

Expand Down Expand Up @@ -2273,7 +2280,9 @@ def _get_engine_info() -> dict:
Fallback chain:
1. PEP 610 direct_url.json (pip install git+https://...)
2. _engine_commits.json (generated by build.py for app bundle)
3. Parse pyproject.toml at runtime (dev environment)

Note: pyproject.toml is NOT used as fallback because it contains dev pins,
not what was actually installed via PyPI.
"""
import importlib.metadata

Expand All @@ -2285,7 +2294,8 @@ def _get_engine_info() -> dict:
"mlx-audio": "https://github.com/Blaizzy/mlx-audio",
}

fallback_commits = _load_fallback_commits(packages)
# Only load _engine_commits.json (app bundle), NOT pyproject.toml
fallback_commits = _load_engine_commits_json(packages)

for pkg_name, default_url in packages.items():
info = {"name": pkg_name, "version": None, "commit": None, "url": None}
Expand All @@ -2296,7 +2306,7 @@ def _get_engine_info() -> dict:
# Method 1: PEP 610 direct_url.json
commit_info = _get_commit_from_direct_url(dist, default_url)
if not commit_info:
# Methods 2+3: _engine_commits.json or pyproject.toml
# Method 2: _engine_commits.json (app bundle only)
commit_info = fallback_commits.get(pkg_name)

if commit_info:
Expand All @@ -2310,17 +2320,26 @@ def _get_engine_info() -> dict:


def _get_commit_from_direct_url(dist, default_url: str) -> dict | None:
"""Extract commit SHA from PEP 610 direct_url.json."""
"""Extract commit SHA from PEP 610 direct_url.json.

Only trusts vcs_info if the URL indicates a VCS install (git+, hg+, etc).
Plain PyPI URLs should not have vcs_info, but some pip edge cases can leave
stale vcs_info from previous installs.
"""
import json

try:
direct_url_text = dist.read_text("direct_url.json")
if direct_url_text:
direct_url = json.loads(direct_url_text)
url = direct_url.get("url", "")
# Only trust vcs_info for actual VCS installs (git+https://, etc)
if not any(url.startswith(prefix) for prefix in ("git+", "hg+", "svn+", "bzr+")):
return None
vcs_info = direct_url.get("vcs_info", {})
commit = vcs_info.get("commit_id")
if commit:
repo_url = direct_url.get("url", default_url).rstrip("/")
repo_url = url.rstrip("/")
if repo_url.endswith(".git"):
repo_url = repo_url[:-4]
return {"commit": commit, "url": f"{repo_url}/commit/{commit}"}
Expand All @@ -2329,20 +2348,20 @@ def _get_commit_from_direct_url(dist, default_url: str) -> dict | None:
return None


def _load_fallback_commits(packages: dict[str, str]) -> dict:
"""Load commit SHAs from fallback sources.
def _load_engine_commits_json(packages: dict[str, str]) -> dict:
"""Load commit SHAs from _engine_commits.json (app bundle only).

Tries in order:
1. _engine_commits.json (generated by build.py, lives in omlx package dir)
2. pyproject.toml (dev environment, lives one level above package dir)
This file is generated by build.py and contains the exact commits
that were installed in the app bundle. For PyPI installs, this file
won't exist and no fallback commits will be used.
"""
import json
from pathlib import Path

# This file is at omlx/admin/routes.py → package dir is omlx/
pkg_dir = Path(__file__).resolve().parent.parent

# Method 2: _engine_commits.json (written by build.py for app bundle)
# _engine_commits.json (written by build.py for app bundle)
commits_file = pkg_dir / "_engine_commits.json"
if commits_file.is_file():
try:
Expand All @@ -2360,14 +2379,6 @@ def _load_fallback_commits(packages: dict[str, str]) -> dict:
except Exception:
pass

# Method 3: Parse pyproject.toml (dev environment)
pyproject = pkg_dir.parent / "pyproject.toml"
if pyproject.is_file():
try:
return _parse_commits_from_pyproject(pyproject, packages)
except Exception:
pass

return {}


Expand Down Expand Up @@ -2772,6 +2783,197 @@ async def clear_ssd_cache(is_admin: bool = Depends(require_admin)):
return {"status": "ok", "total_deleted": total_deleted}


# =============================================================================
# Engine Package Management API Routes
# =============================================================================

# Track ongoing install tasks
_install_tasks: dict[str, dict] = {}


class EnginePackageInstallTask:
"""Represents an ongoing package installation task."""

def __init__(self, package: str, version: str):
self.id = f"{package}-{version}-{int(time.time())}"
self.package = package
self.version = version
self.status = "pending" # pending, running, completed, failed
self.output = ""
self.error = ""
self.started_at: float | None = None
self.completed_at: float | None = None

def to_dict(self) -> dict:
return {
"id": self.id,
"package": self.package,
"version": self.version,
"status": self.status,
"output": self.output,
"error": self.error,
"started_at": self.started_at,
"completed_at": self.completed_at,
}


def _fetch_pypi_versions(package_name: str) -> dict:
"""Fetch available versions for a package from PyPI.

Returns:
dict with keys: "versions" (list), "latest" (str), "error" (str or None)
"""
import requests

try:
resp = requests.get(
f"https://pypi.org/pypi/{package_name}/json",
timeout=10,
headers={"User-Agent": "omlx-admin/1.0"},
)
if resp.status_code == 200:
data = resp.json()
def _version_key(v: str):
try:
from packaging.version import Version

return (Version(v), v)
except Exception:
parts = []
for part in v.split("."):
if part.isdigit():
parts.append((0, int(part)))
else:
parts.append((1, part))
return (tuple(parts), v)

versions = sorted(
data.get("releases", {}).keys(),
key=_version_key,
reverse=True,
)
return {
"versions": versions[:50], # Limit to 50 most recent
"latest": data.get("info", {}).get("version", ""),
"error": None,
}
else:
return {"versions": [], "latest": "", "error": f"PyPI returned {resp.status_code}"}
except Exception as e:
return {"versions": [], "latest": "", "error": str(e)}


@router.get("/api/engine-packages/available")
async def get_available_versions(
package: str,
is_admin: bool = Depends(require_admin),
):
"""Get available versions for an engine package from PyPI.

Args:
package: Package name (e.g., "mlx-vlm", "mlx-lm")
"""
# Validate package name
valid_packages = {"mlx-lm", "mlx-vlm", "mlx-embeddings", "mlx-audio"}
if package not in valid_packages:
raise HTTPException(
status_code=400,
detail=f"Invalid package. Must be one of: {', '.join(sorted(valid_packages))}",
)

result = _fetch_pypi_versions(package)
return result


@router.get("/api/engine-packages/tasks")
async def list_install_tasks(is_admin: bool = Depends(require_admin)):
"""List all package installation tasks (recent 20)."""
global _install_tasks
tasks = list(_install_tasks.values())[-20:]
return {"tasks": tasks}


@router.post("/api/engine-packages/install")
async def install_engine_package(
request: EnginePackageInstallRequest,
is_admin: bool = Depends(require_admin),
):
"""Install or update an engine package.

Args:
request: Contains package name and version to install
"""
import subprocess

valid_packages = {"mlx-lm", "mlx-vlm", "mlx-embeddings", "mlx-audio"}
if request.package not in valid_packages:
raise HTTPException(
status_code=400,
detail=f"Invalid package. Must be one of: {', '.join(sorted(valid_packages))}",
)

# Create task
task = EnginePackageInstallTask(request.package, request.version)
global _install_tasks
_install_tasks[task.id] = task

# Run installation in background thread
def _run_install():
task.status = "running"
task.started_at = time.time()
_install_tasks[task.id] = task

try:
# Build pip install command
if request.version in ("latest", ""):
cmd = [sys.executable, "-m", "pip", "install", "-U", request.package]
elif request.version.startswith("git+"):
cmd = [sys.executable, "-m", "pip", "install", request.version]
else:
cmd = [sys.executable, "-m", "pip", "install", f"{request.package}=={request.version}"]

result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)

task.output = result.stdout
task.error = result.stderr

if result.returncode == 0:
task.status = "completed"
else:
task.status = "failed"
except subprocess.TimeoutExpired:
task.status = "failed"
task.error = "Installation timed out (5 minute limit)"
except Exception as e:
task.status = "failed"
task.error = str(e)
finally:
task.completed_at = time.time()
_install_tasks[task.id] = task

# Run in thread pool to not block the request
asyncio.create_task(asyncio.to_thread(_run_install))

return {"task": task.to_dict()}


@router.delete("/api/engine-packages/task/{task_id}")
async def delete_install_task(
task_id: str,
is_admin: bool = Depends(require_admin),
):
"""Delete a completed install task from history."""
global _install_tasks
if task_id in _install_tasks:
del _install_tasks[task_id]
return {"status": "ok"}


# =============================================================================
# HuggingFace Downloader API Routes
# =============================================================================
Expand Down
Loading