當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.hub方法代碼示例

本文整理匯總了Python中torch.hub方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.hub方法的具體用法?Python torch.hub怎麽用?Python torch.hub使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.hub方法的8個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: config

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def config(*args, **kwargs):
    r"""
                # Using torch.hub !
                import torch

                config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased')  # Download configuration from S3 and cache.
                config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
                config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json')
                config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False)
                assert config.output_attention == True
                config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False, return_unused_kwargs=True)
                assert config.output_attention == True
                assert unused_kwargs == {'foo': False}

            """

    return AutoConfig.from_pretrained(*args, **kwargs) 
開發者ID:bhoov,項目名稱:exbert,代碼行數:19,代碼來源:hubconf.py

示例2: model

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def model(*args, **kwargs):
    r"""
            # Using torch.hub !
            import torch

            model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """

    return AutoModel.from_pretrained(*args, **kwargs) 
開發者ID:bhoov,項目名稱:exbert,代碼行數:18,代碼來源:hubconf.py

示例3: modelForSequenceClassification

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def modelForSequenceClassification(*args, **kwargs):
    r"""
            # Using torch.hub !
            import torch

            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
            model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

        """

    return AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) 
開發者ID:bhoov,項目名稱:exbert,代碼行數:18,代碼來源:hubconf.py

示例4: ig_resnext101_32x8d

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def ig_resnext101_32x8d(**kwargs):
    r"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
    and finetuned on ImageNet from Figure 5 in
    `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
    Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/"""
    return _resnet("ig_resnext101_32x8d", **kwargs) 
開發者ID:bonlime,項目名稱:pytorch-tools,代碼行數:8,代碼來源:resnet.py

示例5: tokenizer

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def tokenizer(*args, **kwargs):
    r"""
        # Using torch.hub !
        import torch

        tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', 'bert-base-uncased')    # Download vocabulary from S3 and cache.
        tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', './test/bert_saved_model/')  # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`

    """

    return AutoTokenizer.from_pretrained(*args, **kwargs) 
開發者ID:bhoov,項目名稱:exbert,代碼行數:13,代碼來源:hubconf.py

示例6: modelWithLMHead

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def modelWithLMHead(*args, **kwargs):
    r"""
        # Using torch.hub !
        import torch

        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased')    # Download model and configuration from S3 and cache.
        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attention=True)  # Update configuration during loading
        assert model.config.output_attention == True
        # Loading from a TF checkpoint file instead of a PyTorch model (slower)
        config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
        model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)

    """
    return AutoModelWithLMHead.from_pretrained(*args, **kwargs) 
開發者ID:bhoov,項目名稱:exbert,代碼行數:17,代碼來源:hubconf.py

示例7: serialize

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def serialize(self):
        """ Serialize model and output dictionary.

        Returns:
            dict, serialized model with keys `model_args` and `state_dict`.
        """
        from .. import __version__ as asteroid_version  # Avoid circular imports
        import pytorch_lightning as pl  # Not used in torch.hub
        model_conf = dict()
        fb_config = self.encoder.filterbank.get_config()
        masknet_config = self.masker.get_config()
        # Assert both dict are disjoint
        if not all(k not in fb_config for k in masknet_config):
            raise AssertionError("Filterbank and Mask network config share"
                                 "common keys. Merging them is not safe.")
        # Merge all args under model_args.
        model_conf['model_name'] = self.__class__.__name__
        model_conf['model_args'] = {**fb_config, **masknet_config}
        model_conf['state_dict'] = self.state_dict()
        # Additional infos
        infos = dict()
        infos['software_versions'] = dict(
            torch_version=torch.__version__,
            pytorch_lightning_version=pl.__version__,
            asteroid_version=asteroid_version,
        )
        model_conf['infos'] = infos
        return model_conf 
開發者ID:mpariente,項目名稱:asteroid,代碼行數:30,代碼來源:base_models.py

示例8: main

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('/home/john/Data/mnist', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('/home/john/Data/mnist', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)


    model = hub.load("johnhany/torchhub:master", "cnn", force_reload=True, pretrained=True).to(device)

    test(args, model, device, test_loader) 
開發者ID:PacktPublishing,項目名稱:Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x,代碼行數:40,代碼來源:mnist_hub.py


注:本文中的torch.hub方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。