Skip to content

Add REST plugin for user-defined policies #2631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/docs/guides/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,14 @@ class ExamplePolicy(ApplyPolicy):

</div>

For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin).
## Built-in Plugins

### REST Plugin
`rest_plugin` is a builtin `dstack` plugin that allows writing your custom plugins as API servers, so you don't need to install plugins as Python packages.

Plugins implemented as API servers have advantages over plugins implemented as Python packages in some cases:
* No dependency conflicts with `dstack`.
* You can use any programming language.
* If you run the `dstack` server via Docker, you don't need to extend the `dstack` server image with plugins or map them via volumes.

To get started, check out the [plugin server example](/examples/plugins/example_plugin_server/README.md).
1 change: 1 addition & 0 deletions examples/plugins/example_plugin_server/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
30 changes: 30 additions & 0 deletions examples/plugins/example_plugin_server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Overview

If you wish to hook up your own plugin server through `dstack` builtin `rest_plugin`, here's a basic example on how to do so.

## Steps


1. Install required dependencies for the plugin server:

```bash
uv sync
```

1. Start the plugin server locally:

```bash
fastapi dev app/main.py
```

1. Enable `rest_plugin` in `server/config.yaml`:

```yaml
plugins:
- rest_plugin
```

1. Point the `dstack` server to your plugin server:
```bash
export DSTACK_PLUGIN_SERVICE_URI=http://127.0.0.1:8000
```
Empty file.
56 changes: 56 additions & 0 deletions examples/plugins/example_plugin_server/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging

from fastapi import FastAPI

from app.utils import configure_logging
from dstack.plugins.builtin.rest_plugin import (
FleetSpecRequest,
FleetSpecResponse,
GatewaySpecRequest,
GatewaySpecResponse,
RunSpecRequest,
RunSpecResponse,
VolumeSpecRequest,
VolumeSpecResponse,
)

configure_logging()
logger = logging.getLogger(__name__)

app = FastAPI()


@app.post("/apply_policies/on_run_apply")
async def on_run_apply(request: RunSpecRequest) -> RunSpecResponse:
logger.info(
f"Received run spec request from user {request.user} and project {request.project}"
)
response = RunSpecResponse(spec=request.spec, error=None)
return response


@app.post("/apply_policies/on_fleet_apply")
async def on_fleet_apply(request: FleetSpecRequest) -> FleetSpecResponse:
logger.info(
f"Received fleet spec request from user {request.user} and project {request.project}"
)
response = FleetSpecResponse(request.spec, error=None)
return response


@app.post("/apply_policies/on_volume_apply")
async def on_volume_apply(request: VolumeSpecRequest) -> VolumeSpecResponse:
logger.info(
f"Received volume spec request from user {request.user} and project {request.project}"
)
response = VolumeSpecResponse(request.spec, error=None)
return response


@app.post("/apply_policies/on_gateway_apply")
async def on_gateway_apply(request: GatewaySpecRequest) -> GatewaySpecResponse:
logger.info(
f"Received gateway spec request from user {request.user} and project {request.project}"
)
response = GatewaySpecResponse(request.spec, error=None)
return response
7 changes: 7 additions & 0 deletions examples/plugins/example_plugin_server/app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import logging
import os


def configure_logging():
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=log_level)
10 changes: 10 additions & 0 deletions examples/plugins/example_plugin_server/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[project]
name = "dstack-plugin-server"
version = "0.1.0"
description = "Example plugin server"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"fastapi[standard]>=0.115.12",
"dstack>=0.19.8"
]
91 changes: 61 additions & 30 deletions src/dstack/_internal/server/services/plugins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from importlib import import_module
from typing import Dict

from backports.entry_points_selectable import entry_points # backport for Python 3.9

Expand All @@ -12,50 +13,80 @@

_PLUGINS: list[Plugin] = []

_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}

def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
plugins_entrypoints = entry_points(group="dstack.plugins")
plugins_to_load = enabled_plugins.copy()
for entrypoint in plugins_entrypoints:
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
entrypoint.name,
)
continue

class PluginEntrypoint:
def __init__(self, name: str, import_path: str, is_builtin: bool = False):
self.name = name
self.import_path = import_path
self.is_builtin = is_builtin

