Skip to content

Commit d32a08b

Browse files
committed
first release
1 parent 718fd9b commit d32a08b

37 files changed

+13782
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
# proteopt
22
Common interface to protein design tools and structure predictors
3+
4+
Details coming soon.

api.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import argparse
2+
import collections
3+
import traceback
4+
import os
5+
import sys
6+
import logging
7+
import socket
8+
import time
9+
10+
from flask import Flask
11+
from flask_restful import reqparse, abort, Api, Resource, inputs
12+
13+
import proteopt.alphafold
14+
import proteopt.mock_tool
15+
from proteopt.common import serialize, deserialize
16+
17+
app = Flask(__name__)
18+
api = Api(app)
19+
20+
21+
def add_argument(parser, arg_name, info, append=False):
22+
type = info['type']
23+
if type is object:
24+
type = str # We will serialize objects
25+
26+
d = {
27+
'type': type,
28+
}
29+
if 'default' in info:
30+
d["default"] = info["default"]
31+
else:
32+
d["required"] = True
33+
if append:
34+
d["action"] = "append"
35+
parser.add_argument(arg_name, **d)
36+
37+
38+
TOOL_CLASSES = [
39+
proteopt.mock_tool.MockTool,
40+
proteopt.alphafold.AlphaFold,
41+
]
42+
TOOLS = dict((cls.tool_name, cls) for cls in TOOL_CLASSES)
43+
44+
45+
class Tool(Resource):
46+
configuration = None # this should be set when the app is launched
47+
48+
tool_parsers = {}
49+
for (tool_name, tool_class) in TOOLS.items():
50+
tool_parsers[tool_name] = reqparse.RequestParser()
51+
for parameter, info in tool_class.model_args.items():
52+
add_argument(tool_parsers[tool_name], parameter, info)
53+
for parameter, info in tool_class.run_args.items():
54+
add_argument(
55+
tool_parsers[tool_name],
56+
parameter,
57+
info,
58+
append=not info['type'] is object)
59+
60+
MODEL_CACHE = collections.OrderedDict()
61+
62+
def get_model(self, tool_name, args):
63+
tool_class = TOOLS[tool_name]
64+
args_dict = dict(self.configuration[tool_name])
65+
cache_key = []
66+
for name, info in tool_class.model_args.items():
67+
value = getattr(args, name)
68+
cache_key.append((name, value))
69+
args_dict[name] = value
70+
71+
cache_key = tuple(cache_key)
72+
73+
try:
74+
return self.MODEL_CACHE[cache_key]
75+
except KeyError:
76+
pass
77+
78+
logging.info("Loading new model: %s %s", tool_name, str(args_dict))
79+
80+
model = tool_class(**args_dict)
81+
if len(self.MODEL_CACHE) >= self.configuration["model_cache_size"]:
82+
self.MODEL_CACHE.popitem(last=False)
83+
self.MODEL_CACHE[cache_key] = model
84+
return model
85+
86+
def get(self, tool_name):
87+
return str(self.MODEL_CACHE.keys())
88+
89+
def post(self, tool_name):
90+
tool_class = TOOLS[tool_name]
91+
92+
parser = self.tool_parsers[tool_name]
93+
args = parser.parse_args()
94+
try:
95+
total_start = time.time()
96+
model = self.get_model(tool_name, args)
97+
init_seconds = time.time() - total_start
98+
99+
run_arg_names = list(tool_class.run_args)
100+
101+
for arg in run_arg_names:
102+
if tool_class.run_args[arg]['type'] is object:
103+
setattr(args, arg, deserialize(getattr(args, arg)))
104+
105+
example_run_arg = run_arg_names[0]
106+
list_of_input_dicts = []
107+
for i in range(len(getattr(args, example_run_arg))):
108+
d = dict((arg, getattr(args, arg)[i]) for arg in run_arg_names)
109+
list_of_input_dicts.append(d)
110+
111+
start = time.time()
112+
results = model.run_multiple(list_of_input_dicts)
113+
assert not any(x is None for x in results)
114+
payload = {
115+
"success": True,
116+
"results": serialize(results),
117+
"init_seconds": init_seconds,
118+
"total_seconds": time.time() - start,
119+
}
120+
return payload, 200
121+
except Exception as e:
122+
exc_info = sys.exc_info()
123+
message = ''.join(traceback.format_exception(*exc_info))
124+
payload = {
125+
"success": False,
126+
"exception": (e.__class__.__name__, message),
127+
}
128+
return payload, 500
129+
130+
131+
api.add_resource(Tool, '/tool/<tool_name>')
132+
133+
# Run the test server
134+
arg_parser = argparse.ArgumentParser()
135+
136+
arg_parser.add_argument(
137+
"--debug",
138+
default=False,
139+
action="store_true")
140+
141+
arg_parser.add_argument(
142+
"--cuda-visible-devices")
143+
144+
arg_parser.add_argument("--host", default="127.0.0.1")
145+
arg_parser.add_argument("--write-endpoint-to-file")
146+
arg_parser.add_argument("--port", type=int)
147+
arg_parser.add_argument("--model-cache-size", type=float, default=1.0)
148+
149+
150+
arg_names_to_tool_configs = {}
151+
for tool_name, tool_class in TOOLS.items():
152+
for parameter, info in tool_class.config_args.items():
153+
arg_name = "%s_%s" % (tool_name, parameter)
154+
arg_names_to_tool_configs[arg_name] = (tool_name, parameter)
155+
add_argument(arg_parser, "--" + arg_name.replace("_", "-"), info)
156+
157+
if __name__ == '__main__':
158+
args = arg_parser.parse_args(sys.argv[1:])
159+
160+
tool_configs = collections.defaultdict(dict) # tool name -> dict
161+
for (arg, (tool, parameter)) in arg_names_to_tool_configs.items():
162+
tool_configs[tool][parameter] = getattr(args, arg)
163+
164+
print("Tool configuration parameters:")
165+
for name, d in tool_configs.items():
166+
print(name)
167+
for (k, v) in d.items():
168+
print("\t%15s = %15s" % (k, v))
169+
print()
170+
171+
Tool.configuration = dict(tool_configs)
172+
Tool.configuration["model_cache_size"] = args.model_cache_size
173+
174+
if args.cuda_visible_devices:
175+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
176+
177+
port = args.port
178+
if not port:
179+
# Identify an available port
180+
# Based on https://stackoverflow.com/questions/5085656/how-to-select-random-port-number-in-flask
181+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
182+
sock.bind((args.host, 0))
183+
port = sock.getsockname()[1]
184+
sock.close()
185+
186+
endpoint = "http://%s:%d" % (args.host, port)
187+
print("Endpoint will be", endpoint)
188+
if args.write_endpoint_to_file:
189+
with open(args.write_endpoint_to_file, "w") as fd:
190+
fd.write(endpoint)
191+
fd.write("\n")
192+
print("Wrote", args.write_endpoint_to_file)
193+
194+
app.run(host=args.host, port=port, debug=args.debug, use_reloader=False)

