当前位置: 首页>>代码示例>>Python>>正文


Python torch.load函数代码示例

本文整理汇总了Python中torch.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了load函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _load

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint
开发者ID:Saiuz,项目名称:autokeras,代码行数:7,代码来源:model_helper.py

示例2: init_model

def init_model(word2id, opt):
    model = Seq2SeqLSTMAttention(
        emb_dim=opt.word_vec_size,
        vocab_size=opt.vocab_size,
        src_hidden_dim=opt.rnn_size,
        trg_hidden_dim=opt.rnn_size,
        ctx_hidden_dim=opt.rnn_size,
        attention_mode='dot',
        batch_size=opt.batch_size,
        bidirectional=opt.bidirectional,
        pad_token_src = word2id[pykp.io.PAD_WORD],
        pad_token_trg = word2id[pykp.io.PAD_WORD],
        nlayers_src=opt.enc_layers,
        nlayers_trg=opt.dec_layers,
        dropout=opt.dropout,
        teacher_forcing_ratio=opt.teacher_forcing_ratio,
        scheduled_sampling=opt.scheduled_sampling,
        scheduled_sampling_batches=opt.scheduled_sampling_batches
    )

    logging.info('======================  Model Parameters  =========================')
    if opt.train_from:
        logging.info("loading previous checkpoint from %s" % opt.train_from)
        if torch.cuda.is_available():
            model.load_state_dict(torch.load(open(opt.train_from, 'rb')))
        else:
            model.load_state_dict(torch.load(
                open(opt.train_from, 'rb'), map_location=lambda storage, loc: storage
            ))
    utils.tally_parameters(model)

    return model
开发者ID:zhhengcs,项目名称:seq2seq-keyphrase-pytorch,代码行数:32,代码来源:train(old,no+copy,max+entropy+loss).py

示例3: generate

def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))
开发者ID:672401341,项目名称:pytorch-book,代码行数:31,代码来源:main.py

示例4: load_model

    def load_model(self):
        if len(glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')) == 0:
            return

        if args.load_iter is None:
            f_list = glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')
            iter_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
            start_iter = sorted(iter_list)[-1]
        else:
            start_iter = args.load_iter

        name = args.corpus + '-selector-{}.pth'.format(start_iter)
        model_file_path = os.path.join(args.save_dir, name)
        print("loading model", model_file_path)

        if opt.device == torch.device('cuda'):
            state = torch.load(model_file_path)
        else:
            state = torch.load(model_file_path, map_location=opt.device)

        self._epoch = state['epoch']
        self._iter = state['iter']
        self.running_avg_loss = state['current_loss']
        self.min_loss = state['min_loss']

        self.model.sentence_selector.load_state_dict(state['selector_state_dict'])

        if not args.is_coverage:
            self.optimizer.load_state_dict(state['optimizer'])
            if opt.device == torch.device('cuda'):
                for state in list(self.optimizer.state.values()):
                    for k, v in list(state.items()):
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
开发者ID:coder352,项目名称:shellscript,代码行数:34,代码来源:train_selector.py

示例5: load_checkpoint

def load_checkpoint(checkpoint):
    if torch.cuda.is_available():
        checkpoint = torch.load(checkpoint)
    else:
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    
    return checkpoint
开发者ID:Wilson-Sunshine,项目名称:Udacity_AI_Program_Basic,代码行数:7,代码来源:predict.py

示例6: run

def run(args, run_args, rank=0, world_size=1):
    set_seed(args, rank=rank)
    logger = initialize_logger(args, rank)
    field, train_sets, val_sets, save_dict = run_args

    logger.start = time.time()

    logger.info(f'Preparing iterators')
    train_iters = [(name, to_iter(args, world_size, tok, x, token_testing=args.token_testing)) 
                      for name, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
    val_iters = [(name, to_iter(args, world_size, tok, x, train=False, token_testing=args.token_testing, sort=False if 'sql' in name else None))
                    for name, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)]

    logger.info(f'Initializing Writer')
    writer = SummaryWriter(log_dir=args.log_dir)

    model = init_model(args, field, logger, world_size)
    opt = init_opt(args, model) 
    start_iteration = 1

    if save_dict is not None:
        logger.info(f'Loading model from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        model.load_state_dict(save_dict['model_state_dict'])
        if args.resume:
            logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')
            opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')))
            start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1])

    logger.info(f'Begin Training')
    train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters, 
        rank=rank, world_size=world_size, 
        log_every=args.log_every, val_every=args.val_every, rounds=len(train_iters)>1,
        writer=writer if rank==0 else None, save_every=args.save_every, start_iteration=start_iteration)
