当前位置: 首页>>代码示例>>Python>>正文


Python torch.load方法代码示例

本文整理汇总了Python中torch.load方法的典型用法代码示例。如果您正苦于以下问题:Python torch.load方法的具体用法?Python torch.load怎么用?Python torch.load使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.load方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: convert

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def convert(src, dst):
    """Convert keys in pycls pretrained RegNet models to mmdet style."""
    # load caffe model
    regnet_model = torch.load(src)
    blobs = regnet_model['model_state']
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    for key, weight in blobs.items():
        if 'stem' in key:
            convert_stem(key, weight, state_dict, converted_names)
        elif 'head' in key:
            convert_head(key, weight, state_dict, converted_names)
        elif key.startswith('s'):
            convert_reslayer(key, weight, state_dict, converted_names)

    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'not converted: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:26,代码来源:regnet2mmdet.py

示例2: get_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def get_model(load_weights = True):
    deepsea_cpu = nn.Sequential( # Sequential,
        nn.Conv2d(4,320,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(320,480,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(480,960,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.Dropout(0.5),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
        nn.Threshold(0, 1e-06),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
    return nn.Sequential(ReCodeAlphabet(), deepsea_cpu) 
开发者ID:kipoi,项目名称:models,代码行数:24,代码来源:model_architecture.py

示例3: from_snapshot

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def from_snapshot(self, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.net.load_state_dict(torch.load(str(sfile)))
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter 
开发者ID:Sunarker,项目名称:Collaborative-Learning-for-Weakly-Supervised-Object-Detection,代码行数:24,代码来源:train_val.py

示例4: get_seqpred_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def get_seqpred_model(load_weights = True):
    deepsea_cpu = nn.Sequential( # Sequential,
        nn.Conv2d(4,320,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(320,480,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(480,960,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.Dropout(0.5),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
        nn.Threshold(0, 1e-06),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
    return nn.Sequential(ReCodeAlphabet(), ConcatenateRC(), deepsea_cpu, AverageRC()) 
开发者ID:kipoi,项目名称:models,代码行数:24,代码来源:model_architecture.py

示例5: load_test_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_test_model(model, config):
    """
    :param model:  initial model
    :param config:  config
    :return:  loaded model
    """
    if config.t_model is None:
        test_model_dir = config.save_best_model_dir
        test_model_name = "{}.pt".format(config.model_name)
        test_model_path = os.path.join(test_model_dir, test_model_name)
        print("load default model from {}".format(test_model_path))
    else:
        test_model_path = config.t_model
        print("load user model from {}".format(test_model_path))
    model.load_state_dict(torch.load(test_model_path))
    return model 
开发者ID:bamtercelboo,项目名称:pytorch_NER_BiLSTM_CNN_CRF,代码行数:18,代码来源:test.py

示例6: load_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_checkpoint(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        opt = checkpoint['opt']
        opt.use_external_captions = False
        vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name))
        opt.vocab_size = len(vocab)

        from model import VSE
        self.model = VSE(opt)
        self.model.load_state_dict(checkpoint['model'])
        self.projector = vocab

        self.model.img_enc.eval()
        self.model.txt_enc.eval()
        for p in self.model.img_enc.parameters():
            p.requires_grad = False
        for p in self.model.txt_enc.parameters():
            p.requires_grad = False 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:20,代码来源:saliency_visualization.py

示例7: load_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_model(model, optimizer, scheduler, path, num_epochs, start_time=time.time()):

    epoch = num_epochs
    while epoch > 0 and not os.path.isfile('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)):
        epoch -= 1
    if epoch > 0:
        model_path = '{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)
        model_state_dict = torch.load('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch))
        model.load_state_dict(model_state_dict)
        if optimizer is not None:
            optimizer_state_dict = torch.load('{path}_optimizer_{epoch:d}.pth'.format(path=path, epoch=epoch))
            optimizer.load_state_dict(optimizer_state_dict)
        if scheduler is not None:
            scheduler_state_dict = torch.load('{path}_scheduler_{epoch:d}.pth'.format(path=path, epoch=epoch))
            scheduler.best = scheduler_state_dict['best']
            scheduler.cooldown_counter = scheduler_state_dict['cooldown_counter']
            scheduler.num_bad_epochs = scheduler_state_dict['num_bad_epochs']
            scheduler.last_epoch = scheduler_state_dict['last_epoch']
        print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=num_epochs), end='')
        print('load {path}; '.format(path=model_path), end='')
        print('{time:8.3f} s'.format(time=time.time()-start_time))
    return epoch 
开发者ID:kibok90,项目名称:cvpr2018-hnd,代码行数:24,代码来源:utils.py

示例8: init_truncated_normal

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def init_truncated_normal(model, aux_str=''):
    if model is None: return None
    init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
                .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
    if os.path.isfile(init_path):
        model.load_state_dict(torch.load(init_path))
        print('load init weight: {init_path}'.format(init_path=init_path))
    else:
        if isinstance(model, nn.ModuleList):
            [truncated_normal(sub) for sub in model]
        else:
            truncated_normal(model)
        print('generate init weight: {init_path}'.format(init_path=init_path))
        torch.save(model.state_dict(), init_path)
        print('save init weight: {init_path}'.format(init_path=init_path))
    
    return model 
开发者ID:kibok90,项目名称:cvpr2018-hnd,代码行数:19,代码来源:models.py

示例9: print_mutation

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
开发者ID:zbyuan,项目名称:pruning_yolov3,代码行数:19,代码来源:utils.py

示例10: load_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_episode = checkpoint['episode']
            self.current_iteration = checkpoint['iteration']
            self.policy_model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['episode'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:18,代码来源:dqn.py

示例11: load_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                             .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:18,代码来源:condensenet.py

示例12: load_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:18,代码来源:erfnet.py

示例13: load_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.netG.load_state_dict(checkpoint['G_state_dict'])
            self.optimG.load_state_dict(checkpoint['G_optimizer'])
            self.netD.load_state_dict(checkpoint['D_state_dict'])
            self.optimD.load_state_dict(checkpoint['D_optimizer'])
            self.fixed_noise = checkpoint['fixed_noise']
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:22,代码来源:dcgan.py

示例14: __getitem__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def __getitem__(self, idx):
        fid = self.data_set_list[idx]
        if self.read_features:
            features = []
            for i in range(self.sequence_length):
                feature_path = os.path.join(
                    self.features_dir,
                    self.frames_metadata[fid + i]['cur_frame'] + '.pytar')
                features.append(torch.load(feature_path))
            input = torch.stack(features)
        else:
            image = self.load_and_resize(
                os.path.join(self.root_dir, 'images', fid))
            segment = self.load_and_resize_segmentation(
                os.path.join(self.root_dir, 'walkable', fid))

        # The two 0s are just place holders. They can be replaced by any values
        return (image, segment, 0, 0, ['images' + fid]) 
开发者ID:ehsanik,项目名称:dogTorch,代码行数:20,代码来源:nyu_walkable_surface_dataset.py

示例15: set_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import load [as 别名]
def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        self.model = RNNLM(self.vocab_size, **
                           self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            self.model.parameters(), **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step)) 
开发者ID:Alexander-H-Liu,项目名称:End-to-end-ASR-Pytorch,代码行数:25,代码来源:train_lm.py


注:本文中的torch.load方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。