def load(self):
module_path, _, class_name = self.import_path.partition(":")
try:
module_path, _, class_name = entrypoint.value.partition(":")
module = import_module(module_path)
plugin_class = getattr(module, class_name, None)
if plugin_class is None:
logger.warning(
("Failed to load plugin %s: plugin class %s not found in module %s."),
self.name,
class_name,
module_path,
)
return None
if not issubclass(plugin_class, Plugin):
logger.warning(
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
self.name,
class_name,
)
return None
return plugin_class()
except ImportError:
logger.warning(
(
"Failed to load plugin %s when importing %s."
" Ensure the module is on the import path."
),
entrypoint.name,
entrypoint.value,
self.name,
self.import_path,
)
continue
plugin_class = getattr(module, class_name, None)
if plugin_class is None:
logger.warning(
("Failed to load plugin %s: plugin class %s not found in module %s."),
return None


def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
entrypoints: dict[str, PluginEntrypoint] = {}
plugins_to_load = enabled_plugins.copy()
for entrypoint in entry_points(group="dstack.plugins"):
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
entrypoint.name,
class_name,
module_path,
)
continue
if not issubclass(plugin_class, Plugin):
logger.warning(
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
entrypoint.name,
class_name,
else:
entrypoints[entrypoint.name] = PluginEntrypoint(
entrypoint.name, entrypoint.value, is_builtin=False
)
continue
plugins_to_load.remove(entrypoint.name)
_PLUGINS.append(plugin_class())
logger.info("Loaded plugin %s", entrypoint.name)

for name, import_path in _BUILTIN_PLUGINS.items():
if name not in enabled_plugins:
logger.info(
("Found not enabled builtin plugin %s. Plugin will not be loaded."),
name,
)
else:
entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)

for plugin_name, plugin_entrypoint in entrypoints.items():
plugin_instance = plugin_entrypoint.load()
if plugin_instance is not None:
_PLUGINS.append(plugin_instance)
plugins_to_load.remove(plugin_name)
logger.info("Loaded plugin %s", plugin_name)

if plugins_to_load:
logger.warning("Enabled plugins not found: %s", plugins_to_load)

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions src/dstack/plugins/builtin/rest_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ruff: noqa: F401
from dstack.plugins.builtin.rest_plugin._models import (
FleetSpecRequest,
FleetSpecResponse,
GatewaySpecRequest,
GatewaySpecResponse,
RunSpecRequest,
RunSpecResponse,
SpecApplyRequest,
SpecApplyResponse,
VolumeSpecRequest,
VolumeSpecResponse,
)
from dstack.plugins.builtin.rest_plugin._plugin import (
PLUGIN_SERVICE_URI_ENV_VAR_NAME,
CustomApplyPolicy,
RESTPlugin,
)
40 changes: 40 additions & 0 deletions src/dstack/plugins/builtin/rest_plugin/_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Generic, TypeVar

from pydantic import BaseModel

from dstack._internal.core.models.fleets import FleetSpec
from dstack._internal.core.models.gateways import GatewaySpec
from dstack._internal.core.models.runs import RunSpec
from dstack._internal.core.models.volumes import VolumeSpec

SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)


class SpecApplyRequest(BaseModel, Generic[SpecType]):
user: str
project: str
spec: SpecType

# Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable"
# error. This issue doesn't happen though when running the code in pytest, only when running the server.
def dict(self, *args, **kwargs):
d = super().dict(*args, **kwargs)
d.pop("__orig_class__", None)
return d


RunSpecRequest = SpecApplyRequest[RunSpec]
FleetSpecRequest = SpecApplyRequest[FleetSpec]
VolumeSpecRequest = SpecApplyRequest[VolumeSpec]
GatewaySpecRequest = SpecApplyRequest[GatewaySpec]


class SpecApplyResponse(BaseModel, Generic[SpecType]):
spec: SpecType
error: str | None = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these fields be documented with a brief description of what they are and their semantics (e.g. what it means if error is populated)? (here and in the request object)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As shown by the CI, str | None is not supported for Python 3.9 which we target so please use Optional.



RunSpecResponse = SpecApplyResponse[RunSpec]
FleetSpecResponse = SpecApplyResponse[FleetSpec]
VolumeSpecResponse = SpecApplyResponse[VolumeSpec]
GatewaySpecResponse = SpecApplyResponse[GatewaySpec]
Loading
Loading