Skip to content

Commit e54cb03

Browse files
Ailing Zhangfacebook-github-bot
Ailing Zhang
authored andcommitted
add/move a few apis in torch.hub (pytorch#18758)
Summary: * `torch.hub.list('pytorch/vision')` - show all available hub models in `pytorch/vision` * `torch.hub.show('pytorch/vision', 'resnet18')` - show docstring & example for `resnet18` in `pytorch/vision` * Moved `torch.utils.model_zoo.load_url` to `torch.hub.load_state_dict_from_url` and deprecate `torch.utils.model_zoo` * We have too many env to control where the cache dir is, it's not very necessary. I actually want to unify `TORCH_HUB_DIR`, `TORCH_HOME` and `TORCH_MODEL_ZOO`, but haven't done it. (more suggestions are welcome!) * Simplify `pytorch/vision` example in doc, it was used to show how how hub entrypoint can be written so had some confusing unnecessary args. An example of hub usage is shown below ``` In [1]: import torch In [2]: torch.hub.list('pytorch/vision', force_reload=True) Downloading: "https://github.com/pytorch/vision/archive/master.zip" to /private/home/ailzhang/.torch/hub/master.zip Out[2]: ['resnet18', 'resnet50'] In [3]: torch.hub.show('pytorch/vision', 'resnet18') Using cache found in /private/home/ailzhang/.torch/hub/vision_master Resnet18 model pretrained (bool): a recommended kwargs for all entrypoints args & kwargs are arguments for the function In [4]: model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True) Using cache found in /private/home/ailzhang/.torch/hub/vision_master ``` Pull Request resolved: pytorch#18758 Differential Revision: D14883651 Pulled By: ailzhang fbshipit-source-id: 6db6ab708a74121782a9154c44b0e190b23e8309
1 parent 5164622 commit e54cb03

File tree

6 files changed

+357
-267
lines changed

6 files changed

+357
-267
lines changed

docs/source/hub.rst

+35-34
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ Publishing models
88
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
99
to a github repository by adding a simple ``hubconf.py`` file;
1010

11-
``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function with
12-
the following signature.
11+
``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function
12+
(example: a pre-trained model you want to publish).
1313

1414
::
1515

16-
def entrypoint_name(pretrained=False, *args, **kwargs):
16+
def entrypoint_name(*args, **kwargs):
17+
# args & kwargs are optional, for models which take positional/keyword arguments.
1718
...
1819

1920
How to implement an entrypoint?
@@ -24,70 +25,70 @@ for ``resnet18`` model. You can see a full script in
2425

2526
::
2627

27-
dependencies = ['torch', 'math']
28+
dependencies = ['torch']
2829

29-
def resnet18(pretrained=False, *args, **kwargs):
30+
def resnet18(pretrained=False, **kwargs):
3031
"""
3132
Resnet18 model
32-
pretrained (bool): a recommended kwargs for all entrypoints
33-
args & kwargs are arguments for the function
33+
pretrained (bool): kwargs, load pretrained weights into the model
3434
"""
35-
######## Call the model in the repo ###############
35+
# Call the model in the repo
3636
from torchvision.models.resnet import resnet18 as _resnet18
37-
model = _resnet18(*args, **kwargs)
38-
######## End of call ##############################
39-
# The following logic is REQUIRED
40-
if pretrained:
41-
# For weights saved in local repo
42-
# model.load_state_dict(<path_to_saved_file>)
43-
44-
# For weights saved elsewhere
45-
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
46-
model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
37+
model = _resnet18(pretrained=pretrained, **kwargs)
4738
return model
4839

40+
4941
- ``dependencies`` variable is a **list** of package names required to to run the model.
50-
- Pretrained weights can either be stored local in the github repo, or loadable by
51-
``model_zoo.load()``.
5242
- ``pretrained`` controls whether to load the pre-trained weights provided by repo owners.
5343
- ``args`` and ``kwargs`` are passed along to the real callable function.
54-
- Docstring of the function works as a help message, explaining what does the model do and what
55-
are the allowed arguments.
44+
- Docstring of the function works as a help message. It explains what does the model do and what
45+
are the allowed positional/keyword arguments. It's highly recommended to add a few examples here.
5646
- Entrypoint function should **ALWAYS** return a model(nn.module).
47+
- Pretrained weights can either be stored local in the github repo, or loadable by
48+
``torch.hub.load_state_dict_from_url()``. In the example above ``torchvision.models.resnet.resnet18``
49+
handles ``pretrained``, alternatively you can put the following logic in the entrypoint.
50+
51+
::
52+
if kwargs.get('pretrained', False):
53+
# For checkpoint saved in local repo
54+
model.load_state_dict(<path_to_saved_checkpoint>)
55+
56+
# For checkpoint saved elsewhere
57+
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
58+
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
59+
5760

5861
Important Notice
5962
^^^^^^^^^^^^^^^^
6063

6164
- The published models should be at least in a branch/tag. It can't be a random commit.
6265

66+
6367
Loading models from Hub
6468
-----------------------
6569

66-
Users can load the pre-trained models using ``torch.hub.load()`` API.
70+
Pytorch Hub provides convenient APIs to explore all available models in hub through ``torch.hub.list()``,
71+
show docstring and examples through ``torch.hub.help()`` and load the pre-trained models using ``torch.hub.load()``
6772

6873

6974
.. automodule:: torch.hub
70-
.. autofunction:: load
7175

72-
Here's an example loading ``resnet18`` entrypoint from ``pytorch/vision`` repo.
76+
.. autofunction:: list
7377

74-
::
78+
.. autofunction:: help
7579

76-
hub_model = hub.load(
77-
'pytorch/vision:master', # repo_owner/repo_name:branch
78-
'resnet18', # entrypoint
79-
1234, # args for callable [not applicable to resnet]
80-
pretrained=True) # kwargs for callable
80+
.. autofunction:: load
8181

82-
Where are my downloaded model & weights saved?
82+
Where are my downloaded models saved?
8383
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8484

8585
The locations are used in the order of
8686

8787
- hub_dir: user specified path. It can be set in the following ways:
88-
- Setting the environment variable ``TORCH_HUB_DIR``
8988
- Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``
90-
- ``~/.torch/hub``
89+
- ``$TORCH_HOME/hub``, if environment variable ``TORCH_HOME`` is set.
90+
- ``$XDG_CACHE_HOME/torch/hub``, if environment variable ``XDG_CACHE_HOME` is set.
91+
- ``~/.cache/torch/hub``
9192

9293
.. autofunction:: set_dir
9394

docs/source/model_zoo.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
torch.utils.model_zoo
22
===================================
33

4+
Moved to `torch.hub`.
5+
46
.. automodule:: torch.utils.model_zoo
57
.. autofunction:: load_url

test/test_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,9 @@ def test_set_dir(self):
518518
assert os.path.exists(temp_dir + '/vision_master')
519519
shutil.rmtree(temp_dir + '/vision_master')
520520

521+
def test_list_entrypoints(self):
522+
entry_lists = hub.list('pytorch/vision', force_reload=True)
523+
self.assertObjectIn('resnet18', entry_lists)
521524

522525
if __name__ == '__main__':
523526
run_tests()

torch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def manager_path():
289289
import torch.utils.backcompat
290290
import torch.onnx
291291
import torch.jit
292+
import torch.hub
292293
import torch.random
293294
import torch.distributions
294295
import torch.testing

0 commit comments

Comments
 (0)