torch.hub¶
Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Publishing models¶
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
to a github repository by adding a simple hubconf.py
file;
hubconf.py
can have multiple entrypoints. Each entrypoint is defined as a python function
(example: a pre-trained model you want to publish).
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
How to implement an entrypoint?¶
Here is a code snippet specifies an entrypoint for resnet18
model if we expand
the implementation in pytorch/vision/hubconf.conf
.
In most case importing the right function in hubconf.conf
is sufficient. Here we
just want to use the expanded version as an example to show how it works.
You can see the full script in
pytorch/vision repo
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
variable is a list of package names required to to run the model.args
andkwargs
are passed along to the real callable function.Docstring of the function works as a help message. It explains what does the model do and what are the allowed positional/keyword arguments. It’s highly recommended to add a few examples here.
Entrypoint function should ALWAYS return a model(nn.module).
Pretrained weights can either be stored locally in the github repo, or loadable by
torch.hub.load_state_dict_from_url()
. In the example abovetorchvision.models.resnet.resnet18
handlespretrained
, alternatively you can put the following logic in the entrypoint definition.
if pretrained:
# For checkpoint saved in local repo
model.load_state_dict(<path_to_saved_checkpoint>)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
Important Notice¶
The published models should be at least in a branch/tag. It can’t be a random commit.
Loading models from Hub¶
Pytorch Hub provides convenient APIs to explore all available models in hub through torch.hub.list()
,
show docstring and examples through torch.hub.help()
and load the pre-trained models using torch.hub.load()
-
torch.hub.
list
(github, force_reload=False)[source]¶ List all entrypoints available in github hubconf.
- Parameters
github – Required, a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
force_reload – Optional, whether to discard the existing cache and force a fresh download. Default is False.
- Returns
a list of available entrypoint names
- Return type
entrypoints
Example
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
-
torch.hub.
help
(github, model, force_reload=False)[source]¶ Show the docstring of entrypoint model.
- Parameters
github – Required, a string with format <repo_owner/repo_name[:tag_name]> with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
model – Required, a string of entrypoint name defined in repo’s hubconf.py
force_reload – Optional, whether to discard the existing cache and force a fresh download. Default is False.
Example
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
-
torch.hub.
load
(github, model, *args, **kwargs)[source]¶ Load a model from a github repo, with pretrained weights.
- Parameters
github – Required, a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
model – Required, a string of entrypoint name defined in repo’s hubconf.py
*args – Optional, the corresponding args for callable model.
force_reload – Optional, whether to force a fresh download of github repo unconditionally. Default is False.
**kwargs – Optional, the corresponding kwargs for callable model.
- Returns
a single model with corresponding pretrained weights.
Example
>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
Running a loaded model:¶
Note that *args, **kwargs
in torch.load()
are used to instantiate a model.
After you loaded a model, how can you find out what you can do with the model?
A suggested workflow is
dir(model)
to see all avaialble methods of the model.help(model.foo)
to check what argumentsmodel.foo
takes to run
Where are my downloaded models saved?¶
The locations are used in the order of
Calling
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
, if environment variableTORCH_HOME
is set.$XDG_CACHE_HOME/torch/hub
, if environment variableXDG_CACHE_HOME
is set.~/.cache/torch/hub
-
torch.hub.
set_dir
(d)[source]¶ Optionally set hub_dir to a local dir to save downloaded models & weights.
If
set_dir
is not called, default path is$TORCH_HOME/hub
where environment variable$TORCH_HOME
defaults to$XDG_CACHE_HOME/torch
.$XDG_CACHE_HOME
follows the X Design Group specification of the Linux filesytem layout, with a default value~/.cache
if the environment variable is not set.- Parameters
d – path to a local folder to save downloaded models & weights.
Caching logic¶
By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in hub_dir
.
Users can force a reload by calling hub.load(..., force_reload=True)
. This will delete
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
when updates are published to the same branch, users can keep up with the latest release.
Known limitations:¶
Torch hub works by importing the package as if it was installed. There’re some side effects
introduced by importing in Python. For example, you can see new items in Python caches
sys.modules
and sys.path_importer_cache
which is normal Python behavior.
A known limitation that worth mentioning here is user CANNOT load two different branches of the same repo in the same python process. It’s just like installing two packages with the same name in Python, which is not good. Cache might join the party and give you surprises if you actually try that. Of course it’s totally fine to load them in separate processes.