Skip to content

Commit e91c1ac

Browse files
committed
Add proxy
1 parent 0bd818d commit e91c1ac

File tree

5 files changed

+300
-9
lines changed

5 files changed

+300
-9
lines changed

api.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ def get_model(self, tool_name, args):
9191
return model
9292

9393
def get(self, tool_name):
94-
return str(self.MODEL_CACHE.keys())
94+
result = {
95+
'description': 'runner',
96+
'max_parallelism': 1,
97+
'model_cache_keys': list(self.MODEL_CACHE.keys()),
98+
}
99+
return result, 200
95100

96101
def post(self, tool_name):
97102
tool_class = TOOLS[tool_name]

proteopt/client.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,26 @@ def __init__(self, endpoints, max_retries=2):
1414
self.endpoints = endpoints
1515
self.work_queue = Queue()
1616
self.max_retries = max_retries
17-
1817
self.threads = []
18+
1919
for endpoint in endpoints:
20-
thread = threading.Thread(
21-
target=self.worker_thread,
22-
name="thread_%s" % endpoint,
23-
daemon=True,
24-
args=(endpoint,))
25-
thread.start()
20+
session = requests.Session()
21+
full_endpoint = endpoint + "/info"
22+
info = session.get(full_endpoint)
23+
if info.status_code != 200:
24+
raise IOError(f"Couldn't get info for {full_endpoint}: {info.status_code} {info.text}")
25+
max_parallelism = info.json()['max_parallelism']
26+
print(f"Client: endpoint {endpoint} will use max_parallelism {max_parallelism}")
27+
for i in range(max_parallelism):
28+
thread = threading.Thread(
29+
target=self.worker_thread,
30+
name=f"thread_{i}_{endpoint}",
31+
daemon=True,
32+
args=(endpoint,))
33+
self.threads.append(thread)
34+
thread.start()
35+
36+
self.max_parallelism = max(1, len(self.threads))
2637

2738
def shutdown(self):
2839
work_queue = self.work_queue

