9
9
from collections import Counter
10
10
from dataclasses import dataclass
11
11
from pathlib import Path
12
- from typing import Dict , Sequence
12
+ from typing import Sequence
13
13
14
14
import numpy as np
15
15
import torch
16
16
import torch .export
17
+ from dataclasses import dataclass
17
18
from einops import einsum , rearrange , repeat
18
19
from torch import Tensor
19
20
from tqdm import tqdm
@@ -136,6 +137,18 @@ def forward(
136
137
return result
137
138
138
139
140
+ @dataclass
141
+ class Model :
142
+ """A dataclass for model weights."""
143
+
144
+ feature_embedding : ModuleWrapper
145
+ bond_loss_input_proj : ModuleWrapper
146
+ token_input_embedder : ModuleWrapper
147
+ trunk : ModuleWrapper | None
148
+ diffusion_module : ModuleWrapper | None
149
+ confidence_head : ModuleWrapper
150
+
151
+
139
152
def load_exported (comp_key : str , device : torch .device ) -> ModuleWrapper :
140
153
torch .jit .set_fusion_strategy ([("STATIC" , 0 ), ("DYNAMIC" , 0 )])
141
154
local_path = chai1_component (comp_key )
@@ -478,6 +491,17 @@ def make_all_atom_feature_context(
478
491
return feature_context
479
492
480
493
494
+ def get_model () -> Model :
495
+ return Model (
496
+ feature_embedding = load_exported ("feature_embedding.pt" , torch_device ),
497
+ bond_loss_input_proj = load_exported ("bond_loss_input_proj.pt" , torch_device ),
498
+ token_input_embedder = load_exported ("token_embedder.pt" , torch_device ),
499
+ trunk = load_exported ("trunk.pt" , torch_device ),
500
+ diffusion_module = load_exported ("diffusion_module.pt" , torch_device ),
501
+ confidence_head = load_exported ("confidence_head.pt" , torch_device ),
502
+ )
503
+
504
+
481
505
@torch .no_grad ()
482
506
def run_inference (
483
507
fasta_file : Path ,
@@ -499,7 +523,7 @@ def run_inference(
499
523
num_trunk_samples : int = 1 ,
500
524
seed : int | None = None ,
501
525
device : str | None = None ,
502
- model : Dict [ str , ModuleWrapper ] | None = None ,
526
+ model : Model | None = None ,
503
527
low_memory : bool = True ,
504
528
) -> StructureCandidates :
505
529
assert num_trunk_samples > 0 and num_diffn_samples > 0
@@ -523,6 +547,13 @@ def run_inference(
523
547
esm_device = torch_device ,
524
548
)
525
549
550
+ ##
551
+ ## Load exported models
552
+ ##
553
+
554
+ model_cached = model is not None
555
+ model = model or get_model ()
556
+
526
557
all_candidates : list [StructureCandidates ] = []
527
558
for trunk_idx in range (num_trunk_samples ):
528
559
logging .info (f"Trunk sample { trunk_idx + 1 } /{ num_trunk_samples } " )
@@ -533,14 +564,15 @@ def run_inference(
533
564
if num_trunk_samples > 1
534
565
else output_dir
535
566
),
567
+ model = model ,
536
568
num_trunk_recycles = num_trunk_recycles ,
537
569
num_diffn_timesteps = num_diffn_timesteps ,
538
570
num_diffn_samples = num_diffn_samples ,
539
571
recycle_msa_subsample = recycle_msa_subsample ,
540
572
seed = seed + trunk_idx if seed is not None else None ,
541
573
device = torch_device ,
542
- model = model ,
543
574
low_memory = low_memory ,
575
+ model_cached = model_cached ,
544
576
)
545
577
all_candidates .append (cand )
546
578
return StructureCandidates .concat (all_candidates )
@@ -555,6 +587,7 @@ def run_folding_on_context(
555
587
feature_context : AllAtomFeatureContext ,
556
588
* ,
557
589
output_dir : Path ,
590
+ model : Model ,
558
591
# expose some params for easy tweaking
559
592
recycle_msa_subsample : int = 0 ,
560
593
num_trunk_recycles : int = 3 ,
@@ -563,15 +596,13 @@ def run_folding_on_context(
563
596
num_diffn_samples : int = 5 ,
564
597
seed : int | None = None ,
565
598
device : torch .device | None = None ,
566
- model : Dict [str , ModuleWrapper ] | None = None ,
567
599
low_memory : bool ,
600
+ model_cached : bool ,
568
601
) -> StructureCandidates :
569
602
"""
570
603
Function for in-depth explorations.
571
604
User completely controls folding inputs.
572
605
"""
573
- model_provided = model is not None
574
-
575
606
# Set seed
576
607
if seed is not None :
577
608
set_seed ([seed ])
@@ -628,29 +659,14 @@ def run_folding_on_context(
628
659
)
629
660
block_atom_pair_mask = inputs ["block_atom_pair_mask" ]
630
661
631
- ##
632
- ## Load exported models
633
- ##
634
-
635
662
_ , _ , model_size = msa_mask .shape
636
663
assert model_size in AVAILABLE_MODEL_SIZES
637
664
638
- # Maybe load model weights
639
- if not model_provided :
640
- model = {
641
- "feature_embedding" : load_exported ("feature_embedding.pt" , device ),
642
- "bond_loss_input_proj" : load_exported ("bond_loss_input_proj.pt" , device ),
643
- "token_input_embedder" : load_exported ("token_embedder.pt" , device ),
644
- "trunk" : load_exported ("trunk.pt" , device ),
645
- "diffusion_module" : load_exported ("diffusion_module.pt" , device ),
646
- "confidence_head" : load_exported ("confidence_head.pt" , device ),
647
- }
648
-
649
665
##
650
666
## Run the features through the feature embedder
651
667
##
652
668
653
- embedded_features = model [ " feature_embedding" ] .forward (
669
+ embedded_features = model . feature_embedding .forward (
654
670
crop_size = model_size ,
655
671
move_to_device = device ,
656
672
return_on_cpu = low_memory ,
@@ -676,7 +692,7 @@ def run_folding_on_context(
676
692
677
693
bond_ft_gen = TokenBondRestraint ()
678
694
bond_ft = bond_ft_gen .generate (batch = batch ).data
679
- trunk_bond_feat , structure_bond_feat = model [ " bond_loss_input_proj" ] .forward (
695
+ trunk_bond_feat , structure_bond_feat = model . bond_loss_input_proj .forward (
680
696
return_on_cpu = low_memory ,
681
697
move_to_device = device ,
682
698
crop_size = model_size ,
@@ -689,7 +705,7 @@ def run_folding_on_context(
689
705
## Run the inputs through the token input embedder
690
706
##
691
707
692
- token_input_embedder_outputs : tuple [Tensor , ...] = model [ " token_input_embedder" ] .forward (
708
+ token_input_embedder_outputs : tuple [Tensor , ...] = model . token_input_embedder .forward (
693
709
return_on_cpu = low_memory ,
694
710
move_to_device = device ,
695
711
token_single_input_feats = token_single_input_feats ,
@@ -724,7 +740,7 @@ def run_folding_on_context(
724
740
msa_mask ,
725
741
)
726
742
)
727
- (token_single_trunk_repr , token_pair_trunk_repr ) = model [ " trunk" ] .forward (
743
+ (token_single_trunk_repr , token_pair_trunk_repr ) = model . trunk .forward (
728
744
move_to_device = device ,
729
745
token_single_trunk_initial_repr = token_single_initial_repr ,
730
746
token_pair_trunk_initial_repr = token_pair_initial_repr ,
@@ -746,8 +762,8 @@ def run_folding_on_context(
746
762
)
747
763
748
764
# We won't be using the trunk anymore; remove it from memory
749
- if not model_provided :
750
- del model [ " trunk" ]
765
+ if not model_cached :
766
+ model . trunk = None
751
767
torch .cuda .empty_cache ()
752
768
753
769
##
@@ -780,7 +796,7 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
780
796
atom_pos , "(b ds) ... -> b ds ..." , ds = ds
781
797
).contiguous ()
782
798
noise_sigma = repeat (sigma , " -> b ds" , b = batch_size , ds = ds )
783
- return model [ " diffusion_module" ] .forward (
799
+ return model . diffusion_module .forward (
784
800
atom_noised_coords = atom_noised_coords .float (),
785
801
noise_sigma = noise_sigma .float (),
786
802
crop_size = model_size ,
@@ -851,16 +867,17 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, ds: int) -> Tensor:
851
867
atom_pos = atom_pos + (sigma_next - sigma_hat ) * ((d_i_prime + d_i ) / 2 )
852
868
853
869
# We won't be running diffusion anymore
854
- if not model_provided :
855
- del model ["diffusion_module" ], static_diffusion_inputs
870
+ if not model_cached :
871
+ model .diffusion_module = None
872
+ del static_diffusion_inputs
856
873
torch .cuda .empty_cache ()
857
874
858
875
##
859
876
## Run the confidence model
860
877
##
861
878
862
879
confidence_outputs : list [tuple [Tensor , ...]] = [
863
- model [ " confidence_head" ] .forward (
880
+ model . confidence_head .forward (
864
881
move_to_device = device ,
865
882
token_single_input_repr = token_single_initial_repr ,
866
883
token_single_trunk_repr = token_single_trunk_repr ,
0 commit comments