Skip to content

Commit 75f6abf

Browse files
authored
Update chai1.py
1 parent eeafd2c commit 75f6abf

File tree

1 file changed

+48
-31
lines changed

1 file changed

+48
-31
lines changed

chai_lab/chai1.py

+48-31
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from collections import Counter
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Dict, Sequence
12+
from typing import Sequence
1313

1414
import numpy as np
1515
import torch
1616
import torch.export
17+
from dataclasses import dataclass
1718
from einops import einsum, rearrange, repeat
1819
from torch import Tensor
1920
from tqdm import tqdm
@@ -136,6 +137,18 @@ def forward(
136137
return result
137138

138139

140+
@dataclass
141+
class Model:
142+
"""A dataclass for model weights."""
143+
144+
feature_embedding: ModuleWrapper
145+
bond_loss_input_proj: ModuleWrapper
146+
token_input_embedder: ModuleWrapper
147+
trunk: ModuleWrapper | None
148+
diffusion_module: ModuleWrapper | None
149+
confidence_head: ModuleWrapper
150+
151+
139152
def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper:
140153
torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)])
141154
local_path = chai1_component(comp_key)
@@ -478,6 +491,17 @@ def make_all_atom_feature_context(
478491
return feature_context
479492

480493

494+
def get_model() -> Model:
495+
return Model(
496+
feature_embedding=load_exported("feature_embedding.pt", torch_device),
497+
bond_loss_input_proj=load_exported("bond_loss_input_proj.pt", torch_device),
498+
token_input_embedder=load_exported("token_embedder.pt", torch_device),
499+
trunk=load_exported("trunk.pt", torch_device),
500+
diffusion_module=load_exported("diffusion_module.pt", torch_device),
501+
confidence_head=load_exported("confidence_head.pt", torch_device),
502+
)
503+
504+
481505
@torch.no_grad()
482506
def run_inference(
483507
fasta_file: Path,
@@ -499,7 +523,7 @@ def run_inference(
499523
num_trunk_samples: int = 1,
500524
seed: int | None = None,
501525
device: str | None = None,
502-
model: Dict[str, ModuleWrapper] | None = None,
526+
model: Model | None = None,
503527
low_memory: bool = True,
504528
) -> StructureCandidates:
505529
assert num_trunk_samples > 0 and num_diffn_samples > 0
@@ -523,6 +547,13 @@ def run_inference(
523547
esm_device=torch_device,
524548
)
525549

550+
##
551+
## Load exported models
552+
##
553+
554+
model_cached = model is not None
555+
model = model or get_model()
556+
526557
all_candidates: list[StructureCandidates] = []
527558
for trunk_idx in range(num_trunk_samples):
528559
logging.info(f"Trunk sample {trunk_idx + 1}/{num_trunk_samples}")
@@ -533,14 +564,15 @@ def run_inference(
533564
if num_trunk_samples > 1
534565
else output_dir
535566
),
567+
model=model,
536568
num_trunk_recycles=num_trunk_recycles,
537569
num_diffn_timesteps=num_diffn_timesteps,
538570
num_diffn_samples=num_diffn_samples,
539571
recycle_msa_subsample=recycle_msa_subsample,
540572
seed=seed + trunk_idx if seed is not None else None,
541573
device=torch_device,
542-
model=model,
543574
low_memory=low_memory,
575+
model_cached=model_cached,
544576
)
545577
all_candidates.append(cand)
546578
return StructureCandidates.concat(all_candidates)
@@ -555,6 +587,7 @@ def run_folding_on_context(
555587
feature_context: AllAtomFeatureContext,
556588
*,
557589
output_dir: Path,
590+
model: Model,
558591
# expose some params for easy tweaking
559592
recycle_msa_subsample: int = 0,
560593
num_trunk_recycles: int = 3,
@@ -563,15 +596,13 @@ def run_folding_on_context(
563596
num_diffn_samples: int = 5,
564597
seed: int | None = None,
565598
device: torch.device | None = None,
566-
model: Dict[str, ModuleWrapper] | None = None,
567599
low_memory: bool,
600+
model_cached: bool,
568601
) -> StructureCandidates:
569602
"""
570603
Function for in-depth explorations.
571604
User completely controls folding inputs.
572605
"""
573-
model_provided = model is not None
574-
575606
# Set seed
576607
if seed is not None:
577608
set_seed([seed])
@@ -628,29 +659,14 @@ def run_folding_on_context(
628659
)
629660
block_atom_pair_mask = inputs["block_atom_pair_mask"]
630661

631-
##
632-
## Load exported models
633-
##
634-
635662
_, _, model_size = msa_mask.shape
636663
assert model_size in AVAILABLE_MODEL_SIZES
637664

638-
# Maybe load model weights
639-
if not model_provided:
640-
model = {
641-
"feature_embedding": load_exported("feature_embedding.pt", device),
642-
"bond_loss_input_proj": load_exported("bond_loss_input_proj.pt", device),
643-
"token_input_embedder": load_exported("token_embedder.pt", device),
644-
"trunk": load_exported("trunk.pt", device),
645-
"diffusion_module": load_exported("diffusion_module.pt", device),
646-
"confidence_head": load_exported("confidence_head.pt", device),
647-
}
648-
649665
##
650666
## Run the features through the feature embedder
651667
##
652668

653-
embedded_features = model["feature_embedding"].forward(
669+
embedded_features = model.feature_embedding.forward(
654670
crop_size=model_size,
655671
move_to_device=device,
656672
return_on_cpu=low_memory,
@@ -676,7 +692,7 @@ def run_folding_on_context(
676692

677693
bond_ft_gen = TokenBondRestraint()
678694
bond_ft = bond_ft_gen.generate(batch=batch).data
679-
trunk_bond_feat, structure_bond_feat = model["bond_loss_input_proj"].forward(
695+
trunk_bond_feat, structure_bond_feat = model.bond_loss_input_proj.forward(
680696
return_on_cpu=low_memory,
681697
move_to_device=device,
682698
crop_size=model_size,
@@ -689,7 +705,7 @@ def run_folding_on_context(
689705
## Run the inputs through the token input embedder
690706
##
691707

692-
token_input_embedder_outputs: tuple[Tensor, ...] = model["token_input_embedder"].forward(
708+
token_input_embedder_outputs: tuple[Tensor, ...] = model.token_input_embedder.forward(
693709
return_on_cpu=low_memory,
694710
move_to_device=device,
695711
token_single_input_feats=token_single_input_feats,
@@ -724,7 +740,7 @@ def run_folding_on_context(
724740
msa_mask,
725741
)
726742
)
727-
(token_single_trunk_repr, token_pair_trunk_repr) = model["trunk"].forward(
743+
(token_single_trunk_repr, token_pair_trunk_repr) = model.trunk.forward(
728744
move_to_device=device,
729745
token_single_trunk_initial_repr=token_single_initial_repr,
730746
token_pair_trunk_initial_repr=token_pair_initial_repr,
@@ -746,8 +762,8 @@ def run_folding_on_context(
746762
)
747763

748764
# We won't be using the trunk anymore; remove it from memory
749-
if not model_provided:
750-
del model["trunk"]
765+
if not model_cached:
766+
model.trunk = None
751767
torch.cuda.empty_cache()
752768

753769
##
@@ -780,7 +796,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
780796
atom_pos, "(b ds) ... -> b ds ...", ds=ds
781797
).contiguous()
782798
noise_sigma = repeat(sigma, " -> b ds", b=batch_size, ds=ds)
783-
return model["diffusion_module"].forward(
799+
return model.diffusion_module.forward(
784800
atom_noised_coords=atom_noised_coords.float(),
785801
noise_sigma=noise_sigma.float(),
786802
crop_size=model_size,
@@ -851,16 +867,17 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
851867
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)
852868

853869
# We won't be running diffusion anymore
854-
if not model_provided:
855-
del model["diffusion_module"], static_diffusion_inputs
870+
if not model_cached:
871+
model.diffusion_module = None
872+
del static_diffusion_inputs
856873
torch.cuda.empty_cache()
857874

858875
##
859876
## Run the confidence model
860877
##
861878

862879
confidence_outputs: list[tuple[Tensor, ...]] = [
863-
model["confidence_head"].forward(
880+
model.confidence_head.forward(
864881
move_to_device=device,
865882
token_single_input_repr=token_single_initial_repr,
866883
token_single_trunk_repr=token_single_trunk_repr,

0 commit comments

Comments
 (0)