1
+ from dstack ._internal .core .errors import ServerError
2
+ from dstack ._internal .server .models import ProjectModel , UserModel
3
+ from plugins .rest_plugin .src .rest_plugin import PreApplyPolicy , PLUGIN_SERVICE_URI_ENV_VAR_NAME
4
+ import pytest
5
+ from sqlalchemy .ext .asyncio import AsyncSession
6
+ from pydantic import parse_obj_as
7
+ import os
8
+ import json
9
+ import requests
10
+ from unittest .mock import Mock
11
+
12
+ from dstack ._internal .core .models .runs import RunSpec
13
+ from dstack ._internal .core .models .configurations import ServiceConfiguration
14
+ from dstack ._internal .core .models .profiles import Profile
15
+ from dstack ._internal .core .models .resources import Range
16
+ from dstack ._internal .server .testing .common import (
17
+ create_project ,
18
+ create_user ,
19
+ create_repo ,
20
+ get_run_spec ,
21
+ )
22
+ from dstack ._internal .server .testing .conf import session , test_db # noqa: F401
23
+ from dstack ._internal .server .services import encryption as encryption # import for side-effect
24
+ import pytest_asyncio
25
+ from unittest import mock
26
+
27
+
28
+ async def create_run_spec (
29
+ session : AsyncSession ,
30
+ project : ProjectModel ,
31
+ replicas : str = 1 ,
32
+ ) -> RunSpec :
33
+ repo = await create_repo (session = session , project_id = project .id )
34
+ run_name = "test-run"
35
+ profile = Profile (name = "test-profile" )
36
+ spec = get_run_spec (
37
+ repo_id = repo .name ,
38
+ run_name = run_name ,
39
+ profile = profile ,
40
+ configuration = ServiceConfiguration (
41
+ commands = ["echo hello" ],
42
+ port = 8000 ,
43
+ replicas = parse_obj_as (Range [int ], replicas )
44
+ ),
45
+ )
46
+ return spec
47
+
48
+ @pytest_asyncio .fixture
49
+ async def project (session ):
50
+ return await create_project (session = session )
51
+
52
+ @pytest_asyncio .fixture
53
+ async def user (session ):
54
+ return await create_user (session = session )
55
+
56
+ @pytest_asyncio .fixture
57
+ async def run_spec (session , project ):
58
+ return await create_run_spec (session = session , project = project )
59
+
60
+
61
+ class TestRESTPlugin :
62
+ @pytest .mark .asyncio
63
+ async def test_on_run_apply_plugin_service_uri_not_set (self ):
64
+ with pytest .raises (ServerError ):
65
+ policy = PreApplyPolicy ()
66
+
67
+ @pytest .mark .asyncio
68
+ @mock .patch .dict (os .environ , {PLUGIN_SERVICE_URI_ENV_VAR_NAME : "http://mock" })
69
+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
70
+ async def test_on_run_apply_plugin_service_returns_mutated_spec (self , test_db , user , project , run_spec ):
71
+ policy = PreApplyPolicy ()
72
+ mock_response = Mock ()
73
+ run_spec_dict = run_spec .dict ()
74
+ run_spec_dict ["profile" ]["tags" ] = {"env" : "test" , "team" : "qa" }
75
+ mock_response .text = json .dumps (run_spec_dict )
76
+ mock_response .raise_for_status = Mock ()
77
+ with mock .patch ("requests.post" , return_value = mock_response ):
78
+ result = policy .on_apply (user = user .name , project = project .name , spec = run_spec )
79
+ assert result == RunSpec (** run_spec_dict )
80
+
81
+ @pytest .mark .asyncio
82
+ @mock .patch .dict (os .environ , {PLUGIN_SERVICE_URI_ENV_VAR_NAME : "http://mock" })
83
+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
84
+ async def test_on_run_apply_plugin_service_call_fails (self , test_db , user , project , run_spec ):
85
+ policy = PreApplyPolicy ()
86
+ with mock .patch ("requests.post" , side_effect = requests .RequestException ("fail" )):
87
+ result = policy .on_apply (user = user .name , project = project .name , spec = run_spec )
88
+ assert result == run_spec
89
+
90
+ @pytest .mark .asyncio
91
+ @mock .patch .dict (os .environ , {PLUGIN_SERVICE_URI_ENV_VAR_NAME : "http://mock" })
92
+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
93
+ async def test_on_run_apply_plugin_service_returns_invalid_spec (self , test_db , user , project , run_spec ):
94
+ policy = PreApplyPolicy ()
95
+ mock_response = Mock ()
96
+ mock_response .text = json .dumps ({"invalid-key" : "abc" })
97
+ mock_response .raise_for_status = Mock ()
98
+ with mock .patch ("requests.post" , return_value = mock_response ):
99
+ result = policy .on_apply (user .name , project = project .name , spec = run_spec )
100
+ # return original run spec
101
+ assert result == run_spec
102
+
0 commit comments