Skip to content

Commit 823d321

Browse files
committed
Change rest-plugin to a builtin plugin
1 parent a756cf4 commit 823d321

File tree

11 files changed

+349
-179
lines changed

11 files changed

+349
-179
lines changed

pyproject.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ pattern = '<picture>\s*|<source[^>]*>\s*|\s*</picture>|<video[^>]*>\s*|</video>\
7575
replacement = ''
7676
ignore-case = true
7777

78-
[tool.uv.workspace]
79-
members = ["src/plugins/rest_plugin"]
80-
8178
[dependency-groups]
8279
dev = [
8380
"build>=1.2.2.post1",
@@ -179,6 +176,3 @@ nebius = [
179176
all = [
180177
"dstack[gateway,server,aws,azure,gcp,datacrunch,kubernetes,lambda,nebius,oci]",
181178
]
182-
183-
[project.entry-points."dstack.plugins"]
184-
rest_plugin = "plugins.rest_plugin.src.rest_plugin:RESTPlugin"

src/dstack/_internal/server/services/plugins.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
from importlib import import_module
3+
from typing import Dict
34

45
from backports.entry_points_selectable import entry_points # backport for Python 3.9
56

@@ -12,50 +13,80 @@
1213

1314
_PLUGINS: list[Plugin] = []
1415

16+
_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}
1517

16-
def load_plugins(enabled_plugins: list[str]):
17-
_PLUGINS.clear()
18-
plugins_entrypoints = entry_points(group="dstack.plugins")
19-
plugins_to_load = enabled_plugins.copy()
20-
for entrypoint in plugins_entrypoints:
21-
if entrypoint.name not in enabled_plugins:
22-
logger.info(
23-
("Found not enabled plugin %s. Plugin will not be loaded."),
24-
entrypoint.name,
25-
)
26-
continue
18+
19+
class PluginEntrypoint:
20+
def __init__(self, name: str, import_path: str, is_builtin: bool = False):
21+
self.name = name
22+
self.import_path = import_path
23+
self.is_builtin = is_builtin
24+
25+
def load(self):
26+
module_path, _, class_name = self.import_path.partition(":")
2727
try:
28-
module_path, _, class_name = entrypoint.value.partition(":")
2928
module = import_module(module_path)
29+
plugin_class = getattr(module, class_name, None)
30+
if plugin_class is None:
31+
logger.warning(
32+
("Failed to load plugin %s: plugin class %s not found in module %s."),
33+
self.name,
34+
class_name,
35+
module_path,
36+
)
37+
return None
38+
if not issubclass(plugin_class, Plugin):
39+
logger.warning(
40+
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
41+
self.name,
42+
class_name,
43+
)
44+
return None
45+
return plugin_class()
3046
except ImportError:
3147
logger.warning(
3248
(
3349
"Failed to load plugin %s when importing %s."
3450
" Ensure the module is on the import path."
3551
),
36-
entrypoint.name,
37-
entrypoint.value,
52+
self.name,
53+
self.import_path,
3854
)
39-
continue
40-
plugin_class = getattr(module, class_name, None)
41-
if plugin_class is None:
42-
logger.warning(
43-
("Failed to load plugin %s: plugin class %s not found in module %s."),
55+
return None
56+
57+
58+
def load_plugins(enabled_plugins: list[str]):
59+
_PLUGINS.clear()
60+
entrypoints: dict[str, PluginEntrypoint] = {}
61+
plugins_to_load = enabled_plugins.copy()
62+
for entrypoint in entry_points(group="dstack.plugins"):
63+
if entrypoint.name not in enabled_plugins:
64+
logger.info(
65+
("Found not enabled plugin %s. Plugin will not be loaded."),
4466
entrypoint.name,
45-
class_name,
46-
module_path,
4767
)
4868
continue
49-
if not issubclass(plugin_class, Plugin):
50-
logger.warning(
51-
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
52-
entrypoint.name,
53-
class_name,
69+
else:
70+
entrypoints[entrypoint.name] = PluginEntrypoint(
71+
entrypoint.name, entrypoint.value, is_builtin=False
5472
)
55-
continue
56-
plugins_to_load.remove(entrypoint.name)
57-
_PLUGINS.append(plugin_class())
58-
logger.info("Loaded plugin %s", entrypoint.name)
73+
74+
for name, import_path in _BUILTIN_PLUGINS.items():
75+
if name not in enabled_plugins:
76+
logger.info(
77+
("Found not enabled builtin plugin %s. Plugin will not be loaded."),
78+
name,
79+
)
80+
else:
81+
entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)
82+
83+
for plugin_name, plugin_entrypoint in entrypoints.items():
84+
plugin_instance = plugin_entrypoint.load()
85+
if plugin_instance is not None:
86+
_PLUGINS.append(plugin_instance)
87+
plugins_to_load.remove(plugin_name)
88+
logger.info("Loaded plugin %s", plugin_name)
89+
5990
if plugins_to_load:
6091
logger.warning("Enabled plugins not found: %s", plugins_to_load)
6192

@@ -65,7 +96,7 @@ def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec
6596
for policy in policies:
6697
try:
6798
spec = policy.on_apply(user=user, project=project, spec=spec)
68-
except ValueError as e:
99+
except Exception as e:
69100
msg = None
70101
if len(e.args) > 0:
71102
msg = e.args[0]
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import json
2+
import os
3+
from typing import Generic, TypeVar
4+
5+
import requests
6+
from pydantic import BaseModel, ValidationError
7+
8+
from dstack._internal.core.errors import ServerClientError
9+
from dstack._internal.core.models.fleets import FleetSpec
10+
from dstack._internal.core.models.gateways import GatewaySpec
11+
from dstack._internal.core.models.volumes import VolumeSpec
12+
from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger
13+
from dstack.plugins._models import ApplySpec
14+
15+
logger = get_plugin_logger(__name__)
16+
17+
PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI"
18+
PLUGIN_REQUEST_TIMEOUT = 8 # in seconds
19+
20+
SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)
21+
22+
23+
class SpecRequest(BaseModel, Generic[SpecType]):
24+
user: str
25+
project: str
26+
spec: SpecType
27+
28+
29+
RunSpecRequest = SpecRequest[RunSpec]
30+
FleetSpecRequest = SpecRequest[FleetSpec]
31+
VolumeSpecRequest = SpecRequest[VolumeSpec]
32+
GatewaySpecRequest = SpecRequest[GatewaySpec]
33+
34+
35+
class CustomApplyPolicy(ApplyPolicy):
36+
def __init__(self):
37+
self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME)
38+
logger.info(f"Found plugin service at {self._plugin_service_uri}")
39+
if not self._plugin_service_uri:
40+
logger.error(
41+
f"Cannot create policy because {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set"
42+
)
43+
raise ServerClientError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set")
44+
45+
def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> ApplySpec:
46+
response = None
47+
try:
48+
response = requests.post(
49+
f"{self._plugin_service_uri}{endpoint}",
50+
json=spec_request.dict(),
51+
headers={"accept": "application/json", "Content-Type": "application/json"},
52+
timeout=PLUGIN_REQUEST_TIMEOUT,
53+
)
54+
response.raise_for_status()
55+
spec_json = json.loads(response.text)
56+
return spec_json
57+
except requests.exceptions.ConnectionError as e:
58+
logger.error(
59+
f"Could not connect to plugin service at {self._plugin_service_uri}: %s", e
60+
)
61+
raise e
62+
except requests.RequestException as e:
63+
logger.error("Request to the plugin service failed: %s", e)
64+
if response:
65+
logger.error(f"Error response from plugin service:\n{response.text}")
66+
raise e
67+
except ValidationError as e:
68+
# Received 200 code but response body is invalid
69+
logger.exception(
70+
f"Plugin service returned invalid response:\n{response.text if response else None}"
71+
)
72+
raise e
73+
74+
def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
75+
spec_request = RunSpecRequest(user=user, project=project, spec=spec)
76+
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_run_apply")
77+
return RunSpec(**spec_json)
78+
79+
def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec:
80+
spec_request = FleetSpecRequest(user=user, project=project, spec=spec)
81+
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_fleet_apply")
82+
return FleetSpec(**spec_json)
83+
84+
def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec:
85+
spec_request = VolumeSpecRequest(user=user, project=project, spec=spec)
86+
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_volume_apply")
87+
return VolumeSpec(**spec_json)
88+
89+
def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
90+
spec_request = GatewaySpecRequest(user=user, project=project, spec=spec)
91+
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_gateway_apply")
92+
return GatewaySpec(**spec_json)
93+
94+
95+
class RESTPlugin(Plugin):
96+
def get_apply_policies(self) -> list[ApplyPolicy]:
97+
return [CustomApplyPolicy()]

src/plugins/rest_plugin/README.md

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/plugins/rest_plugin/__init__.py

Whitespace-only changes.

src/plugins/rest_plugin/pyproject.toml

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/plugins/rest_plugin/src/__init__.py

Whitespace-only changes.

src/plugins/rest_plugin/src/rest_plugin.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)