Skip to content

Commit f6b065c

Browse files
committed
Support for remote proteinmpnn calls
1 parent 81a5108 commit f6b065c

10 files changed

+83
-19
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Common interface to protein design tools and structure predictors
44
The goal is provide a Python API that wraps recent protein design and structure prediction tools. We want to
55
be able to experiment with multiple approaches for same problem without rewriting our task for each tool's API.
66

7-
This package also provides a large docker image with the tools installed (see [docker/base](docker/base)) as well as an
8-
image that extends it to include a jupyter environment for interactive work (see [docker/full](docker/full)).
7+
This package also provides a huge docker image with the tools installed, including
8+
model weights.
99

1010
This is all very much WIP. Contributions are welcome.
1111

api.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from flask_restful import reqparse, abort, Api, Resource, inputs
1212

1313
import proteopt.alphafold
14+
import proteopt.proteinmpnn
1415
import proteopt.mock_tool
1516
from proteopt.common import serialize, deserialize
1617

@@ -38,6 +39,7 @@ def add_argument(parser, arg_name, info, append=False):
3839
TOOL_CLASSES = [
3940
proteopt.mock_tool.MockTool,
4041
proteopt.alphafold.AlphaFold,
42+
proteopt.proteinmpnn.ProteinMPNN,
4143
]
4244
TOOLS = dict((cls.tool_name, cls) for cls in TOOL_CLASSES)
4345

@@ -157,7 +159,8 @@ def post(self, tool_name):
157159
if __name__ == '__main__':
158160
args = arg_parser.parse_args(sys.argv[1:])
159161

160-
tool_configs = collections.defaultdict(dict) # tool name -> dict
162+
# tool name -> dict
163+
tool_configs = dict((tool_name, {}) for tool_name in TOOLS.keys())
161164
for (arg, (tool, parameter)) in arg_names_to_tool_configs.items():
162165
tool_configs[tool][parameter] = getattr(args, arg)
163166

docker/Dockerfile

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ RUN ln -s /data/static/alphafold-params . \
4545
&& mkdir -p ~/.cache \
4646
&& ln -s /data/static/omegafold_ckpt ~/.cache/omegafold_ckpt
4747

48+
# Was getting some errors from wandb without this:
49+
ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
50+
51+
4852
# ****************************
4953
FROM minimal as base
5054
RUN mkdir -p /data/static
@@ -217,8 +221,6 @@ RUN cd /data/static \
217221
# Switch to the design-env on start:
218222
RUN echo "conda activate design-env" >> ~/.bashrc
219223

220-
# Was getting some errors from wandb without this:
221-
ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
222224

223225
# ****************************
224226
# Jupyter setup for interactive usage

docker/container-files/deploy_local_server.sh

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ set -e
22
set -x
33

44
NUM="${1:-2}"
5-
NUM_PER=1
5+
NUM_PER="${2:-1}"
66

7-
source ../paths.sh
7+
ALPHAFOLD_WEIGHTS_DIR="/software/mlfold/alphafold-data"
88

99
ENDPOINTS_FILE=/tmp/PROTEOPT_ENDPOINTS.TXT
1010
PIDS_FILE=/tmp/PROTEOPT_ENDPOINTS.PIDS.TXT
@@ -18,7 +18,7 @@ do
1818
for j in $(seq $NUM_PER)
1919
do
2020
CUDA_VISIBLE_DEVICES=$(expr $i - 1) python \
21-
~/git/proteopt/api.py \
21+
~/proteopt/api.py \
2222
--debug \
2323
--alphafold-data-dir "$ALPHAFOLD_WEIGHTS_DIR" \
2424
--write-endpoint-to-file /tmp/proteopt_endpoint.txt &
File renamed without changes.

docker/run.shell.sh

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
docker run -it \
44
-v "$(realpath ~):/root/host" \
55
-v "$(realpath ..):/root/proteopt" \
6-
-p 8888:8888 \
76
--gpus all \
87
timodonnell/proteopt-complete:latest \
98
bash -c "cd /root/proteopt/docker ; bash update.sh ; cd ; bash"

