当前位置: 首页>>代码示例>>Python>>正文


Python torch.hub方法代码示例

本文整理汇总了Python中torch.hub方法的典型用法代码示例。如果您正苦于以下问题:Python torch.hub方法的具体用法?Python torch.hub怎么用?Python torch.hub使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.hub方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: config

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def config(*args, **kwargs):
    r"""
                # Using torch.hub !
                import torch

                config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased')  # Download configuration from S3 and cache.
                config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
                config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json')
                config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False)
                assert config.output_attention == True
                config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False, return_unused_kwargs=True)
                assert config.output_attention == True
                assert unused_kwargs == {'foo': False}

            """

    return AutoConfig.from_pretrained(*args, **kwargs) 
开发者ID:bhoov,项目名称:exbert,代码行数:19,代码来源:hubconf.py

示例2: model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def model(*args, **kwargs):
    r"""
            # Using torch.hub !
            import torch

            model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """

    return AutoModel.from_pretrained(*args, **kwargs) 
开发者ID:bhoov,项目名称:exbert,代码行数:18,代码来源:hubconf.py

示例3: modelForSequenceClassification

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def modelForSequenceClassification(*args, **kwargs):
    r"""
            # Using torch.hub !
            import torch

            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """

    return AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) 
开发者ID:bhoov,项目名称:exbert,代码行数:18,代码来源:hubconf.py

示例4: ig_resnext101_32x8d

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def ig_resnext101_32x8d(**kwargs):
    r"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
    and finetuned on ImageNet from Figure 5 in
    `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
    Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/"""
    return _resnet("ig_resnext101_32x8d", **kwargs) 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:8,代码来源:resnet.py

示例5: tokenizer

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def tokenizer(*args, **kwargs):
    r"""
        # Using torch.hub !
        import torch

        tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', 'bert-base-uncased')    # Download vocabulary from S3 and cache.
        tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', './test/bert_saved_model/')  # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`

    """

    return AutoTokenizer.from_pretrained(*args, **kwargs) 
开发者ID:bhoov,项目名称:exbert,代码行数:13,代码来源:hubconf.py

示例6: modelWithLMHead

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def modelWithLMHead(*args, **kwargs):
    r"""
        # Using torch.hub !
        import torch

        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased')    # Download model and configuration from S3 and cache.
        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attention=True)  # Update configuration during loading
        assert model.config.output_attention == True
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
        config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

    """
    return AutoModelWithLMHead.from_pretrained(*args, **kwargs) 
开发者ID:bhoov,项目名称:exbert,代码行数:17,代码来源:hubconf.py

示例7: serialize

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def serialize(self):
        """ Serialize model and output dictionary.

        Returns:
            dict, serialized model with keys `model_args` and `state_dict`.
        """
        from .. import __version__ as asteroid_version  # Avoid circular imports
        import pytorch_lightning as pl  # Not used in torch.hub
        model_conf = dict()
        fb_config = self.encoder.filterbank.get_config()
        masknet_config = self.masker.get_config()
        # Assert both dict are disjoint
        if not all(k not in fb_config for k in masknet_config):
            raise AssertionError("Filterbank and Mask network config share"
                                 "common keys. Merging them is not safe.")
        # Merge all args under model_args.
        model_conf['model_name'] = self.__class__.__name__
        model_conf['model_args'] = {**fb_config, **masknet_config}
        model_conf['state_dict'] = self.state_dict()
        # Additional infos
        infos = dict()
        infos['software_versions'] = dict(
            torch_version=torch.__version__,
            pytorch_lightning_version=pl.__version__,
            asteroid_version=asteroid_version,
        )
        model_conf['infos'] = infos
        return model_conf 
开发者ID:mpariente,项目名称:asteroid,代码行数:30,代码来源:base_models.py

示例8: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import hub [as 别名]
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('/home/john/Data/mnist', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('/home/john/Data/mnist', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)


    model = hub.load("johnhany/torchhub:master", "cnn", force_reload=True, pretrained=True).to(device)

    test(args, model, device, test_loader) 
开发者ID:PacktPublishing,项目名称:Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x,代码行数:40,代码来源:mnist_hub.py


注:本文中的torch.hub方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。