Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions lib/umbridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,22 @@ namespace umbridge {
return supportsApplyHessian;
}

void terminate(const json& config_json = json::parse("{}")) {
json request_body;
request_body["name"] = name;

if (!config_json.empty())
request_body["config"] = config_json;

if (auto res = cli.Post("/Terminate", headers, request_body.dump(), "application/json")) {
json response_body = parse_result_with_error_handling(res);
std::string status = response_body["status"].get<std::string>();
std::cout << status << std::endl;
} else {
throw std::runtime_error("POST Terminate failed with error type '" + to_string(res.error()) + "'");
}
}

private:

mutable httplib::Client cli;
Expand Down Expand Up @@ -693,6 +709,19 @@ namespace umbridge {
res.set_content(response_body.dump(), "application/json");
});

svr.Post("/Terminate", [&](const httplib::Request &req, httplib::Response &res) {
json request_body = json::parse(req.body);
Model& model = get_model_from_name(models, request_body["name"]);
json empty_default_config;
json config_json = request_body.value("config", empty_default_config);

json response_body;
svr.stop();
response_body["status"] = "Model server terminated.";

res.set_content(response_body.dump(), "application/json");
});

std::cout << "Listening on port " << port << "..." << std::endl;

#ifdef LOGGING
Expand Down
26 changes: 26 additions & 0 deletions umbridge/um.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
import asyncio
from concurrent.futures import ThreadPoolExecutor
import signal

class Model(object):

Expand Down Expand Up @@ -156,6 +157,14 @@ def apply_hessian(self, out_wrt, in_wrt1, in_wrt2, parameters, sens, vec, config
raise Exception(f'Model returned error of type {response["error"]["type"]}: {response["error"]["message"]}')
return response["output"]

def terminate(self):
inputParams = {}
inputParams["name"] = self.name

response = requests.post(f"{self.url}/Terminate", json=inputParams).json()
print(response["status"])


def serve_models(models, port=4242, max_workers=1, error_checks=True):

model_executor = ThreadPoolExecutor(max_workers=max_workers)
Expand Down Expand Up @@ -437,6 +446,23 @@ async def info(request):
return web.json_response(response_body)


@routes.post('/Terminate')
async def terminate(request):
req_json = await request.json()
model_name = req_json["name"]
model = get_model_from_name(model_name)

if model is None:
return model_not_found_response(req_json["name"])

print("Sending SIGTERM to model server")
signal.raise_signal(signal.SIGTERM)

return web.Response(text="{\"status\": \"Model server terminated.\"}")




app = web.Application(client_max_size=None)
app.add_routes(routes)
web.run_app(app, port=port)
Loading