proteopt/proteinmpnn.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tempfile
22
import os
33
import json
4-
import pickle
4+
from typing import Optional
55

66
import numpy
77
import pandas
@@ -12,20 +12,25 @@
1212

1313
from ProteinMPNN import protein_mpnn_run
1414

15+
from .common import args_from_function_signature
1516

1617
class ProteinMPNN(object):
18+
tool_name = "proteinmpnn"
19+
config_args = {}
20+
model_args = {}
21+
1722
def __init__(self):
1823
pass
1924

2025
def run(
2126
self,
22-
structure,
23-
fixed=None,
24-
num=1,
25-
ca_only=False,
26-
sampling_temp=0.1,
27-
batch_size=1,
28-
verbose=False):
27+
structure : prody.Atomic,
28+
fixed : Optional[prody.Atomic] = None,
29+
num : int = 1,
30+
ca_only : bool = False,
31+
sampling_temp : float = 0.1,
32+
batch_size : int = 1,
33+
verbose : bool = False):
2934

3035
# Reset resnums to avoid gaps
3136
chains = numpy.unique(structure.getChids())
@@ -124,5 +129,19 @@ def run(
124129
temp_dir.cleanup()
125130
return result_df
126131

132+
def run_multiple(self, list_of_dicts, show_progress=False):
133+
results = []
134+
135+
iterator = list_of_dicts
136+
if show_progress:
137+
import tqdm
138+
iterator = tqdm.tqdm(list_of_dicts)
139+
140+
for kwargs in iterator:
141+
result = self.run(**kwargs)
142+
assert result is not None
143+
results.append(result)
144+
return results
127145

146+
run_args = args_from_function_signature(run)
128147

test/test_alphafold_remote.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_multiple_endpoints_mock(multiple_running_server_endpoints):
4646

4747
def test_basic_mock(running_server_endpoint):
4848
client = proteopt.client.Client(endpoints=[running_server_endpoint])
49-
model = client.remote_model(
49+
model = client.premote_model(
5050
proteopt.alphafold.AlphaFold,
5151
model_name="MOCK",
5252
max_length=16,

test/test_proteinmpnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
ALPHAFOLD_WEIGHTS_DIR = "/software/mlfold/alphafold-data"
1414

1515

16-
def Xtest_basic():
16+
def test_basic():
1717
region1 = prody.parsePDB(
1818
os.path.join(DATA_DIR, "1MBN.pdb")
1919
).select("protein chain A and resid 10 to 39")

test/test_proteinmpnn_remote.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import warnings
2+
3+
import numpy.testing
4+
5+
warnings.filterwarnings("ignore")
6+
7+
import os
8+
import prody
9+
10+
import pytest
11+
12+
import proteopt
13+
import proteopt.client
14+
import proteopt.proteinmpnn
15+
16+
DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data")
17+
18+
from .util import running_server_endpoint
19+
20+
def test_basic(running_server_endpoint):
21+
client = proteopt.client.Client(endpoints=[running_server_endpoint])
22+
runner = client.remote_model(
23+
proteopt.proteinmpnn.ProteinMPNN)
24+
25+
region1 = prody.parsePDB(
26+
os.path.join(DATA_DIR, "1MBN.pdb")
27+
).select("protein chain A and resid 10 to 39")
28+
sequence = region1.ca.getSequence()
29+
fixed_region = region1.select("resid 25 to 28 or resid 35")
30+
31+
results = runner.run(region1, num=5, fixed=fixed_region)
32+
print(results)
33+
34+
assert results.shape[0] == 5
35+
assert list(results.seq.str.len().unique()) == [len(region1.ca)]
36+
37+
for _, row in results.iterrows():
38+
assert row.seq[:10] != sequence[:10]
39+
assert row.seq[15:19] == sequence[15:19]
40+
assert row.seq[25] == sequence[25]
41+

0 commit comments

Comments
 (0)