diff --git a/src/bioemu/model_utils.py b/src/bioemu/model_utils.py index d70d513..53a8626 100644 --- a/src/bioemu/model_utils.py +++ b/src/bioemu/model_utils.py @@ -41,13 +41,13 @@ def maybe_download_checkpoint( except HTTPError as e: fs = HfFileSystem() - available_checkpoints = [ + available_checkpoints = { Path(p).parent.name for p in fs.glob("microsoft/bioemu/checkpoints/*/checkpoint.ckpt") - ] + } available_configs = [ Path(p).parent.name for p in fs.glob("microsoft/bioemu/checkpoints/*/config.yaml") ] - available_model_names = sorted(set(available_checkpoints).intersection(available_configs)) + available_model_names = sorted(available_checkpoints.intersection(available_configs)) raise ValueError( f"Model {model_name} not found. Available model names: " f"{available_model_names}" ) from e diff --git a/src/bioemu/sample.py b/src/bioemu/sample.py index 1ec9f50..2a5ed8b 100644 --- a/src/bioemu/sample.py +++ b/src/bioemu/sample.py @@ -193,9 +193,9 @@ def main( logger.info("Converting samples to .pdb and .xtc...") samples_files = sorted(list(output_dir.glob("batch_*.npz"))) - sequences = [np.load(f)["sequence"].item() for f in samples_files] - if set(sequences) != {sequence}: - raise ValueError(f"Expected all sequences to be {sequence}, but got {set(sequences)}") + sequences = {np.load(f)["sequence"].item() for f in samples_files} + if sequences != {sequence}: + raise ValueError(f"Expected all sequences to be {sequence}, but got {sequences}") positions = torch.tensor(np.concatenate([np.load(f)["pos"] for f in samples_files])) node_orientations = torch.tensor( np.concatenate([np.load(f)["node_orientations"] for f in samples_files]) diff --git a/src/bioemu/training/foldedness.py b/src/bioemu/training/foldedness.py index c8e4126..c0ce257 100644 --- a/src/bioemu/training/foldedness.py +++ b/src/bioemu/training/foldedness.py @@ -87,8 +87,8 @@ def compute_fnc_for_list(batch: list[ChemGraph], reference_info: ReferenceInfo) Returns: torch tensor of fraction of native contacts. """ - seqs = [x.sequence for x in batch] - assert len(set(seqs)) == 1, "Batch should contain samples all from the same system." + seqs = {x.sequence for x in batch} + assert len(seqs) == 1, "Batch should contain samples all from the same system." sequence = seqs[0] device = batch[0].pos.device diff --git a/src/bioemu/training/loss.py b/src/bioemu/training/loss.py index e972e3c..b7800cc 100644 --- a/src/bioemu/training/loss.py +++ b/src/bioemu/training/loss.py @@ -145,8 +145,8 @@ def _estimate_squared_mean_error( loss: an estimate of [(mean foldedness of samples) - (target mean foldedness)]^2. """ assert isinstance(batch, list) # Not a Batch! - sequences = [x.sequence for x in batch] - assert len(set(sequences)) == 1, "Batch must contain samples all from the same system." + sequences = {x.sequence for x in batch} + assert len(sequences) == 1, "Batch must contain samples all from the same system." n = len(batch) assert n >= 2, "Batch must contain at least two samples."