當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.DataParallel方法代碼示例

本文整理匯總了Python中torch.nn.DataParallel方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.DataParallel方法的具體用法?Python nn.DataParallel怎麽用?Python nn.DataParallel使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.nn的用法示例。


在下文中一共展示了nn.DataParallel方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: eval

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def eval(self, evaluate, model_file, model):
        """ evaluation function """
        if model_file:
            self.model.eval()
            self.load(model_file, None)
            model = self.model.to(self.device)
            if self.cfg.data_parallel:
                model = nn.DataParallel(model)

        results = []
        iter_bar = tqdm(self.sup_iter) if model_file \
            else tqdm(deepcopy(self.eval_iter))
        for batch in iter_bar:
            batch = [t.to(self.device) for t in batch]

            with torch.no_grad():
                accuracy, result = evaluate(model, batch)
            results.append(result)

            iter_bar.set_description('Eval Acc=%5.3f' % accuracy)
        return results 
開發者ID:SanghunYun,項目名稱:UDA_pytorch,代碼行數:23,代碼來源:train.py

示例2: get_cnn

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def get_cnn(self, arch, pretrained):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if pretrained:
            print(("=> using pre-trained model '{}'".format(arch)))
            model = models.__dict__[arch](pretrained=True)
        else:
            print(("=> creating model '{}'".format(arch)))
            model = models.__dict__[arch]()

        if arch.startswith('alexnet') or arch.startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        return model 
開發者ID:ExplorerFreda,項目名稱:VSE-C,代碼行數:19,代碼來源:model.py

示例3: main

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def main():
    best_acc = 0

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('==> Preparing data..')
    transforms_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    dataset_train = CIFAR10(root='../data', train=True, download=True, 
                            transform=transforms_train)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_worker)

    # there are 10 classes so the dataset name is cifar-10
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    print('==> Making model..')

    net = pyramidnet()
    net = nn.DataParallel(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('The number of parameters of model is', num_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    # optimizer = optim.SGD(net.parameters(), lr=args.lr, 
    #                       momentum=0.9, weight_decay=1e-4)
    
    train(net, criterion, optimizer, train_loader, device) 
開發者ID:dnddnjs,項目名稱:pytorch-multigpu,代碼行數:38,代碼來源:train.py

示例4: net_from_chkpt_

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def net_from_chkpt_(self):
        def map_location(storage, _):
            return storage.cuda() if self.cuda else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(self.checkpoint, map_location=map_location)

        num_classes = len(self.dataset["common"]["classes"])

        net = UNet(num_classes).to(self.device)
        net = nn.DataParallel(net)

        if self.cuda:
            torch.backends.cudnn.benchmark = True

        net.load_state_dict(chkpt["state_dict"])
        net.eval()

        return net 
開發者ID:mapbox,項目名稱:robosat,代碼行數:21,代碼來源:serve.py

示例5: build_model

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def build_model(self):
        # code_dim=100, n_class=1000
        self.G = Generator(self.z_dim, self.n_class, chn=self.chn).to(self.device)
        self.D = Discriminator(self.n_class, chn=self.chn).to(self.device)
        if self.parallel:
            print('use parallel...')
            print('gpuids ', self.gpus)
            gpus = [int(i) for i in self.gpus.split(',')]
    
            self.G = nn.DataParallel(self.G, device_ids=gpus)
            self.D = nn.DataParallel(self.D, device_ids=gpus)

        # self.G.apply(weights_init)
        # self.D.apply(weights_init)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D) 
開發者ID:sxhxliang,項目名稱:BigGAN-pytorch,代碼行數:26,代碼來源:trainer.py

示例6: _save

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def _save(self, step):
        real_model = (self.model.module
                      if isinstance(self.model, nn.DataParallel)
                      else self.model)
        real_generator = (real_model.generator.module
                          if isinstance(real_model.generator, nn.DataParallel)
                          else real_model.generator)

        model_state_dict = real_model.state_dict()
        model_state_dict = {k: v for k, v in model_state_dict.items()
                            if 'generator' not in k}
        generator_state_dict = real_generator.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'generator': generator_state_dict,
            'vocab': onmt.inputters.save_fields_to_vocab(self.fields),
            'opt': self.model_opt,
            'optim': self.optim,
        }

        logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
        checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
        torch.save(checkpoint, checkpoint_path)
        return checkpoint, checkpoint_path 
開發者ID:lizekang,項目名稱:ITDD,代碼行數:26,代碼來源:model_saver.py

示例7: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def __init__(self, nClass, nCam, model_client, use_flow, task_dir, raw_model_dir, is_image_dataset, recorder):
        self.nClass = nClass
        self.nCam = nCam
        self.recorder = recorder
        self.visual = self.recorder.visual
        self.logger = self.recorder.logger
        self._mode = 'Train'
        self.is_image_dataset = is_image_dataset
        self.task_dir = task_dir

        self.model = model_client(self.nClass, self.nCam, use_flow, self.is_image_dataset, raw_model_dir, self.logger)
        self.model_parallel = DataParallel(self.model).cuda()
        self.model_parallel.feature = DataParallel(self.model.feature).cuda()

        self.net_info = []
        self.const_options()
        self.init_options()
        self.loss_mean = AverageMeter(len(self.line_name))

        self.net_info.extend(self.model.net_info)
        self.optimizer = self.init_optimizer()
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.lr_decay_step, gamma=self.gamma)
        self.idx = 0
        self.best_performance = 0.0 
