diff --git a/lib/umbridge.h b/lib/umbridge.h index a779112e..b7a707c9 100644 --- a/lib/umbridge.h +++ b/lib/umbridge.h @@ -271,6 +271,22 @@ namespace umbridge { return supportsApplyHessian; } + void terminate(const json& config_json = json::parse("{}")) { + json request_body; + request_body["name"] = name; + + if (!config_json.empty()) + request_body["config"] = config_json; + + if (auto res = cli.Post("/Terminate", headers, request_body.dump(), "application/json")) { + json response_body = parse_result_with_error_handling(res); + std::string status = response_body["status"].get(); + std::cout << status << std::endl; + } else { + throw std::runtime_error("POST Terminate failed with error type '" + to_string(res.error()) + "'"); + } + } + private: mutable httplib::Client cli; @@ -693,6 +709,19 @@ namespace umbridge { res.set_content(response_body.dump(), "application/json"); }); + svr.Post("/Terminate", [&](const httplib::Request &req, httplib::Response &res) { + json request_body = json::parse(req.body); + Model& model = get_model_from_name(models, request_body["name"]); + json empty_default_config; + json config_json = request_body.value("config", empty_default_config); + + json response_body; + svr.stop(); + response_body["status"] = "Model server terminated."; + + res.set_content(response_body.dump(), "application/json"); + }); + std::cout << "Listening on port " << port << "..." << std::endl; #ifdef LOGGING diff --git a/umbridge/um.py b/umbridge/um.py index 7aabb051..6e829906 100755 --- a/umbridge/um.py +++ b/umbridge/um.py @@ -2,6 +2,7 @@ import requests import asyncio from concurrent.futures import ThreadPoolExecutor +import signal class Model(object): @@ -156,6 +157,14 @@ def apply_hessian(self, out_wrt, in_wrt1, in_wrt2, parameters, sens, vec, config raise Exception(f'Model returned error of type {response["error"]["type"]}: {response["error"]["message"]}') return response["output"] + def terminate(self): + inputParams = {} + inputParams["name"] = self.name + + response = requests.post(f"{self.url}/Terminate", json=inputParams).json() + print(response["status"]) + + def serve_models(models, port=4242, max_workers=1, error_checks=True): model_executor = ThreadPoolExecutor(max_workers=max_workers) @@ -437,6 +446,23 @@ async def info(request): return web.json_response(response_body) + @routes.post('/Terminate') + async def terminate(request): + req_json = await request.json() + model_name = req_json["name"] + model = get_model_from_name(model_name) + + if model is None: + return model_not_found_response(req_json["name"]) + + print("Sending SIGTERM to model server") + signal.raise_signal(signal.SIGTERM) + + return web.Response(text="{\"status\": \"Model server terminated.\"}") + + + + app = web.Application(client_max_size=None) app.add_routes(routes) web.run_app(app, port=port)