proxy.py

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import argparse
2+
import queue
3+
import subprocess
4+
import signal
5+
import os
6+
import sys
7+
import glob
8+
import logging
9+
import socket
10+
import time
11+
import tempfile
12+
13+
from flask import Flask, request
14+
from flask_restful import reqparse, abort, Api, Resource, inputs
15+
16+
import proteopt.client
17+
18+
from proteopt.common import serialize, deserialize
19+
20+
app = Flask(__name__)
21+
api = Api(app)
22+
23+
24+
class Proxy(Resource):
25+
endpoints = set()
26+
max_retries = None
27+
client = None
28+
29+
@classmethod
30+
def get_client(cls):
31+
if cls.client is None:
32+
if not cls.endpoints:
33+
raise ValueError("No endpoints")
34+
cls.client = proteopt.client.Client(
35+
endpoints=[e + "/tool" for e in cls.endpoints],
36+
max_retries=cls.max_retries)
37+
return cls.client
38+
39+
def get(self, action, name):
40+
if action == "add-endpoint":
41+
endpoint = request.args.get('endpoint')
42+
self.endpoints.add(endpoint)
43+
return f"Added endpoint {endpoint}"
44+
elif action == "remove-endpoint":
45+
endpoint = request.args.get('endpoint')
46+
if endpoint in self.endpoints:
47+
self.endpoints.remove(endpoint)
48+
return f"Removed endpoint {endpoint}"
49+
else:
50+
return f"No such endpoint {endpoint}"
51+
elif action == "status":
52+
lines = []
53+
lines.extend(sorted(self.endpoints))
54+
return "\n".join(lines)
55+
elif action == "clear":
56+
self.endpoints.clear()
57+
return "Cleared endpoints"
58+
return str(self.MODEL_CACHE.keys())
59+
60+
class Tool(Resource):
61+
def get(self, tool_name):
62+
try:
63+
max_parallelism = Proxy.get_client().max_parallelism
64+
except Exception as e:
65+
logging.warning("Couldn't get parallelism: %s", e)
66+
max_parallelism = 8
67+
result = {
68+
'description': 'proxy',
69+
'endpoints': sorted(Proxy.endpoints),
70+
'max_parallelism': max_parallelism,
71+
}
72+
return result, 200
73+
74+
def post(self, tool_name):
75+
payload = request.get_json()
76+
payload['tool_name'] = tool_name
77+
78+
client = Proxy.get_client()
79+
result_queue = queue.Queue()
80+
client.work_queue.put((0, payload, result_queue))
81+
(payload_id, return_payload) = result_queue.get()
82+
assert payload_id == 0
83+
return return_payload, 200
84+
85+
86+
api.add_resource(Proxy, '/proxy/<action>')
87+
api.add_resource(Tool, '/tool/<tool_name>')
88+
89+
90+
# Run the test server
91+
arg_parser = argparse.ArgumentParser()
92+
arg_parser.add_argument("--no-cleanup", action="store_true", default=False)
93+
arg_parser.add_argument("--max-retries", default=2, type=int)
94+
arg_parser.add_argument("--endpoints", nargs="+")
95+
arg_parser.add_argument("--host", default="127.0.0.1")
96+
arg_parser.add_argument("--port", type=int)
97+
arg_parser.add_argument("--write-endpoint-to-file")
98+
arg_parser.add_argument(
99+
"--debug",
100+
default=False,
101+
action="store_true")
102+
103+
arg_parser.add_argument(
104+
"--launch-servers",
105+
metavar="N",
106+
type=int,
107+
help="Launch N API servers. If N=-1, then one server is launched per GPU and "
108+
"the CUDA_VISIBLE_DEVICES parameter is set accordingly for each server.")
109+
arg_parser.add_argument(
110+
"--launch-args",
111+
nargs=argparse.REMAINDER,
112+
help="All following args are args for launched API servers.")
113+
114+
if __name__ == '__main__':
115+
args = arg_parser.parse_args(sys.argv[1:])
116+
logging.basicConfig(level=logging.INFO)
117+
118+
endpoint_to_process = {}
119+
work_dir = None
120+
if args.launch_servers:
121+
print(args)
122+
num_to_launch = args.launch_servers
123+
set_cuda_visible_devices = False
124+
if args.launch_servers == -1:
125+
gpu_lines = subprocess.check_output(["nvidia-smi", "-L"]).decode().split("\n")
126+
gpu_lines = [g.strip() for g in gpu_lines]
127+
gpu_lines = [g for g in gpu_lines if g.startswith("GPU ")]
128+
num_to_launch = len(gpu_lines)
129+
print(f"Detected {num_to_launch} GPUs.")
130+
set_cuda_visible_devices = True
131+
132+
work_dir = tempfile.TemporaryDirectory(prefix="proteopt_proxy_")
133+
for i in range(num_to_launch):
134+
endpoint_file = os.path.join(work_dir.name, f"endpoint.{i}.txt")
135+
sub_args = [
136+
"python",
137+
os.path.join(os.path.dirname(__file__), "api.py"),
138+
]
139+
sub_args.extend(args.launch_args)
140+
sub_args.extend(["--write-endpoint-to-file", endpoint_file])
141+
if set_cuda_visible_devices:
142+
sub_args.extend(["--cuda-visible-devices", str(i)])
143+
print(f"Launching API server {i} / {num_to_launch} with args:")
144+
print(sub_args)
145+
146+
logfile = os.path.join(work_dir.name, f"log.{i}.txt")
147+
logfile_fd = open(logfile, "w+b")
148+
process = subprocess.Popen(
149+
sub_args, stderr=logfile_fd, stdout=logfile_fd)
150+
while process.poll() is None and not os.path.exists(endpoint_file):
151+
time.sleep(0.1)
152+
try:
153+
endpoint = open(endpoint_file).read().strip()
154+
except IOError:
155+
print("Failed to load endpoint file. Process log:")
156+
logfile_fd.seek(0)
157+
for line in logfile_fd.readlines():
158+
print(line)
159+
raise
160+
print(f"API server {i} at endpoint {endpoint} will log to {logfile}")
161+
endpoint_to_process[endpoint] = process
162+
Proxy.endpoints.update(list(endpoint_to_process))
163+
164+
Proxy.max_retries = args.max_retries
165+
if args.endpoints:
166+
Proxy.endpoints.update(args.endpoints)
167+
168+
print("Initialized proxy with endpoints: ", Proxy.endpoints)
169+
170+
port = args.port
171+
if not port:
172+
# Identify an available port
173+
# Based on https://stackoverflow.com/questions/5085656/how-to-select-random-port-number-in-flask
174+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
175+
sock.bind((args.host, 0))
176+
port = sock.getsockname()[1]
177+
sock.close()
178+
179+
endpoint = "http://%s:%d" % (args.host, port)
180+
print("Endpoint will be", endpoint)
181+
if args.write_endpoint_to_file:
182+
with open(args.write_endpoint_to_file, "w") as fd:
183+
fd.write(endpoint)
184+
fd.write("\n")
185+
print("Wrote", args.write_endpoint_to_file)
186+
187+
def cleanup(sig, frame):
188+
import ipdb ; ipdb.set_trace()
189+
if args.debug:
190+
print("Dumping logs.")
191+
for g in glob.glob(os.path.join(work_dir.name, "*.txt")):
192+
print("*" * 40)
193+
print(g)
194+
print("*" * 40)
195+
for line in open(g).readlines():
196+
print("---", line.rstrip())
197+
198+
if work_dir is not None and not args.no_cleanup:
199+
print(f"Cleaning up {work_dir}")
200+
work_dir.cleanup()
201+
202+
while endpoint_to_process:
203+
endpoint, process = endpoint_to_process.popitem()
204+
print(f"Terminating process with endpoint {endpoint}")
205+
process.terminate()
206+
if process.poll() is None:
207+
process.kill()
208+
print("Done.")
209+
sys.exit(0)
210+
211+
signal.signal(signal.SIGINT, cleanup)
212+
213+
app.run(
214+
host=args.host,
215+
port=port,
216+
debug=args.debug,
217+
use_reloader=False,
218+
threaded=True)

