本文整理匯總了Python中torch.hub方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.hub方法的具體用法?Python torch.hub怎麽用?Python torch.hub使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch
的用法示例。
在下文中一共展示了torch.hub方法的8個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: config
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def config(*args, **kwargs):
r"""
# Using torch.hub !
import torch
config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased') # Download configuration from S3 and cache.
config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json')
config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False)
assert config.output_attention == True
config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False, return_unused_kwargs=True)
assert config.output_attention == True
assert unused_kwargs == {'foo': False}
"""
return AutoConfig.from_pretrained(*args, **kwargs)
示例2: model
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def model(*args, **kwargs):
r"""
# Using torch.hub !
import torch
model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
return AutoModel.from_pretrained(*args, **kwargs)
示例3: modelForSequenceClassification
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def modelForSequenceClassification(*args, **kwargs):
r"""
# Using torch.hub !
import torch
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
return AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
示例4: ig_resnext101_32x8d
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def ig_resnext101_32x8d(**kwargs):
r"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/"""
return _resnet("ig_resnext101_32x8d", **kwargs)
示例5: tokenizer
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def tokenizer(*args, **kwargs):
r"""
# Using torch.hub !
import torch
tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', 'bert-base-uncased') # Download vocabulary from S3 and cache.
tokenizer = torch.hub.load('huggingface/transformers', 'tokenizer', './test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
"""
return AutoTokenizer.from_pretrained(*args, **kwargs)
示例6: modelWithLMHead
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def modelWithLMHead(*args, **kwargs):
r"""
# Using torch.hub !
import torch
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
return AutoModelWithLMHead.from_pretrained(*args, **kwargs)
示例7: serialize
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def serialize(self):
""" Serialize model and output dictionary.
Returns:
dict, serialized model with keys `model_args` and `state_dict`.
"""
from .. import __version__ as asteroid_version # Avoid circular imports
import pytorch_lightning as pl # Not used in torch.hub
model_conf = dict()
fb_config = self.encoder.filterbank.get_config()
masknet_config = self.masker.get_config()
# Assert both dict are disjoint
if not all(k not in fb_config for k in masknet_config):
raise AssertionError("Filterbank and Mask network config share"
"common keys. Merging them is not safe.")
# Merge all args under model_args.
model_conf['model_name'] = self.__class__.__name__
model_conf['model_args'] = {**fb_config, **masknet_config}
model_conf['state_dict'] = self.state_dict()
# Additional infos
infos = dict()
infos['software_versions'] = dict(
torch_version=torch.__version__,
pytorch_lightning_version=pl.__version__,
asteroid_version=asteroid_version,
)
model_conf['infos'] = infos
return model_conf
示例8: main
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import hub [as 別名]
def main():
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('/home/john/Data/mnist', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('/home/john/Data/mnist', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = hub.load("johnhany/torchhub:master", "cnn", force_reload=True, pretrained=True).to(device)
test(args, model, device, test_loader)
開發者ID:PacktPublishing,項目名稱:Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x,代碼行數:40,代碼來源:mnist_hub.py