开发者ID:AhlamMD,项目名称:decaNLP,代码行数:34,代码来源:train.py

示例7: restore_model

 def restore_model(self, resume_iters):
     """Restore the trained generator and discriminator."""
     print('Loading the trained models from step {}...'.format(resume_iters))
     G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
     D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
     self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
     self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:7,代码来源:solver.py

示例8: get_pretrained_net

def get_pretrained_net(name):
    """Loads pretrained network"""
    if name == 'alexnet_caffe':
        if not os.path.exists('alexnet-torch_py3.pth'):
            print('Downloading AlexNet')
            os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
        return torch.load('alexnet-torch_py3.pth')
    elif name == 'vgg19_caffe':
        if not os.path.exists('vgg19-caffe-py3.pth'):
            print('Downloading VGG-19')
            os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')
        
        vgg = get_vgg19_caffe()
        
        return vgg
    elif name == 'vgg16_caffe':
        if not os.path.exists('vgg16-caffe-py3.pth'):
            print('Downloading VGG-16')
            os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')
        
        vgg = get_vgg16_caffe()
        
        return vgg
    elif name == 'vgg19_pytorch_modified':
        # os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')
        
        model = VGGModified(vgg19(pretrained=False), 0.2)
        model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])

        return model
    else:
        assert False
开发者ID:1exx,项目名称:deep-image-prior,代码行数:32,代码来源:perceptual_loss.py

示例9: get_vanilla_vgg_features

def get_vanilla_vgg_features(cut_idx=-1):
    if not os.path.exists('vgg_features.pth'):
        os.system(
            'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth')
        vgg_weights = torch.load('vgg19-d01eb7cb.pth')
        # fix compatibility issues
        map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'}
        vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()])

        

        model = models.vgg19()
        model.classifier = nn.Sequential(View(), *model.classifier._modules.values())
        

        model.load_state_dict(vgg_weights)
        
        torch.save(model.features, 'vgg_features.pth')
        torch.save(model.classifier, 'vgg_classifier.pth')

    vgg = torch.load('vgg_features.pth')
    if cut_idx > 36:
        vgg_classifier = torch.load('vgg_classifier.pth')
        vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values()))

    vgg.eval()

    return vgg
开发者ID:1exx,项目名称:deep-image-prior,代码行数:28,代码来源:feature_inversion_utils.py

示例10: load

 def load(self, filename, legacy=False, ignore_d=False):
     """
     ignore_d: if `True`, then don't load in the
       discriminator.
     """
     if not self.use_cuda:
         map_location = lambda storage, loc: storage
     else:
         map_location = None
     if legacy:
         g, d = torch.load(filename,
                           map_location=map_location)
         self.g.load_state_dict(g)
         if not ignore_d:
             self.d.load_state_dict(d)
     else:
         dd = torch.load(filename,
                         map_location=map_location)
         self.g.load_state_dict(dd['g'])
         if not ignore_d:
             self.d.load_state_dict(dd['d'])
         for key in self.optim:
             if ignore_d and key == 'd':
                 continue
             self.optim[key].load_state_dict(dd['optim_'+key])
         self.last_epoch = dd['epoch']
开发者ID:kazk1018,项目名称:manifold_mixup,代码行数:26,代码来源:base.py

示例11: load_network_stageI

    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD
开发者ID:tensoralex,项目名称:StackGAN-Pytorch,代码行数:25,代码来源:trainer.py

