Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions nbs_tests/model/model_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,7 @@
"outputs": [],
"source": [
"from peptdeep.model.ms2 import pDeepModel\n",
"from peptdeep.pretrained_models import MODEL_ZIP_FILE_PATH\n",
"from peptdeep.pretrained_models import download_models"
"from peptdeep.pretrained_models import download_models, get_model_zip_file_path"
]
},
{
Expand Down Expand Up @@ -664,7 +663,7 @@
"download_models()\n",
"ms2_model = pDeepModel()\n",
"ms2_model.build_from_py_codes(\n",
" MODEL_ZIP_FILE_PATH,\n",
" get_model_zip_file_path(),\n",
" 'generic/ms2.pth.model.py',\n",
" include_model_params_yaml=True\n",
")\n",
Expand Down
15 changes: 9 additions & 6 deletions nbs_tests/pretrained_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from peptdeep.pretrained_models import *"
]
"source": "from peptdeep.pretrained_models import download_models, is_model_zip, get_model_zip_file_path, ModelManager"
},
{
"cell_type": "code",
Expand All @@ -57,7 +55,7 @@
"source": [
"#| hide\n",
"download_models()\n",
"assert is_model_zip(MODEL_ZIP_FILE_PATH)"
"assert is_model_zip(get_model_zip_file_path())"
]
},
{
Expand Down Expand Up @@ -98,9 +96,12 @@
"metadata": {},
"outputs": [],
"source": [
"from zipfile import ZipFile\n",
"import os\n",
"\n",
"#| hide\n",
"assert os.path.isfile(MODEL_ZIP_FILE_PATH)\n",
"with ZipFile(MODEL_ZIP_FILE_PATH) as _zip:\n",
"assert os.path.isfile(get_model_zip_file_path())\n",
"with ZipFile(get_model_zip_file_path()) as _zip:\n",
" with _zip.open('generic/ms2.pth'):\n",
" pass\n",
" with _zip.open('generic/rt.pth'):\n",
Expand Down Expand Up @@ -130,6 +131,8 @@
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"#| hide\n",
"\n",
"matched_df = pd.read_csv(\n",
Expand Down
4 changes: 2 additions & 2 deletions peptdeep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def _gui(port, settings_yaml):
help="If overwrite existing model file.",
)
def _install_model(model_file, overwrite):
from peptdeep.pretrained_models import download_models, MODEL_URL
from peptdeep.pretrained_models import download_models, get_model_url

if not model_file:
download_models(MODEL_URL, overwrite=overwrite)
download_models(get_model_url(), overwrite=overwrite)
else:
download_models(model_file, overwrite=overwrite)

Expand Down
20 changes: 16 additions & 4 deletions peptdeep/hla/hla_class1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import peptdeep.model.building_block as building_block
from peptdeep.model.model_interface import ModelInterface, append_nAA_column_if_missing
from peptdeep.model.featurize import get_ascii_indices
from peptdeep.pretrained_models import PRETRAIN_DIR, download_models, global_settings
from peptdeep.pretrained_models import (
get_pretrain_dir,
download_models,
global_settings,
)

from .hla_utils import (
get_random_sequences,
Expand Down Expand Up @@ -132,9 +136,17 @@ class HLA1_Binding_Classifier(ModelInterface):
Class to predict HLA-binding probabilities of peptides.
"""

_model_zip_name = global_settings["local_hla_model_zip_name"]
_model_url = global_settings["hla_model_url"]
_model_zip = os.path.join(PRETRAIN_DIR, _model_zip_name)
@property
def _model_zip_name(self) -> str:
return global_settings["local_hla_model_zip_name"]

@property
def _model_url(self) -> str:
return global_settings["hla_model_url"]

@property
def _model_zip(self) -> str:
return os.path.join(get_pretrain_dir(), self._model_zip_name)

def __init__(
self,
Expand Down
137 changes: 85 additions & 52 deletions peptdeep/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,51 +39,66 @@

from peptdeep.settings import global_settings, update_global_settings

PRETRAIN_DIR = os.path.join(
os.path.join(

def get_pretrain_dir() -> str:
"""Get the pretrained models directory path dynamically from settings."""
return os.path.join(
os.path.expanduser(global_settings["PEPTDEEP_HOME"]), "pretrained_models"
)
)


def get_local_model_zip_name() -> str:
"""Get the local model zip file name dynamically from settings."""
return global_settings["local_model_zip_name"]


def get_model_url() -> str:
"""Get the model URL dynamically from settings."""
return global_settings["model_url"]


def get_model_zip_file_path() -> str:
"""Get the full path to the model zip file dynamically from settings."""
return os.path.join(get_pretrain_dir(), get_local_model_zip_name())


sys.modules[__name__].__class__ = ModuleWithDeprecations

LOCAL_MODEL_ZIP_NAME = global_settings["local_model_zip_name"]
MODEL_URL = global_settings["model_url"]
MODEL_ZIP_FILE_PATH = os.path.join(PRETRAIN_DIR, LOCAL_MODEL_ZIP_NAME)

ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "PRETRAIN_DIR")
ModuleWithDeprecations.deprecate(__name__, "model_zip_name", "LOCAL_MODAL_ZIP_NAME")
ModuleWithDeprecations.deprecate(__name__, "model_url", "MODEL_URL")
ModuleWithDeprecations.deprecate(__name__, "model_zip", "MODEL_ZIP_FILE_PATH")

MODEL_DOWNLOAD_INSTRUCTIONS = (
"Please download the "
f'zip or tar file by yourself from "{MODEL_URL}",'
" and use \n"
f'"peptdeep --install-model /path/to/{LOCAL_MODEL_ZIP_NAME}.zip"\n'
" to install the models"
ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "get_pretrain_dir()")
ModuleWithDeprecations.deprecate(
__name__, "model_zip_name", "get_local_model_zip_name()"
)
ModuleWithDeprecations.deprecate(__name__, "model_url", "get_model_url()")
ModuleWithDeprecations.deprecate(__name__, "model_zip", "get_model_zip_file_path()")


def get_model_download_instructions() -> str:
"""Get the model download instructions dynamically from settings."""
return (
"Please download the "
f'zip or tar file by yourself from "{get_model_url()}",'
" and use \n"
f'"peptdeep install-models --model-file /path/to/{get_local_model_zip_name()}.zip"\n'
" to install the models"
)


def is_model_zip(downloaded_zip):
with ZipFile(downloaded_zip) as zip:
return any(x == "generic/ms2.pth" for x in zip.namelist())


def download_models(
url: str = MODEL_URL, target_path: str = MODEL_ZIP_FILE_PATH, overwrite: bool = True
):
def download_models(url: str = None, target_path: str = None, overwrite: bool = True):
"""
Parameters
----------
url : str, optional
Remote or local path.
Defaults to :data:`peptdeep.pretrained_models.model_url`
Defaults to None, which will take the default using get_model_url()

target_path : str, optional
Target file path after download.
Defaults to :data:`peptdeep.pretrained_models.MODEL_ZIP_FILE_PATH`
Defaults to None, which will take the default using get_model_zip_file_path()

overwrite : bool, optional
overwrite old model files.
Expand All @@ -94,9 +109,21 @@ def download_models(
FileNotFoundError
If remote url is not accessible.
"""
if url is None:
url = get_model_url()
if target_path is None:
target_path = get_model_zip_file_path()

if not overwrite and os.path.exists(target_path):
raise FileExistsError(f"Model file already exists: {target_path}")

if url is None:
raise ValueError(
"Cannot download models: 'model_url' is not set in settings. "
"Please either set 'model_url' in your settings file, or ensure "
"the model file already exists at the expected location."
)

if not os.path.isfile(url):
logging.info(f"Downloading pretrained models from {url} to {target_path} ...")
try:
Expand All @@ -107,7 +134,7 @@ def download_models(
f.write(requests.read())
except Exception as e:
raise FileNotFoundError(
f"Downloading model failed: {e}.\n" + MODEL_DOWNLOAD_INSTRUCTIONS
f"Downloading model failed: {e}.\n" + get_model_download_instructions()
) from e
else:
logging.info(f"Copying pretrained models from {url} to {target_path} ...")
Expand All @@ -116,16 +143,18 @@ def download_models(
logging.info(f"Successfully downloaded pretrained models.")


def _download_models(model_zip_file_path: str) -> None:
def _download_models(model_zip_file_path: str = None) -> None:
"""Download models if not done yet."""
os.makedirs(PRETRAIN_DIR, exist_ok=True)
if model_zip_file_path is None:
model_zip_file_path = get_model_zip_file_path()
os.makedirs(get_pretrain_dir(), exist_ok=True)
if not os.path.exists(model_zip_file_path):
download_models()
if not is_model_zip(model_zip_file_path):
raise ValueError(
f"Local model file is not a valid zip: {model_zip_file_path}.\n"
f"Please delete this file and try again.\n"
f"Or: {MODEL_DOWNLOAD_INSTRUCTIONS}"
f"Or: {get_model_download_instructions()}"
)


Expand Down Expand Up @@ -206,38 +235,41 @@ def _sample(psm_df, n):


def load_phos_models(mask_modloss=True):
_download_models(MODEL_ZIP_FILE_PATH)
model_zip_file_path = get_model_zip_file_path()
_download_models(model_zip_file_path)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/ms2_phos.pth")
ms2_model.load(model_zip_file_path, model_path_in_zip="phospho/ms2_phos.pth")
rt_model = AlphaRTModel()
rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth")
rt_model.load(model_zip_file_path, model_path_in_zip="phospho/rt_phos.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth")
ccs_model.load(model_zip_file_path, model_path_in_zip="generic/ccs.pth")
return ms2_model, rt_model, ccs_model


def load_models(mask_modloss=True):
_download_models(MODEL_ZIP_FILE_PATH)
model_zip_file_path = get_model_zip_file_path()
_download_models(model_zip_file_path)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth")
ms2_model.load(model_zip_file_path, model_path_in_zip="generic/ms2.pth")
rt_model = AlphaRTModel()
rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth")
rt_model.load(model_zip_file_path, model_path_in_zip="generic/rt.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth")
ccs_model.load(model_zip_file_path, model_path_in_zip="generic/ccs.pth")
return ms2_model, rt_model, ccs_model


def load_models_by_model_type_in_zip(model_type_in_zip: str, mask_modloss=True):
_download_models(MODEL_ZIP_FILE_PATH)
model_zip_file_path = get_model_zip_file_path()
_download_models(model_zip_file_path)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ms2.pth"
model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/ms2.pth"
)
rt_model = AlphaRTModel()
rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/rt.pth")
rt_model.load(model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/rt.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ccs.pth"
model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/ccs.pth"
)
return ms2_model, rt_model, ccs_model

Expand Down Expand Up @@ -312,7 +344,7 @@ def __init__(
if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'.
Defaults to 'gpu'
"""
_download_models(MODEL_ZIP_FILE_PATH)
_download_models(get_model_zip_file_path())

self._train_psm_logging = True

Expand Down Expand Up @@ -484,18 +516,19 @@ def load_installed_models(self, model_type: str = "generic"):
It could be 'digly', 'phospho', 'HLA', or 'generic'.
Defaults to 'generic'.
"""
model_zip_file_path = get_model_zip_file_path()
if model_type.lower() in ["phospho", "phos", "phosphorylation"]:
self.ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth"
model_zip_file_path, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth"
model_zip_file_path, model_path_in_zip="phospho/rt_phos.pth"
)
self.ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth"
model_zip_file_path, model_path_in_zip="generic/ccs.pth"
)
self.charge_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/charge.pth"
model_zip_file_path, model_path_in_zip="generic/charge.pth"
)
elif model_type.lower() in [
"digly",
Expand All @@ -505,27 +538,27 @@ def load_installed_models(self, model_type: str = "generic"):
"ubiquitinylation",
]:
self.ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth"
model_zip_file_path, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="digly/rt_digly.pth"
model_zip_file_path, model_path_in_zip="digly/rt_digly.pth"
)
self.ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth"
model_zip_file_path, model_path_in_zip="generic/ccs.pth"
)
self.charge_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/charge.pth"
model_zip_file_path, model_path_in_zip="generic/charge.pth"
)
elif model_type.lower() in ["regular", "common", "generic"]:
self.ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth"
model_zip_file_path, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth")
self.rt_model.load(model_zip_file_path, model_path_in_zip="generic/rt.pth")
self.ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth"
model_zip_file_path, model_path_in_zip="generic/ccs.pth"
)
self.charge_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/charge.pth"
model_zip_file_path, model_path_in_zip="generic/charge.pth"
)
elif model_type.lower() in ["hla", "unspecific", "non-specific", "nonspecific"]:
self.load_installed_models(model_type="generic")
Expand Down
4 changes: 4 additions & 0 deletions peptdeep/utils/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def __getattr__(self, name: str) -> Any:
module_deprecations = self._deprecations[self.__name__]
if name in module_deprecations:
new_name = module_deprecations[name]

msg = f"{name} is deprecated! Use '{new_name}' instead."
if new_name.endswith("()"): # hack to support functions
raise AttributeError(msg)

warn(msg, DeprecationWarning)
print(f"WARNING: {msg}")
return self.__getattribute__(new_name)
Expand Down
Loading