開發者ID:yolomax,項目名稱:person-reid-lib,代碼行數:26,代碼來源:netbase.py

示例8: multi_gpu

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def multi_gpu(model, devices):
    return nn.DataParallel(model, device_ids=devices) 
開發者ID:atcbosselut,項目名稱:comet-commonsense,代碼行數:4,代碼來源:models.py

示例9: setup_networks

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def setup_networks(self):
    '''
    Networks for DDPAE.
    '''
    self.nets = {}
    # These will be registered in model() and guide() with pyro.module().
    self.model_modules = {}
    self.guide_modules = {}

    # Backbone, Pose RNN
    pose_model = PoseRNN(self.n_components, self.n_frames_output, self.n_channels,
                         self.image_size, self.image_latent_size, self.hidden_size,
                         self.ngf, self.pose_latent_size, self.independent_components)
    self.pose_model = nn.DataParallel(pose_model.cuda())

    self.nets['pose_model'] = self.pose_model
    self.guide_modules['pose_model'] = self.pose_model

    # Content LSTM
    content_lstm = SequenceEncoder(self.content_latent_size, self.hidden_size,
                                   self.content_latent_size * 2)
    self.content_lstm = nn.DataParallel(content_lstm.cuda())
    self.nets['content_lstm'] = self.content_lstm
    self.model_modules['content_lstm'] = self.content_lstm

    # Image encoder and decoder
    n_layers = int(np.log2(self.object_size)) - 1
    object_encoder = ImageEncoder(self.n_channels, self.content_latent_size,
                                  self.ngf, n_layers)
    object_decoder = ImageDecoder(self.content_latent_size, self.n_channels,
                                  self.ngf, n_layers, 'sigmoid')
    self.object_encoder = nn.DataParallel(object_encoder.cuda())
    self.object_decoder = nn.DataParallel(object_decoder.cuda())
    self.nets.update({'object_encoder': self.object_encoder,
                      'object_decoder': self.object_decoder})
    self.model_modules['decoder'] = self.object_decoder
    self.guide_modules['encoder'] = self.object_encoder 
開發者ID:jthsieh,項目名稱:DDPAE-video-prediction,代碼行數:39,代碼來源:DDPAE.py

示例10: set_train_model

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def set_train_model(self, model):
        print("Initializing training model...")
        self.model = model
        self.trainModel = self.model(config=self)
        #self.trainModel = nn.DataParallel(self.trainModel, device_ids=[2,3,4])
        
        self.trainModel.to(device)
        if self.optimizer != None:
            pass
        elif self.opt_method == "Adagrad" or self.opt_method == "adagrad":
            self.optimizer = optim.Adagrad(
                self.trainModel.parameters(),
                lr=self.alpha,
                lr_decay=self.lr_decay,
                weight_decay=self.weight_decay,
            )
        elif self.opt_method == "Adadelta" or self.opt_method == "adadelta":
            self.optimizer = optim.Adadelta(
                self.trainModel.parameters(),
                lr=self.alpha,
                weight_decay=self.weight_decay,
            )
        elif self.opt_method == "Adam" or self.opt_method == "adam":
            self.optimizer = optim.Adam(
                self.trainModel.parameters(),
                lr=self.alpha,
                weight_decay=self.weight_decay,
            )
        else:
            self.optimizer = optim.SGD(
                self.trainModel.parameters(),
                lr=self.alpha,
                weight_decay=self.weight_decay,
            )
        print("Finish initializing") 
開發者ID:daiquocnguyen,項目名稱:ConvKB,代碼行數:37,代碼來源:Config.py

示例11: forward

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def forward(self, x):
        """
        :param x: NCHW
        :return:
        """
        x = self.down(x)
        x = self.body(x) + x
        F = x
        x = self.to_q(x)
        # assert self.summarizer is not None
        x_soft, x_hard, symbols_hard = self.q(x)
        # TODO(parallel): To support nn.DataParallel, this must be changed, as it not a tensor
        return EncOut(x_soft, x_hard, symbols_hard, self.L, F) 
開發者ID:fab-jul,項目名稱:L3C-PyTorch,代碼行數:15,代碼來源:net.py