示例12: __init__

    def __init__(self,
                 root, mnist_root="data",
                 train=True,
                 transform=None, target_transform=None,
                 download=False):
        """Init MNIST-M dataset."""
        super(MNISTM, self).__init__()
        self.root = os.path.expanduser(root)
        self.mnist_root = os.path.expanduser(mnist_root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = \
                torch.load(os.path.join(self.root,
                                        self.processed_folder,
                                        self.training_file))
        else:
            self.test_data, self.test_labels = \
                torch.load(os.path.join(self.root,
                                        self.processed_folder,
                                        self.test_file))
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:30,代码来源:mnistm.py

示例13: load_models

def load_models(load_path):
    model_args = json.load(open("{}/args.json".format(load_path), "r"))
    word2idx = json.load(open("{}/vocab.json".format(load_path), "r"))
    idx2word = {v: k for k, v in word2idx.items()}

    autoencoder = Seq2Seq(emsize=model_args['emsize'],
                          nhidden=model_args['nhidden'],
                          ntokens=model_args['ntokens'],
                          nlayers=model_args['nlayers'],
                          hidden_init=model_args['hidden_init'])
    gan_gen = MLP_G(ninput=model_args['z_size'],
                    noutput=model_args['nhidden'],
                    layers=model_args['arch_g'])
    gan_disc = MLP_D(ninput=model_args['nhidden'],
                     noutput=1,
                     layers=model_args['arch_d'])

    print('Loading models from'+load_path)
    ae_path = os.path.join(load_path, "autoencoder_model.pt")
    gen_path = os.path.join(load_path, "gan_gen_model.pt")
    disc_path = os.path.join(load_path, "gan_disc_model.pt")

    autoencoder.load_state_dict(torch.load(ae_path))
    gan_gen.load_state_dict(torch.load(gen_path))
    gan_disc.load_state_dict(torch.load(disc_path))
    return model_args, idx2word, autoencoder, gan_gen, gan_disc
开发者ID:wangwang110,项目名称:ARAE,代码行数:26,代码来源:models.py

示例14: demo

def demo(data, save, depth=40, growth_rate=12, batch_size=256):
    """
    Applies temperature scaling to a trained model.

    Takes a pretrained DenseNet-CIFAR100 model, and a validation set
    (parameterized by indices on train set).
    Applies temperature scaling, and saves a temperature scaled version.

    NB: the "save" parameter references a DIRECTORY, not a file.
    In that directory, there should be two files:
    - model.pth (model state dict)
    - valid_indices.pth (a list of indices corresponding to the validation set).

    data (str) - path to directory where data should be loaded from/downloaded
    save (str) - directory with necessary files (see above)
    """
    # Load model state dict
    model_filename = os.path.join(save, 'model.pth')
    if not os.path.exists(model_filename):
        raise RuntimeError('Cannot find file %s to load' % model_filename)
    state_dict = torch.load(model_filename)

    # Load validation indices
    valid_indices_filename = os.path.join(save, 'valid_indices.pth')
    if not os.path.exists(valid_indices_filename):
        raise RuntimeError('Cannot find file %s to load' % valid_indices_filename)
    valid_indices = torch.load(valid_indices_filename)

    # Regenerate validation set loader
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    test_transforms = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=mean, std=stdv),
    ])
    valid_set = tv.datasets.CIFAR100(data, train=True, transform=test_transforms, download=True)
    valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
                                               sampler=SubsetRandomSampler(valid_indices))

    # Load original model
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]
    orig_model = DenseNetEfficientMulti(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=100
    ).cuda()
    orig_model.load_state_dict(state_dict)

    # Now we're going to wrap the model with a decorator that adds temperature scaling
    model = ModelWithTemperature(orig_model)

    # Tune the model temperature, and save the results
    model.set_temperature(valid_loader)
    model_filename = os.path.join(save, 'model_with_temperature.pth')
    torch.save(model.state_dict(), model_filename)
    print('Temperature scaled model sved to %s' % model_filename)
    print('Done!')
开发者ID:zhenglm,项目名称:temperature_scaling,代码行数:59,代码来源:demo.py

示例15: __init__

 def __init__(self, file, labelFile):
     self.train = torch.load(file)
     self.label = torch.load(labelFile)
     self.len = len(self.train)  # get how many data points.
     for i in range(0, self.len):  # transform the imgs.
         self.train[i] = transforms.Normalize((0.1307,), (0.3081,))(
             self.train[i].view(1, -1))  # do a small transformation
     self.train = self.train.view(-1, 1, 28, 28)
开发者ID:RobinROAR,项目名称:TensorflowTutorialsCode,代码行数:8,代码来源:utils.py


注:本文中的torch.load函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。