application.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from flask import Flask
2+
3+
# print a nice greeting.
4+
def say_hello(username = "World"):
5+
return '<p>Hello %s!</p>\n' % username
6+
7+
# some bits of text for the page.
8+
header_text = '''
9+
<html>\n<head> <title>EB Flask Test</title> </head>\n<body>'''
10+
instructions = '''
11+
<p><em>Hint</em>: This is a RESTful web service! Append a username
12+
to the URL (for example: <code>/Thelonious</code>) to say hello to
13+
someone specific.</p>\n'''
14+
home_link = '<p><a href="/">Back</a></p>\n'
15+
footer_text = '</body>\n</html>'
16+
17+
# EB looks for an 'application' callable by default.
18+
application = Flask(__name__)
19+
20+
# add a rule for the index page.
21+
application.add_url_rule(
22+
'/',
23+
'index',
24+
lambda: header_text + say_hello() + instructions + footer_text)
25+
26+
# add a rule when the page is accessed with a name appended to the site
27+
# URL.
28+
application.add_url_rule(
29+
'/<username>',
30+
'hello',
31+
lambda username: header_text + say_hello(username) + home_link + footer_text)
32+
33+
# run the app.
34+
if __name__ == "__main__":
35+
# Setting debug to True enables debug output. This line should be
36+
# removed before deploying a production app.
37+
application.debug = True
38+
application.run()

