1
+ import argparse
2
+ import queue
3
+ import subprocess
4
+ import signal
5
+ import os
6
+ import sys
7
+ import glob
8
+ import logging
9
+ import socket
10
+ import time
11
+ import tempfile
12
+
13
+ from flask import Flask , request
14
+ from flask_restful import reqparse , abort , Api , Resource , inputs
15
+
16
+ import proteopt .client
17
+
18
+ from proteopt .common import serialize , deserialize
19
+
20
+ app = Flask (__name__ )
21
+ api = Api (app )
22
+
23
+
24
+ class Proxy (Resource ):
25
+ endpoints = set ()
26
+ max_retries = None
27
+ client = None
28
+
29
+ @classmethod
30
+ def get_client (cls ):
31
+ if cls .client is None :
32
+ if not cls .endpoints :
33
+ raise ValueError ("No endpoints" )
34
+ cls .client = proteopt .client .Client (
35
+ endpoints = [e + "/tool" for e in cls .endpoints ],
36
+ max_retries = cls .max_retries )
37
+ return cls .client
38
+
39
+ def get (self , action , name ):
40
+ if action == "add-endpoint" :
41
+ endpoint = request .args .get ('endpoint' )
42
+ self .endpoints .add (endpoint )
43
+ return f"Added endpoint { endpoint } "
44
+ elif action == "remove-endpoint" :
45
+ endpoint = request .args .get ('endpoint' )
46
+ if endpoint in self .endpoints :
47
+ self .endpoints .remove (endpoint )
48
+ return f"Removed endpoint { endpoint } "
49
+ else :
50
+ return f"No such endpoint { endpoint } "
51
+ elif action == "status" :
52
+ lines = []
53
+ lines .extend (sorted (self .endpoints ))
54
+ return "\n " .join (lines )
55
+ elif action == "clear" :
56
+ self .endpoints .clear ()
57
+ return "Cleared endpoints"
58
+ return str (self .MODEL_CACHE .keys ())
59
+
60
+ class Tool (Resource ):
61
+ def get (self , tool_name ):
62
+ try :
63
+ max_parallelism = Proxy .get_client ().max_parallelism
64
+ except Exception as e :
65
+ logging .warning ("Couldn't get parallelism: %s" , e )
66
+ max_parallelism = 8
67
+ result = {
68
+ 'description' : 'proxy' ,
69
+ 'endpoints' : sorted (Proxy .endpoints ),
70
+ 'max_parallelism' : max_parallelism ,
71
+ }
72
+ return result , 200
73
+
74
+ def post (self , tool_name ):
75
+ payload = request .get_json ()
76
+ payload ['tool_name' ] = tool_name
77
+
78
+ client = Proxy .get_client ()
79
+ result_queue = queue .Queue ()
80
+ client .work_queue .put ((0 , payload , result_queue ))
81
+ (payload_id , return_payload ) = result_queue .get ()
82
+ assert payload_id == 0
83
+ return return_payload , 200
84
+
85
+
86
+ api .add_resource (Proxy , '/proxy/<action>' )
87
+ api .add_resource (Tool , '/tool/<tool_name>' )
88
+
89
+
90
+ # Run the test server
91
+ arg_parser = argparse .ArgumentParser ()
92
+ arg_parser .add_argument ("--no-cleanup" , action = "store_true" , default = False )
93
+ arg_parser .add_argument ("--max-retries" , default = 2 , type = int )
94
+ arg_parser .add_argument ("--endpoints" , nargs = "+" )
95
+ arg_parser .add_argument ("--host" , default = "127.0.0.1" )
96
+ arg_parser .add_argument ("--port" , type = int )
97
+ arg_parser .add_argument ("--write-endpoint-to-file" )
98
+ arg_parser .add_argument (
99
+ "--debug" ,
100
+ default = False ,
101
+ action = "store_true" )
102
+
103
+ arg_parser .add_argument (
104
+ "--launch-servers" ,
105
+ metavar = "N" ,
106
+ type = int ,
107
+ help = "Launch N API servers. If N=-1, then one server is launched per GPU and "
108
+ "the CUDA_VISIBLE_DEVICES parameter is set accordingly for each server." )
109
+ arg_parser .add_argument (
110
+ "--launch-args" ,
111
+ nargs = argparse .REMAINDER ,
112
+ help = "All following args are args for launched API servers." )
113
+
114
+ if __name__ == '__main__' :
115
+ args = arg_parser .parse_args (sys .argv [1 :])
116
+ logging .basicConfig (level = logging .INFO )
117
+
118
+ endpoint_to_process = {}
119
+ work_dir = None
120
+ if args .launch_servers :
121
+ print (args )
122
+ num_to_launch = args .launch_servers
123
+ set_cuda_visible_devices = False
124
+ if args .launch_servers == - 1 :
125
+ gpu_lines = subprocess .check_output (["nvidia-smi" , "-L" ]).decode ().split ("\n " )
126
+ gpu_lines = [g .strip () for g in gpu_lines ]
127
+ gpu_lines = [g for g in gpu_lines if g .startswith ("GPU " )]
128
+ num_to_launch = len (gpu_lines )
129
+ print (f"Detected { num_to_launch } GPUs." )
130
+ set_cuda_visible_devices = True
131
+
132
+ work_dir = tempfile .TemporaryDirectory (prefix = "proteopt_proxy_" )
133
+ for i in range (num_to_launch ):
134
+ endpoint_file = os .path .join (work_dir .name , f"endpoint.{ i } .txt" )
135
+ sub_args = [
136
+ "python" ,
137
+ os .path .join (os .path .dirname (__file__ ), "api.py" ),
138
+ ]
139
+ sub_args .extend (args .launch_args )
140
+ sub_args .extend (["--write-endpoint-to-file" , endpoint_file ])
141
+ if set_cuda_visible_devices :
142
+ sub_args .extend (["--cuda-visible-devices" , str (i )])
143
+ print (f"Launching API server { i } / { num_to_launch } with args:" )
144
+ print (sub_args )
145
+
146
+ logfile = os .path .join (work_dir .name , f"log.{ i } .txt" )
147
+ logfile_fd = open (logfile , "w+b" )
148
+ process = subprocess .Popen (
149
+ sub_args , stderr = logfile_fd , stdout = logfile_fd )
150
+ while process .poll () is None and not os .path .exists (endpoint_file ):
151
+ time .sleep (0.1 )
152
+ try :
153
+ endpoint = open (endpoint_file ).read ().strip ()
154
+ except IOError :
155
+ print ("Failed to load endpoint file. Process log:" )
156
+ logfile_fd .seek (0 )
157
+ for line in logfile_fd .readlines ():
158
+ print (line )
159
+ raise
160
+ print (f"API server { i } at endpoint { endpoint } will log to { logfile } " )
161
+ endpoint_to_process [endpoint ] = process
162
+ Proxy .endpoints .update (list (endpoint_to_process ))
163
+
164
+ Proxy .max_retries = args .max_retries
165
+ if args .endpoints :
166
+ Proxy .endpoints .update (args .endpoints )
167
+
168
+ print ("Initialized proxy with endpoints: " , Proxy .endpoints )
169
+
170
+ port = args .port
171
+ if not port :
172
+ # Identify an available port
173
+ # Based on https://stackoverflow.com/questions/5085656/how-to-select-random-port-number-in-flask
174
+ sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
175
+ sock .bind ((args .host , 0 ))
176
+ port = sock .getsockname ()[1 ]
177
+ sock .close ()
178
+
179
+ endpoint = "http://%s:%d" % (args .host , port )
180
+ print ("Endpoint will be" , endpoint )
181
+ if args .write_endpoint_to_file :
182
+ with open (args .write_endpoint_to_file , "w" ) as fd :
183
+ fd .write (endpoint )
184
+ fd .write ("\n " )
185
+ print ("Wrote" , args .write_endpoint_to_file )
186
+
187
+ def cleanup (sig , frame ):
188
+ import ipdb ; ipdb .set_trace ()
189
+ if args .debug :
190
+ print ("Dumping logs." )
191
+ for g in glob .glob (os .path .join (work_dir .name , "*.txt" )):
192
+ print ("*" * 40 )
193
+ print (g )
194
+ print ("*" * 40 )
195
+ for line in open (g ).readlines ():
196
+ print ("---" , line .rstrip ())
197
+
198
+ if work_dir is not None and not args .no_cleanup :
199
+ print (f"Cleaning up { work_dir } " )
200
+ work_dir .cleanup ()
201
+
202
+ while endpoint_to_process :
203
+ endpoint , process = endpoint_to_process .popitem ()
204
+ print (f"Terminating process with endpoint { endpoint } " )
205
+ process .terminate ()
206
+ if process .poll () is None :
207
+ process .kill ()
208
+ print ("Done." )
209
+ sys .exit (0 )
210
+
211
+ signal .signal (signal .SIGINT , cleanup )
212
+
213
+ app .run (
214
+ host = args .host ,
215
+ port = port ,
216
+ debug = args .debug ,
217
+ use_reloader = False ,
218
+ threaded = True )
0 commit comments