test/test_proxy.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import warnings
2+
3+
import numpy.testing
4+
5+
warnings.filterwarnings("ignore")
6+
7+
import proteopt
8+
import proteopt.client
9+
import proteopt.mock_tool
10+
11+
from .util import running_proxy_endpoint
12+
13+
14+
def test_basic(running_proxy_endpoint):
15+
client = proteopt.client.Client(endpoints=[running_proxy_endpoint])
16+
model = client.remote_model(proteopt.mock_tool.MockTool, greeting="hi")
17+
results = model.run_multiple([
18+
dict(name="tim", sleep_time=0.0),
19+
dict(name="joe", sleep_time=0.0, array=[1,2]),
20+
])
21+
results = list(results)
22+
assert results == ["test-server: hi tim", "test-server: hi joe 3.00"]

test/util.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,35 @@ def run_server(port, sleep_seconds=3.0):
3434
os.unlink(endpoint_file)
3535
return (process, endpoint + "/tool")
3636

37+
38+
def run_proxy(port):
39+
endpoint_file = "/tmp/proteopt_endpoint.txt"
40+
try:
41+
os.unlink(endpoint_file)
42+
except IOError:
43+
pass
44+
process = subprocess.Popen(
45+
[
46+
"python",
47+
os.path.join(REPO_ROOT_DIR, "proxy.py"),
48+
"--debug",
49+
"--no-cleanup",
50+
"--port", str(port),
51+
"--launch-servers", "3",
52+
"--write-endpoint-to-file", endpoint_file,
53+
"--launch-args",
54+
"--mock-server-name", 'test-server',
55+
"--alphafold-data-dir", ALPHAFOLD_WEIGHTS_DIR,
56+
"--omegafold-data-dir", OMEGAFOLD_WEIGHTS_DIR,
57+
"--rfdiffusion-motif-models-dir", RFDIFFUSION_WEIGHTS_DIR,
58+
])
59+
while process.poll() is None and not os.path.exists(endpoint_file):
60+
time.sleep(0.1)
61+
with open(endpoint_file) as fd:
62+
endpoint = fd.read().strip()
63+
os.unlink(endpoint_file)
64+
return (process, endpoint + "/tool")
65+
3766
@pytest.fixture
3867
def running_server_endpoint(port=0, sleep_seconds=0):
3968
(process, endpoint) = run_server(port)
@@ -54,4 +83,10 @@ def multiple_running_server_endpoints(ports=(0, 0), sleep_seconds=0):
5483
time.sleep(sleep_seconds)
5584
yield endpoints
5685
for process in processes:
57-
process.terminate()
86+
process.terminate()
87+
88+
@pytest.fixture
89+
def running_proxy_endpoint(port=0):
90+
(process, endpoint) = run_proxy(port)
91+
yield endpoint
92+
process.terminate()

0 commit comments

Comments
 (0)