config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Flask configuration.
3+
"""
4+

docker/base/Dockerfile

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
FROM --platform=linux/x86_64 nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04 as shrunk
2+
3+
ENV PATH="/root/miniconda3/bin:${PATH}"
4+
ARG PATH="/root/miniconda3/bin:${PATH}"
5+
WORKDIR /root
6+
7+
RUN apt-get update \
8+
&& apt-get install -y wget git vim \
9+
&& rm -rf /var/lib/apt/lists/* \
10+
&& wget -nv \
11+
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
12+
&& mkdir /root/.conda \
13+
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
14+
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
15+
&& conda --version
16+
17+
# Make links
18+
RUN ln -s /data/static/alphafold-params . \
19+
&& ln -s /data/static/RFDesign . \
20+
&& ln -s /data/static/AlphaFold . \
21+
&& ln -s /data/static/openfold . \
22+
&& ln -s /data/static/OmegaFold . \
23+
&& ln -s /data/static/ProteinMPNN . \
24+
&& ln -s /data/static/design-env/ miniconda3/envs/design-env \
25+
&& mkdir -p /software/mlfold/alphafold-data \
26+
&& ln -s /data/static/alphafold-params /software/mlfold/alphafold-data/params \
27+
&& mkdir -p /data/static/omegafold_ckpt \
28+
&& mkdir -p ~/.cache \
29+
&& ln -s /data/static/omegafold_ckpt ~/.cache/omegafold_ckpt
30+
31+
# ****************************
32+
FROM shrunk as complete
33+
RUN mkdir -p /data/static
34+
35+
RUN cd /data/static \
36+
&& mkdir -p alphafold-params \
37+
&& wget -nv --progress=dot:giga --show-progress https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar -O params.tar \
38+
&& tar --extract --verbose --file=params.tar --directory=alphafold-params --preserve-permissions \
39+
&& rm -f params.tar
40+
41+
# Note that we are using a different version of pytorch than recommended
42+
# in RFDesign readme (they recommend pytorch=1.10.1)
43+
RUN conda update -n base -c defaults conda \
44+
&& conda config --set ssl_verify no \
45+
&& conda init bash \
46+
&& conda clean -afy
47+
48+
RUN rm miniconda3/envs/design-env \
49+
&& conda create -n design-env \
50+
python=3.8 \
51+
pytorch=1.11 \
52+
dgl-cuda11.3 \
53+
cudatoolkit=11.3 \
54+
cuda-toolkit \
55+
numpy scipy requests packaging pip \
56+
-c "nvidia/label/cuda-11.3.1" -c pytorch -c dglteam \
57+
&& mv miniconda3/envs/design-env /data/static \
58+
&& ln -s /data/static/design-env/ miniconda3/envs/design-env
59+
60+
RUN conda install -n design-env \
61+
pyg \
62+
openmm==7.5.1 \
63+
-c pyg -c conda-forge
64+
65+
RUN cd /data/static \
66+
&& git clone https://github.com/deepmind/AlphaFold.git \
67+
&& wget -q -P /data/static/AlphaFold/alphafold/common/ \
68+
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
69+
70+
RUN /root/miniconda3/envs/design-env/bin/pip install \
71+
https://github.com/openmm/pdbfixer/archive/refs/tags/v1.7.tar.gz \
72+
icecream==2.1.3 \
73+
lie_learn==0.0.1.post1 \
74+
opt_einsum==3.3.0 \
75+
e3nn==0.3.4 \
76+
&& /root/miniconda3/envs/design-env/bin/pip install \
77+
"jax[cuda]==0.3.25" \
78+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
79+
&& /root/miniconda3/envs/design-env/bin/pip install \
80+
dm-tree==0.1.6 \
81+
dm-haiku==0.0.9 \
82+
absl-py==1.0.0 \
83+
ml-collections==0.1.0 \
84+
tensorflow-gpu==2.11.0 \
85+
biopython==1.81 \
86+
pytorch-lightning==1.9.3 # For openfold \
87+
&& conda clean -afy \
88+
&& /root/miniconda3/envs/design-env/bin/pip install -e /data/static/AlphaFold \
89+
&& /root/miniconda3/envs/design-env/bin/pip cache purge
90+
91+
RUN cd /data/static \
92+
&& git clone --branch main https://github.com/timodonnell/RFDesign.git \
93+
&& cd RFDesign \
94+
&& git remote add upstream https://github.com/RosettaCommons/RFDesign.git \
95+
&& wget -nv -P rfdesign/hallucination/weights/rf_Nov05 http://files.ipd.uw.edu/pub/rfdesign/weights/BFF_last.pt \
96+
&& wget -nv -P rfdesign/inpainting/weights/ http://files.ipd.uw.edu/pub/rfdesign/weights/BFF_mix_epoch25.pt \
97+
&& /root/miniconda3/envs/design-env/bin/pip install -e .
98+
99+
RUN cd /data/static \
100+
&& git clone https://github.com/timodonnell/ProteinMPNN.git \
101+
&& cd ProteinMPNN \
102+
&& git remote add upstream https://github.com/dauparas/ProteinMPNN.git \
103+
&& /root/miniconda3/envs/design-env/bin/pip install -e .
104+
105+
# OpenFold
106+
RUN cd /data/static \
107+
&& git clone https://github.com/timodonnell/openfold.git \
108+
&& cd openfold \
109+
&& git remote add upstream https://github.com/aqlaboratory/openfold.git \
110+
&& /root/miniconda3/envs/design-env/bin/pip install -e . \
111+
&& /root/miniconda3/envs/design-env/bin/pip cache purge \
112+
&& mkdir -p /data/static/openfold-params/ \
113+
&& bash scripts/download_openfold_params_huggingface.sh /data/static/openfold-params/
114+
115+
RUN mkdir -p /data/static/example-data \
116+
&& wget -nv -P /data/static/example-data https://files.rcsb.org/download/7SL5.pdb
117+
COPY container-files/test_protein.fa /data/static/example-data/
118+
119+
# We run OmegaFold to force it download the weights
120+
RUN cd /data/static \
121+
&& git clone https://github.com/timodonnell/OmegaFold \
122+
&& cd OmegaFold \
123+
&& git remote add upstream https://github.com/HeliXonProtein/OmegaFold \
124+
&& /root/miniconda3/envs/design-env/bin/python setup.py install \
125+
&& mkdir -p /tmp/omegafold_out \
126+
&& /root/miniconda3/envs/design-env/bin/omegafold /data/static/example-data/test_protein.fa /tmp/omegafold_out \
127+
&& ls -lh /tmp/omegafold_out \
128+
&& rm -rf /tmp/omegafold_out
129+
130+
# Not sure why this is needed, but it seems to be:
131+
RUN /root/miniconda3/envs/design-env/bin/pip install -e /data/static/AlphaFold
132+
133+
# Switch to the design-env on start:
134+
RUN echo "conda activate design-env" >> ~/.bashrc

docker/base/build.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set -e
2+
set -x
3+
4+
time docker build -t timodonnell/proteopt-base-gpu:latest .
5+
time docker build --target shrunk -t timodonnell/proteopt-base-gpu-shrunk:latest .
6+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
>ubiquitin
2+
MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQRESTLHLVLRLRGG

docker/base/push_to_dockerhub.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
docker push timodonnell/proteopt-base-gpu:latest
2+
docker push timodonnell/proteopt-base-gpu-shrunk:latest

0 commit comments

Comments
 (0)