当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch load用法及代码示例


本文简要介绍python语言中 torch.hub.load 的用法。

用法:

torch.hub.load(repo_or_dir, model, *args, source='github', force_reload=False, verbose=True, skip_validation=False, **kwargs)

参数

  • repo_or_dir(string) -如果source 是‘github’,这应该对应于格式为repo_owner/repo_name[:tag_name] 的github repo,带有可选的标签/分支,例如'pytorch/vision:0.10'。如果未指定tag_name,则默认分支假定为main(如果存在),否则为master。如果 source 是 ‘local’ 那么它应该是本地目录的路径。

  • model(string) -在 repo/dir 的 hubconf.py 中定义的可调用(入口点)的名称。

  • *args(可选的) -可调用 model 的相应参数。

  • source(string,可选的) -‘github’或‘local’。指定如何解释repo_or_dir。默认为‘github’。

  • force_reload(bool,可选的) -是否无条件强制重新下载 github repo。如果 source = 'local' 没有任何效果。默认为 False

  • verbose(bool,可选的) -如果 False ,则静音有关命中本地缓存的消息。请注意,关于首次下载的消息无法静音。如果 source = 'local' 没有任何效果。默认为 True

  • skip_validation(bool,可选的) -如果 False ,torchhub 将检查 github 参数指定的分支或提交是否正确属于存储库所有者。这将向 GitHub API 发出请求;您可以通过设置 GITHUB_TOKEN 环境变量来指定非默认 GitHub 令牌。默认为 False

  • **kwargs(可选的) -可调用 model 的相应 kwargs。

返回

当使用给定的 *args**kwargs 调用时,model 可调用的输出。

从 github 存储库或本地目录加载模型。

注意:加载模型是典型的用例,但这也可用于加载其他对象,例如分词器、损失函数等。

如果source 是‘github’,则repo_or_dir 应该是带有可选标签/分支的repo_owner/repo_name[:tag_name] 形式。

如果source 是‘local’,则repo_or_dir 应该是本地目录的路径。

示例

>>> # from a github repo
>>> repo = 'pytorch/vision'
>>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
>>> # from a local directory
>>> path = '/some/local/path/pytorch/vision'
>>> model = torch.hub.load(path, 'resnet50', pretrained=True)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.hub.load。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。