We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 75f6abf commit bd77d0bCopy full SHA for bd77d0b
chai_lab/chai1.py
@@ -491,7 +491,7 @@ def make_all_atom_feature_context(
491
return feature_context
492
493
494
-def get_model() -> Model:
+def get_model(torch_device: torch.device) -> Model:
495
return Model(
496
feature_embedding=load_exported("feature_embedding.pt", torch_device),
497
bond_loss_input_proj=load_exported("bond_loss_input_proj.pt", torch_device),
@@ -552,7 +552,7 @@ def run_inference(
552
##
553
554
model_cached = model is not None
555
- model = model or get_model()
+ model = model or get_model(torch_device)
556
557
all_candidates: list[StructureCandidates] = []
558
for trunk_idx in range(num_trunk_samples):
0 commit comments