diff --git a/nbs_tests/model/model_interface.ipynb b/nbs_tests/model/model_interface.ipynb index 59602bb0..995d0faf 100644 --- a/nbs_tests/model/model_interface.ipynb +++ b/nbs_tests/model/model_interface.ipynb @@ -485,8 +485,7 @@ "outputs": [], "source": [ "from peptdeep.model.ms2 import pDeepModel\n", - "from peptdeep.pretrained_models import MODEL_ZIP_FILE_PATH\n", - "from peptdeep.pretrained_models import download_models" + "from peptdeep.pretrained_models import download_models, get_model_zip_file_path" ] }, { @@ -664,7 +663,7 @@ "download_models()\n", "ms2_model = pDeepModel()\n", "ms2_model.build_from_py_codes(\n", - " MODEL_ZIP_FILE_PATH,\n", + " get_model_zip_file_path(),\n", " 'generic/ms2.pth.model.py',\n", " include_model_params_yaml=True\n", ")\n", diff --git a/nbs_tests/pretrained_models.ipynb b/nbs_tests/pretrained_models.ipynb index e73eddb7..ad3e3c69 100644 --- a/nbs_tests/pretrained_models.ipynb +++ b/nbs_tests/pretrained_models.ipynb @@ -45,9 +45,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from peptdeep.pretrained_models import *" - ] + "source": "from peptdeep.pretrained_models import download_models, is_model_zip, get_model_zip_file_path, ModelManager" }, { "cell_type": "code", @@ -57,7 +55,7 @@ "source": [ "#| hide\n", "download_models()\n", - "assert is_model_zip(MODEL_ZIP_FILE_PATH)" + "assert is_model_zip(get_model_zip_file_path())" ] }, { @@ -98,9 +96,12 @@ "metadata": {}, "outputs": [], "source": [ + "from zipfile import ZipFile\n", + "import os\n", + "\n", "#| hide\n", - "assert os.path.isfile(MODEL_ZIP_FILE_PATH)\n", - "with ZipFile(MODEL_ZIP_FILE_PATH) as _zip:\n", + "assert os.path.isfile(get_model_zip_file_path())\n", + "with ZipFile(get_model_zip_file_path()) as _zip:\n", " with _zip.open('generic/ms2.pth'):\n", " pass\n", " with _zip.open('generic/rt.pth'):\n", @@ -130,6 +131,8 @@ "metadata": {}, "outputs": [], "source": [ + "import pandas as pd\n", + "\n", "#| hide\n", "\n", "matched_df = pd.read_csv(\n", diff --git a/peptdeep/cli.py b/peptdeep/cli.py index 50d4497a..20c08e03 100644 --- a/peptdeep/cli.py +++ b/peptdeep/cli.py @@ -89,10 +89,10 @@ def _gui(port, settings_yaml): help="If overwrite existing model file.", ) def _install_model(model_file, overwrite): - from peptdeep.pretrained_models import download_models, MODEL_URL + from peptdeep.pretrained_models import download_models, get_model_url if not model_file: - download_models(MODEL_URL, overwrite=overwrite) + download_models(get_model_url(), overwrite=overwrite) else: download_models(model_file, overwrite=overwrite) diff --git a/peptdeep/hla/hla_class1.py b/peptdeep/hla/hla_class1.py index 8cbd9d89..b2d99c40 100644 --- a/peptdeep/hla/hla_class1.py +++ b/peptdeep/hla/hla_class1.py @@ -8,7 +8,11 @@ import peptdeep.model.building_block as building_block from peptdeep.model.model_interface import ModelInterface, append_nAA_column_if_missing from peptdeep.model.featurize import get_ascii_indices -from peptdeep.pretrained_models import PRETRAIN_DIR, download_models, global_settings +from peptdeep.pretrained_models import ( + get_pretrain_dir, + download_models, + global_settings, +) from .hla_utils import ( get_random_sequences, @@ -132,9 +136,17 @@ class HLA1_Binding_Classifier(ModelInterface): Class to predict HLA-binding probabilities of peptides. """ - _model_zip_name = global_settings["local_hla_model_zip_name"] - _model_url = global_settings["hla_model_url"] - _model_zip = os.path.join(PRETRAIN_DIR, _model_zip_name) + @property + def _model_zip_name(self) -> str: + return global_settings["local_hla_model_zip_name"] + + @property + def _model_url(self) -> str: + return global_settings["hla_model_url"] + + @property + def _model_zip(self) -> str: + return os.path.join(get_pretrain_dir(), self._model_zip_name) def __init__( self, diff --git a/peptdeep/pretrained_models.py b/peptdeep/pretrained_models.py index f372d364..3ab8e80e 100644 --- a/peptdeep/pretrained_models.py +++ b/peptdeep/pretrained_models.py @@ -39,31 +39,48 @@ from peptdeep.settings import global_settings, update_global_settings -PRETRAIN_DIR = os.path.join( - os.path.join( + +def get_pretrain_dir() -> str: + """Get the pretrained models directory path dynamically from settings.""" + return os.path.join( os.path.expanduser(global_settings["PEPTDEEP_HOME"]), "pretrained_models" ) -) + + +def get_local_model_zip_name() -> str: + """Get the local model zip file name dynamically from settings.""" + return global_settings["local_model_zip_name"] + + +def get_model_url() -> str: + """Get the model URL dynamically from settings.""" + return global_settings["model_url"] + + +def get_model_zip_file_path() -> str: + """Get the full path to the model zip file dynamically from settings.""" + return os.path.join(get_pretrain_dir(), get_local_model_zip_name()) sys.modules[__name__].__class__ = ModuleWithDeprecations -LOCAL_MODEL_ZIP_NAME = global_settings["local_model_zip_name"] -MODEL_URL = global_settings["model_url"] -MODEL_ZIP_FILE_PATH = os.path.join(PRETRAIN_DIR, LOCAL_MODEL_ZIP_NAME) - -ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "PRETRAIN_DIR") -ModuleWithDeprecations.deprecate(__name__, "model_zip_name", "LOCAL_MODAL_ZIP_NAME") -ModuleWithDeprecations.deprecate(__name__, "model_url", "MODEL_URL") -ModuleWithDeprecations.deprecate(__name__, "model_zip", "MODEL_ZIP_FILE_PATH") - -MODEL_DOWNLOAD_INSTRUCTIONS = ( - "Please download the " - f'zip or tar file by yourself from "{MODEL_URL}",' - " and use \n" - f'"peptdeep --install-model /path/to/{LOCAL_MODEL_ZIP_NAME}.zip"\n' - " to install the models" +ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "get_pretrain_dir()") +ModuleWithDeprecations.deprecate( + __name__, "model_zip_name", "get_local_model_zip_name()" ) +ModuleWithDeprecations.deprecate(__name__, "model_url", "get_model_url()") +ModuleWithDeprecations.deprecate(__name__, "model_zip", "get_model_zip_file_path()") + + +def get_model_download_instructions() -> str: + """Get the model download instructions dynamically from settings.""" + return ( + "Please download the " + f'zip or tar file by yourself from "{get_model_url()}",' + " and use \n" + f'"peptdeep install-models --model-file /path/to/{get_local_model_zip_name()}.zip"\n' + " to install the models" + ) def is_model_zip(downloaded_zip): @@ -71,19 +88,17 @@ def is_model_zip(downloaded_zip): return any(x == "generic/ms2.pth" for x in zip.namelist()) -def download_models( - url: str = MODEL_URL, target_path: str = MODEL_ZIP_FILE_PATH, overwrite: bool = True -): +def download_models(url: str = None, target_path: str = None, overwrite: bool = True): """ Parameters ---------- url : str, optional Remote or local path. - Defaults to :data:`peptdeep.pretrained_models.model_url` + Defaults to None, which will take the default using get_model_url() target_path : str, optional Target file path after download. - Defaults to :data:`peptdeep.pretrained_models.MODEL_ZIP_FILE_PATH` + Defaults to None, which will take the default using get_model_zip_file_path() overwrite : bool, optional overwrite old model files. @@ -94,9 +109,21 @@ def download_models( FileNotFoundError If remote url is not accessible. """ + if url is None: + url = get_model_url() + if target_path is None: + target_path = get_model_zip_file_path() + if not overwrite and os.path.exists(target_path): raise FileExistsError(f"Model file already exists: {target_path}") + if url is None: + raise ValueError( + "Cannot download models: 'model_url' is not set in settings. " + "Please either set 'model_url' in your settings file, or ensure " + "the model file already exists at the expected location." + ) + if not os.path.isfile(url): logging.info(f"Downloading pretrained models from {url} to {target_path} ...") try: @@ -107,7 +134,7 @@ def download_models( f.write(requests.read()) except Exception as e: raise FileNotFoundError( - f"Downloading model failed: {e}.\n" + MODEL_DOWNLOAD_INSTRUCTIONS + f"Downloading model failed: {e}.\n" + get_model_download_instructions() ) from e else: logging.info(f"Copying pretrained models from {url} to {target_path} ...") @@ -116,16 +143,18 @@ def download_models( logging.info(f"Successfully downloaded pretrained models.") -def _download_models(model_zip_file_path: str) -> None: +def _download_models(model_zip_file_path: str = None) -> None: """Download models if not done yet.""" - os.makedirs(PRETRAIN_DIR, exist_ok=True) + if model_zip_file_path is None: + model_zip_file_path = get_model_zip_file_path() + os.makedirs(get_pretrain_dir(), exist_ok=True) if not os.path.exists(model_zip_file_path): download_models() if not is_model_zip(model_zip_file_path): raise ValueError( f"Local model file is not a valid zip: {model_zip_file_path}.\n" f"Please delete this file and try again.\n" - f"Or: {MODEL_DOWNLOAD_INSTRUCTIONS}" + f"Or: {get_model_download_instructions()}" ) @@ -206,38 +235,41 @@ def _sample(psm_df, n): def load_phos_models(mask_modloss=True): - _download_models(MODEL_ZIP_FILE_PATH) + model_zip_file_path = get_model_zip_file_path() + _download_models(model_zip_file_path) ms2_model = pDeepModel(mask_modloss=mask_modloss) - ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/ms2_phos.pth") + ms2_model.load(model_zip_file_path, model_path_in_zip="phospho/ms2_phos.pth") rt_model = AlphaRTModel() - rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth") + rt_model.load(model_zip_file_path, model_path_in_zip="phospho/rt_phos.pth") ccs_model = AlphaCCSModel() - ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth") + ccs_model.load(model_zip_file_path, model_path_in_zip="generic/ccs.pth") return ms2_model, rt_model, ccs_model def load_models(mask_modloss=True): - _download_models(MODEL_ZIP_FILE_PATH) + model_zip_file_path = get_model_zip_file_path() + _download_models(model_zip_file_path) ms2_model = pDeepModel(mask_modloss=mask_modloss) - ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth") + ms2_model.load(model_zip_file_path, model_path_in_zip="generic/ms2.pth") rt_model = AlphaRTModel() - rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth") + rt_model.load(model_zip_file_path, model_path_in_zip="generic/rt.pth") ccs_model = AlphaCCSModel() - ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth") + ccs_model.load(model_zip_file_path, model_path_in_zip="generic/ccs.pth") return ms2_model, rt_model, ccs_model def load_models_by_model_type_in_zip(model_type_in_zip: str, mask_modloss=True): - _download_models(MODEL_ZIP_FILE_PATH) + model_zip_file_path = get_model_zip_file_path() + _download_models(model_zip_file_path) ms2_model = pDeepModel(mask_modloss=mask_modloss) ms2_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ms2.pth" + model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/ms2.pth" ) rt_model = AlphaRTModel() - rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/rt.pth") + rt_model.load(model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/rt.pth") ccs_model = AlphaCCSModel() ccs_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ccs.pth" + model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/ccs.pth" ) return ms2_model, rt_model, ccs_model @@ -312,7 +344,7 @@ def __init__( if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'. Defaults to 'gpu' """ - _download_models(MODEL_ZIP_FILE_PATH) + _download_models(get_model_zip_file_path()) self._train_psm_logging = True @@ -484,18 +516,19 @@ def load_installed_models(self, model_type: str = "generic"): It could be 'digly', 'phospho', 'HLA', or 'generic'. Defaults to 'generic'. """ + model_zip_file_path = get_model_zip_file_path() if model_type.lower() in ["phospho", "phos", "phosphorylation"]: self.ms2_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" + model_zip_file_path, model_path_in_zip="generic/ms2.pth" ) self.rt_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth" + model_zip_file_path, model_path_in_zip="phospho/rt_phos.pth" ) self.ccs_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" + model_zip_file_path, model_path_in_zip="generic/ccs.pth" ) self.charge_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/charge.pth" + model_zip_file_path, model_path_in_zip="generic/charge.pth" ) elif model_type.lower() in [ "digly", @@ -505,27 +538,27 @@ def load_installed_models(self, model_type: str = "generic"): "ubiquitinylation", ]: self.ms2_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" + model_zip_file_path, model_path_in_zip="generic/ms2.pth" ) self.rt_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="digly/rt_digly.pth" + model_zip_file_path, model_path_in_zip="digly/rt_digly.pth" ) self.ccs_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" + model_zip_file_path, model_path_in_zip="generic/ccs.pth" ) self.charge_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/charge.pth" + model_zip_file_path, model_path_in_zip="generic/charge.pth" ) elif model_type.lower() in ["regular", "common", "generic"]: self.ms2_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" + model_zip_file_path, model_path_in_zip="generic/ms2.pth" ) - self.rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth") + self.rt_model.load(model_zip_file_path, model_path_in_zip="generic/rt.pth") self.ccs_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" + model_zip_file_path, model_path_in_zip="generic/ccs.pth" ) self.charge_model.load( - MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/charge.pth" + model_zip_file_path, model_path_in_zip="generic/charge.pth" ) elif model_type.lower() in ["hla", "unspecific", "non-specific", "nonspecific"]: self.load_installed_models(model_type="generic") diff --git a/peptdeep/utils/deprecations.py b/peptdeep/utils/deprecations.py index f8ba3927..d17318d2 100644 --- a/peptdeep/utils/deprecations.py +++ b/peptdeep/utils/deprecations.py @@ -22,7 +22,11 @@ def __getattr__(self, name: str) -> Any: module_deprecations = self._deprecations[self.__name__] if name in module_deprecations: new_name = module_deprecations[name] + msg = f"{name} is deprecated! Use '{new_name}' instead." + if new_name.endswith("()"): # hack to support functions + raise AttributeError(msg) + warn(msg, DeprecationWarning) print(f"WARNING: {msg}") return self.__getattribute__(new_name)