示例12: train

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def train(start_epoch):
    global EPOCH_CNT
    min_loss = 1e10
    loss = 0
    local_epoch = MAX_EPOCH 
    if FLAGS.dump_results == True:
        local_epoch = start_epoch + 1

    for epoch in range(start_epoch, local_epoch):
        EPOCH_CNT = epoch
        log_string('**** EPOCH %03d ****' % (epoch))
        log_string('Current learning rate: %f'%(get_current_lr(epoch)))
        log_string('Current BN decay momentum: %f'%(bnm_scheduler.lmbd(bnm_scheduler.last_epoch)))
        log_string(str(datetime.now()))
        
        # Reset numpy seed.
        # REF: https://github.com/pytorch/pytorch/issues/5059
        np.random.seed()
        if not FLAGS.dump_results:
            train_one_epoch()
        # if (EPOCH_CNT == 0 or EPOCH_CNT % 10 == 9 or FLAGS.dump_results == True): # Eval every 10 epochs
        # if (EPOCH_CNT == 0 or EPOCH_CNT == 29 or EPOCH_CNT == 59 or (EPOCH_CNT % 10 == 9 and EPOCH_CNT > 70) \
        if (EPOCH_CNT == 29 or EPOCH_CNT == 59 or (EPOCH_CNT % 10 == 9 and EPOCH_CNT > 70) \
            or FLAGS.get_data == True or FLAGS.dump_results == True): # Eval every 10 epochs
            loss = evaluate_one_epoch()
        # Save checkpoint
        if not FLAGS.dump_results:
            save_dict = {'epoch': epoch+1, # after training one epoch, the start_epoch should be epoch+1
                         'optimizer_state_dict': optimizer.state_dict(),
                         'loss': loss,
            }
            try: # with nn.DataParallel() the net is added as a submodule of DataParallel
                save_dict['model_state_dict'] = net.module.state_dict()
            except:
                save_dict['model_state_dict'] = net.state_dict()
            torch.save(save_dict, os.path.join(LOG_DIR, 'checkpoint.tar'))
            if EPOCH_CNT % 10 == 9 and EPOCH_CNT > 70:
                torch.save(save_dict, os.path.join(LOG_DIR, 'checkpoint_eval%d.tar' % EPOCH_CNT)) 
開發者ID:zaiweizhang,項目名稱:H3DNet,代碼行數:40,代碼來源:train.py

示例13: test_from_last_checkpoint_model

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def test_from_last_checkpoint_model(self):
        # test that loading works even if they differ by a prefix
        for trained_model, fresh_model in [
            (self.create_model(), self.create_model()),
            (nn.DataParallel(self.create_model()), self.create_model()),
            (self.create_model(), nn.DataParallel(self.create_model())),
            (
                nn.DataParallel(self.create_model()),
                nn.DataParallel(self.create_model()),
            ),
        ]:

            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(
                    trained_model, save_dir=f, save_to_disk=True
                )
                checkpointer.save("checkpoint_file")

                # in the same folder
                fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
                self.assertTrue(fresh_checkpointer.has_checkpoint())
                self.assertEqual(
                    fresh_checkpointer.get_checkpoint_file(),
                    os.path.join(f, "checkpoint_file.pth"),
                )
                _ = fresh_checkpointer.load()

            for trained_p, loaded_p in zip(
                trained_model.parameters(), fresh_model.parameters()
            ):
                # different tensor references
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content
                self.assertTrue(trained_p.equal(loaded_p)) 
開發者ID:Res2Net,項目名稱:Res2Net-maskrcnn,代碼行數:36,代碼來源:checkpoint.py

示例14: test_from_name_file_model

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def test_from_name_file_model(self):
        # test that loading works even if they differ by a prefix
        for trained_model, fresh_model in [
            (self.create_model(), self.create_model()),
            (nn.DataParallel(self.create_model()), self.create_model()),
            (self.create_model(), nn.DataParallel(self.create_model())),
            (
                nn.DataParallel(self.create_model()),
                nn.DataParallel(self.create_model()),
            ),
        ]:
            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(
                    trained_model, save_dir=f, save_to_disk=True
                )
                checkpointer.save("checkpoint_file")

                # on different folders
                with TemporaryDirectory() as g:
                    fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
                    self.assertFalse(fresh_checkpointer.has_checkpoint())
                    self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "")
                    _ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth"))

            for trained_p, loaded_p in zip(
                trained_model.parameters(), fresh_model.parameters()
            ):
                # different tensor references
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content
                self.assertTrue(trained_p.equal(loaded_p)) 
開發者ID:Res2Net,項目名稱:Res2Net-maskrcnn,代碼行數:33,代碼來源:checkpoint.py

示例15: test_complex_model_loaded

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import DataParallel [as 別名]
def test_complex_model_loaded(self):
        for add_data_parallel in [False, True]:
            model, state_dict = self.create_complex_model()
            if add_data_parallel:
                model = nn.DataParallel(model)

            load_state_dict(model, state_dict)
            for loaded, stored in zip(model.state_dict().values(), state_dict.values()):
                # different tensor references
                self.assertFalse(id(loaded) == id(stored))
                # same content
                self.assertTrue(loaded.equal(stored)) 
開發者ID:Res2Net,項目名稱:Res2Net-maskrcnn,代碼行數:14,代碼來源:checkpoint.py


注:本文中的torch.nn.DataParallel方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。