Skip to content

Commit bd77d0b

Browse files
authored
Update chai1.py
1 parent 75f6abf commit bd77d0b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

chai_lab/chai1.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def make_all_atom_feature_context(
491491
return feature_context
492492

493493

494-
def get_model() -> Model:
494+
def get_model(torch_device: torch.device) -> Model:
495495
return Model(
496496
feature_embedding=load_exported("feature_embedding.pt", torch_device),
497497
bond_loss_input_proj=load_exported("bond_loss_input_proj.pt", torch_device),
@@ -552,7 +552,7 @@ def run_inference(
552552
##
553553

554554
model_cached = model is not None
555-
model = model or get_model()
555+
model = model or get_model(torch_device)
556556

557557
all_candidates: list[StructureCandidates] = []
558558
for trunk_idx in range(num_trunk_samples):

0 commit comments

Comments
 (0)