Skip to content

Commit d81639a

Browse files
committed
Add REST plugin
1 parent c521a57 commit d81639a

File tree

8 files changed

+185
-0
lines changed

8 files changed

+185
-0
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ pattern = '<picture>\s*|<source[^>]*>\s*|\s*</picture>|<video[^>]*>\s*|</video>\
7575
replacement = ''
7676
ignore-case = true
7777

78+
[tool.uv.workspace]
79+
members = ["src/plugins/rest_plugin"]
80+
7881
[dependency-groups]
7982
dev = [
8083
"build>=1.2.2.post1",
@@ -176,3 +179,6 @@ nebius = [
176179
all = [
177180
"dstack[gateway,server,aws,azure,gcp,datacrunch,kubernetes,lambda,nebius,oci]",
178181
]
182+
183+
[project.entry-points."dstack.plugins"]
184+
rest_plugin = "plugins.rest_plugin.src.rest_plugin:RESTPlugin"

src/plugins/rest_plugin/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[TODO]

src/plugins/rest_plugin/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[project]
2+
name = "rest-plugin"
3+
version = "0.1.0"
4+
description = "A dstack plugin that enables validation and mutation of run specifications via REST API"
5+
readme = "README.md"
6+
requires-python = ">=3.9"
7+
dependencies = []
8+
9+
[build-system]
10+
requires = ["hatchling"]
11+
build-backend = "hatchling.build"
12+
13+
[tool.hatch.build.targets.wheel]
14+
packages = ["src"]

src/plugins/rest_plugin/src/__init__.py

Whitespace-only changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import json
2+
import os
3+
import pydantic
4+
import requests
5+
from dstack._internal.core.errors import ServerError
6+
from dstack._internal.core.models.fleets import FleetSpec
7+
from dstack._internal.core.models.gateways import GatewaySpec
8+
from dstack._internal.core.models.volumes import VolumeSpec
9+
from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger
10+
from dstack.plugins._models import ApplySpec
11+
12+
logger = get_plugin_logger(__name__)
13+
14+
PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI"
15+
16+
class PreApplyPolicy(ApplyPolicy):
17+
def __init__(self):
18+
self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME)
19+
if not self._plugin_service_uri:
20+
logger.error(f"Cannot create policy as {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set")
21+
raise ServerError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set")
22+
23+
def _call_plugin_service(self, user: str, project: str, spec: ApplySpec, endpoint: str) -> ApplySpec:
24+
# Make request to plugin service with run params
25+
params = {
26+
"user": user,
27+
"project": project,
28+
"spec": spec.json()
29+
}
30+
response = None
31+
try:
32+
response = requests.post(f"{self._plugin_service_uri}/{endpoint}", json=json.dumps(params))
33+
response.raise_for_status()
34+
spec_json = json.loads(response.text)
35+
spec = RunSpec(**spec_json)
36+
except requests.RequestException as e:
37+
logger.error("Failed to call plugin service: %s", e)
38+
if response:
39+
logger.error(f"Error response from plugin service:\n{response.text}")
40+
logger.info("Returning original run spec")
41+
return spec
42+
except pydantic.ValidationError as e:
43+
logger.exception(f"Plugin service returned invalid response:\n{response.text if response else None}")
44+
logger.info("Returning original run spec")
45+
return spec
46+
logger.info(f"Using RunSpec from plugin service:\n{spec}")
47+
return spec
48+
def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
49+
return self._call_plugin_service(user, project, spec, '/runs/pre_apply')
50+
51+
def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec:
52+
return self._call_plugin_service(user, project, spec, '/fleets/pre_apply')
53+
54+
def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec:
55+
return self._call_plugin_service(user, project, spec, '/volumes/pre_apply')
56+
57+
def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
58+
return self._call_plugin_service(user, project, spec, '/gateways/pre_apply')
59+
60+
class RESTPlugin(Plugin):
61+
def get_apply_policies(self) -> list[ApplyPolicy]:
62+
return [PreApplyPolicy()]

src/tests/plugins/__init__.py

Whitespace-only changes.

src/tests/plugins/test_rest_plugin.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from dstack._internal.core.errors import ServerError
2+
from dstack._internal.server.models import ProjectModel, UserModel
3+
from plugins.rest_plugin.src.rest_plugin import PreApplyPolicy, PLUGIN_SERVICE_URI_ENV_VAR_NAME
4+
import pytest
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
from pydantic import parse_obj_as
7+
import os
8+
import json
9+
import requests
10+
from unittest.mock import Mock
11+
12+
from dstack._internal.core.models.runs import RunSpec
13+
from dstack._internal.core.models.configurations import ServiceConfiguration
14+
from dstack._internal.core.models.profiles import Profile
15+
from dstack._internal.core.models.resources import Range
16+
from dstack._internal.server.testing.common import (
17+
create_project,
18+
create_user,
19+
create_repo,
20+
get_run_spec,
21+
)
22+
from dstack._internal.server.testing.conf import session, test_db # noqa: F401
23+
from dstack._internal.server.services import encryption as encryption # import for side-effect
24+
import pytest_asyncio
25+
from unittest import mock
26+
27+
28+
async def create_run_spec(
29+
session: AsyncSession,
30+
project: ProjectModel,
31+
replicas: str = 1,
32+
) -> RunSpec:
33+
repo = await create_repo(session=session, project_id=project.id)
34+
run_name = "test-run"
35+
profile = Profile(name="test-profile")
36+
spec = get_run_spec(
37+
repo_id=repo.name,
38+
run_name=run_name,
39+
profile=profile,
40+
configuration=ServiceConfiguration(
41+
commands=["echo hello"],
42+
port=8000,
43+
replicas=parse_obj_as(Range[int], replicas)
44+
),
45+
)
46+
return spec
47+
48+
@pytest_asyncio.fixture
49+
async def project(session):
50+
return await create_project(session=session)
51+
52+
@pytest_asyncio.fixture
53+
async def user(session):
54+
return await create_user(session=session)
55+
56+
@pytest_asyncio.fixture
57+
async def run_spec(session, project):
58+
return await create_run_spec(session=session, project=project)
59+
60+
61+
class TestRESTPlugin:
62+
@pytest.mark.asyncio
63+
async def test_on_run_apply_plugin_service_uri_not_set(self):
64+
with pytest.raises(ServerError):
65+
policy = PreApplyPolicy()
66+
67+
@pytest.mark.asyncio
68+
@mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"})
69+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
70+
async def test_on_run_apply_plugin_service_returns_mutated_spec(self, test_db, user, project, run_spec):
71+
policy = PreApplyPolicy()
72+
mock_response = Mock()
73+
run_spec_dict = run_spec.dict()
74+
run_spec_dict["profile"]["tags"] = {"env": "test", "team": "qa"}
75+
mock_response.text = json.dumps(run_spec_dict)
76+
mock_response.raise_for_status = Mock()
77+
with mock.patch("requests.post", return_value=mock_response):
78+
result = policy.on_apply(user=user.name, project=project.name, spec=run_spec)
79+
assert result == RunSpec(**run_spec_dict)
80+
81+
@pytest.mark.asyncio
82+
@mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"})
83+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
84+
async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, project, run_spec):
85+
policy = PreApplyPolicy()
86+
with mock.patch("requests.post", side_effect=requests.RequestException("fail")):
87+
result = policy.on_apply(user=user.name, project=project.name, spec=run_spec)
88+
assert result == run_spec
89+
90+
@pytest.mark.asyncio
91+
@mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"})
92+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
93+
async def test_on_run_apply_plugin_service_returns_invalid_spec(self, test_db, user, project, run_spec):
94+
policy = PreApplyPolicy()
95+
mock_response = Mock()
96+
mock_response.text = json.dumps({"invalid-key": "abc"})
97+
mock_response.raise_for_status = Mock()
98+
with mock.patch("requests.post", return_value=mock_response):
99+
result = policy.on_apply(user.name, project=project.name, spec=run_spec)
100+
# return original run spec
101+
assert result == run_spec
102+

0 commit comments

Comments
 (0)