當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.set_grad_enabled方法代碼示例

本文整理匯總了Python中torch.set_grad_enabled方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.set_grad_enabled方法的具體用法?Python torch.set_grad_enabled怎麽用?Python torch.set_grad_enabled使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.set_grad_enabled方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main_inference

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def main_inference():
    print("Loading config...")
    opt = TestOptions().parse()
    print("Loading dataset...")
    dset = TVQADataset(opt, mode=opt.mode)
    print("Loading model...")
    model = STAGE(opt)
    model.to(opt.device)
    cudnn.benchmark = True
    strict_mode = not opt.no_strict
    model_path = os.path.join("results", opt.model_dir, "best_valid.pth")
    model.load_state_dict(torch.load(model_path), strict=strict_mode)
    model.eval()
    model.inference_mode = True
    torch.set_grad_enabled(False)
    print("Evaluation Starts:\n")
    predictions = inference(opt, dset, model)
    print("predictions {}".format(predictions.keys()))
    pred_path = model_path.replace("best_valid.pth",
                                   "{}_inference_predictions.json".format(opt.mode))
    save_json(predictions, pred_path) 
開發者ID:jayleicn,項目名稱:TVQAplus,代碼行數:23,代碼來源:inference.py

示例2: test_simple_inference

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def test_simple_inference(self):
            if not torch.cuda.is_available():
                import pytest

                pytest.skip('test requires GPU and torch+cuda')

            ori_grad_enabled = torch.is_grad_enabled()
            root_dir = os.path.dirname(os.path.dirname(__name__))
            model_config = os.path.join(
                root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
            detector = MaskRCNNDetector(model_config)
            await detector.init()
            img_path = os.path.join(root_dir, 'demo/demo.jpg')
            bboxes, _ = await detector.apredict(img_path)
            self.assertTrue(bboxes)
            # asy inference detector will hack grad_enabled,
            # so restore here to avoid it to influence other tests
            torch.set_grad_enabled(ori_grad_enabled) 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:20,代碼來源:test_async.py

示例3: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def forward(self, x):
        with torch.set_grad_enabled(self.finetune):
            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.relu(x)
            x = self.model.maxpool(x)

            x = self.model.layer1(x)
            x = self.model.layer2(x)
            x = self.model.layer3(x)
            x = self.model.layer4(x)

        if not self.spatial_context:
            x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), x.size(1))
        if hasattr(self, 'context_transform'):
            x = self.context_transform(x)
        if hasattr(self, 'context_nonlinearity'):
            x = self.context_nonlinearity(x)
        return x 
開發者ID:nadavbh12,項目名稱:Character-Level-Language-Modeling-with-Deeper-Self-Attention-pytorch,代碼行數:21,代碼來源:vision_encoders.py

示例4: sample

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def sample(seed, num_bits, num_samples, samples_per_row, _log, output_path=None):
    torch.set_grad_enabled(False)

    if output_path is None:
        output_path = 'samples.png'

    torch.manual_seed(seed)
    np.random.seed(seed)

    device = set_device()

    _, _, (c, h, w) = get_train_valid_data()

    flow = create_flow(c, h, w).to(device)
    flow.eval()

    preprocess = Preprocess(num_bits)

    samples = flow.sample(num_samples)
    samples = preprocess.inverse(samples)

    save_image(samples.cpu(), output_path,
               nrow=samples_per_row,
               padding=0) 
開發者ID:bayesiains,項目名稱:nsf,代碼行數:26,代碼來源:images.py

示例5: do_one_epoch

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def do_one_epoch(self, epoch, episodes):
        mode = "train" if self.naff.training else "val"
        epoch_loss, accuracy, steps = 0., 0., 0
        data_generator = self.generate_batch(episodes)
        for x_t, x_tn in data_generator:
            with torch.set_grad_enabled(mode == 'train'):
                x_tn_hat = self.naff(x_t)
                loss = self.loss_fn(x_tn_hat, x_tn)

            if mode == "train":
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            epoch_loss += loss.detach().item()
            steps += 1
        self.log_results(epoch, epoch_loss / steps, prefix=mode)
        if mode == "val":
            self.early_stopper(-epoch_loss / steps, self.encoder) 
開發者ID:mila-iqia,項目名稱:atari-representation-learning,代碼行數:21,代碼來源:no_action_feedforward_predictor.py

示例6: do_one_epoch

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def do_one_epoch(self, epoch, episodes):
        mode = "train" if self.VAE.training else "val"
        epoch_loss, accuracy, steps = 0., 0., 0
        data_generator = self.generate_batch(episodes)
        for x_t in data_generator:
            with torch.set_grad_enabled(mode == 'train'):
                x_hat, mu, logvar = self.VAE(x_t)
                loss = self.loss_fn(x_t, x_hat, mu, logvar)

            if mode == "train":
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            epoch_loss += loss.detach().item()
            steps += 1
        self.log_results(epoch, epoch_loss / steps, prefix=mode)
        if mode == "val":
            self.early_stopper(-epoch_loss / steps, self.encoder)

    #             xim = x_hat.detach().cpu().numpy()[0].transpose(1,2,0)
    #             self.wandb.log({"example_reconstruction": [self.wandb.Image(xim, caption="")]}) 
開發者ID:mila-iqia,項目名稱:atari-representation-learning,代碼行數:24,代碼來源:vae.py

示例7: test

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def test():
    torch.set_grad_enabled(False)
    net.eval()
    loss_avg = 0.0
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.cuda(), target.cuda()

        # forward
        output = net(data)
        loss = F.cross_entropy(output, target)

        # accuracy
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).sum().item()

        # test loss average
        loss_avg += loss.item()

    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)
    torch.set_grad_enabled(True)


# Main loop 
開發者ID:hendrycks,項目名稱:pre-training,代碼行數:27,代碼來源:train_glc.py

示例8: get_C_hat_transpose

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def get_C_hat_transpose():
    torch.set_grad_enabled(False)
    probs = []
    net.eval()
    count = 0
    for batch_idx, (data, target) in enumerate(train_gold_deterministic_loader):
        # we subtract num_classes because we added num_classes to gold so we could identify which example is gold in train_phase2
        data, target = data.cuda(), (target - num_classes).cuda()
        count += target.shape[0]

        # forward
        output = net(data)
        pred = F.softmax(output, dim=1)
        probs.extend(list(pred.data.cpu().numpy()))

    probs = np.array(probs, dtype=np.float32)
    C_hat = np.zeros((num_classes, num_classes))
    for label in range(num_classes):
        indices = np.arange(len(train_data_gold.train_labels))[
            np.isclose(np.array(train_data_gold.train_labels) - num_classes, label)]
        C_hat[label] = np.mean(probs[indices], axis=0, keepdims=True)

    torch.set_grad_enabled(True)
    return C_hat.T.astype(np.float32) 
開發者ID:hendrycks,項目名稱:pre-training,代碼行數:26,代碼來源:train_glc.py

示例9: contractive_penalty

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def contractive_penalty(self, network, input, penalty_amount=0.5):

        if penalty_amount == 0.:
            return

        if not isinstance(input, (list, tuple)):
            input = [input]

        input = [inp.detach() for inp in input]
        input = [inp.requires_grad_() for inp in input]

        with torch.set_grad_enabled(True):
            output = network(*input)
        gradient = self._get_gradient(input, output)
        gradient = gradient.view(gradient.size()[0], -1)
        penalty = (gradient ** 2).sum(1).mean()

        return penalty_amount * penalty 
開發者ID:rdevon,項目名稱:cortex,代碼行數:20,代碼來源:gan.py

示例10: interpolate_penalty

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def interpolate_penalty(self, network, input, penalty_amount=0.5):

        input = input.detach()
        input = input.requires_grad_()

        if len(input) != 2:
            raise ValueError('tuple of 2 inputs required to interpolate')
        inp1, inp2 = input

        try:
            epsilon = network.inputs.e.view(-1, 1, 1, 1)
        except AttributeError:
            raise ValueError('You must initiate a uniform random variable'
                             '`e` to use interpolation')
        mid_in = ((1. - epsilon) * inp1 + epsilon * inp2)
        mid_in.requires_grad_()

        with torch.set_grad_enabled(True):
            mid_out = network(mid_in)
        gradient = self._get_gradient(mid_in, mid_out)
        gradient = gradient.view(gradient.size()[0], -1)
        penalty = ((gradient.norm(2, dim=1) - 1.) ** 2).mean()

        return penalty_amount * penalty 
開發者ID:rdevon,項目名稱:cortex,代碼行數:26,代碼來源:gan.py

示例11: _adapt

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def _adapt(self, batch_samples, set_grad=True):
        """Performs one MAML inner step to update the policy.

        Args:
            batch_samples (MAMLTrajectoryBatch): Samples data for one
                task and one gradient step.
            set_grad (bool): if False, update policy parameters in-place.
                Else, allow taking gradient of functions of updated parameters
                with respect to pre-updated parameters.

        """
        # pylint: disable=protected-access
        loss = self._inner_algo._compute_loss(*batch_samples[1:])

        # Update policy parameters with one SGD step
        self._inner_optimizer.zero_grad()
        loss.backward(create_graph=set_grad)

        with torch.set_grad_enabled(set_grad):
            self._inner_optimizer.step() 
開發者ID:rlworkgroup,項目名稱:garage,代碼行數:22,代碼來源:maml.py

示例12: preeval_batch

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def preeval_batch(self, dataset):
        torch.set_grad_enabled(False)
        refs = {}
        cands = {}
        i = 0
        for batch_idx in range(len(dataset.corpus)):
            batch_s, batch_o_s, batch_f, batch_pf, batch_pb, _, targets, _, list_oovs, source_len, max_source_oov, w2fs = dataset.get_batch(batch_idx)
            decoded_outputs, lengths = self.model(batch_s, batch_o_s, max_source_oov, batch_f, batch_pf, batch_pb, source_len, w2fs=w2fs)
            for j in range(len(lengths)):
                i += 1
                ref = self.prepare_for_bleu(targets[j])
                refs[i] = [ref]
                out_seq = []
                for k in range(lengths[j]):
                    symbol = decoded_outputs[j][k].item()
                    if symbol < self.vocab.size:
                        out_seq.append(self.vocab.idx2word[symbol])
                    else:
                        out_seq.append(list_oovs[j][symbol-self.vocab.size])
                out = self.prepare_for_bleu(out_seq)
                cands[i] = out

                # if i % 2500 == 0:
                #     print("Percentages:  %.4f" % (i/float(dataset.len)))
        return cands, refs 
開發者ID:EagleW,項目名稱:Describing_a_Knowledge_Base,代碼行數:27,代碼來源:predictor.py

示例13: figure

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def figure(self, batch_s, batch_o_s, batch_f, batch_pf, batch_pb, max_source_oov, source_len, list_oovs, w2fs, type,
               visual):
        torch.set_grad_enabled(False)
        decoded_outputs, lengths, self_matrix, soft = self.model(batch_s, batch_o_s, max_source_oov, batch_f, batch_pf, batch_pb,
                                              source_len, w2fs=w2fs, fig=True)
        length = lengths[0]
        output = []
        # print(decoded_outputs)
        for i in range(length):
            symbol = decoded_outputs[0][i].item()
            if symbol < self.vocab.size:
                output.append(self.vocab.idx2word[symbol])
            else:
                output.append(list_oovs[symbol-self.vocab.size])
        output = [i for i in output if i != '<PAD>' and i != '<EOS>' and i != '<SOS>']
        print(self_matrix.size(), soft.size())
        pos = [str(i) for i in batch_pf[0].cpu().tolist()]
        combine = []
        for j in range(len(pos)):
            combine.append(visual[j] + " : " + pos[j])
        self.showAttention(pos, combine, self_matrix.cpu(), 'self.png', 0)
        self.showAttention(type, output[19:25], soft[19:25].cpu(), 'type.png', 1)
        # return output 
開發者ID:EagleW,項目名稱:Describing_a_Knowledge_Base,代碼行數:25,代碼來源:predictor.py

示例14: validate

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def validate(opt, dset, model, criterion, mode="valid", use_hard_negatives=False):
    dset.set_mode(mode)
    torch.set_grad_enabled(False)
    model.eval()
    valid_loader = DataLoader(dset, batch_size=opt.test_bsz, shuffle=False,
                              collate_fn=pad_collate, num_workers=opt.num_workers, pin_memory=True)

    valid_qids = []
    valid_loss = []
    valid_corrects = []
    max_len_dict = dict(
        max_sub_l=opt.max_sub_l,
        max_vid_l=opt.max_vid_l,
        max_vcpt_l=opt.max_vcpt_l,
        max_qa_l=opt.max_qa_l,
    )
    for val_idx, batch in enumerate(valid_loader):
        model_inputs, targets, qids = prepare_inputs(batch, max_len_dict=max_len_dict, device=opt.device)
        model_inputs.use_hard_negatives = use_hard_negatives
        outputs, att_loss, _, temporal_loss, _ = model(model_inputs)
        loss = criterion(outputs, targets) + opt.att_weight * att_loss + opt.ts_weight * temporal_loss
        # measure accuracy and record loss
        valid_qids += [int(x) for x in qids]
        valid_loss.append(loss.data.item())
        pred_ids = outputs.data.max(1)[1]
        valid_corrects += pred_ids.eq(targets.data).tolist()
        if opt.debug and val_idx == 20:
            break

    valid_acc = sum(valid_corrects) / float(len(valid_corrects))
    valid_loss = sum(valid_loss) / float(len(valid_corrects))
    qid_corrects = ["%d\t%d" % (a, b) for a, b in zip(valid_qids, valid_corrects)]
    return valid_acc, valid_loss, qid_corrects 
開發者ID:jayleicn,項目名稱:TVQAplus,代碼行數:35,代碼來源:main.py

示例15: async_inference_detector

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import set_grad_enabled [as 別名]
def async_inference_detector(model, img):
    """Async inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        Awaitable detection results.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    # We don't restore `torch.is_grad_enabled()` value during concurrent
    # inference since execution can overlap
    torch.set_grad_enabled(False)
    result = await model.aforward_test(rescale=True, **data)
    return result 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:28,代碼來源:inference.py


注:本文中的torch.set_grad_enabled方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。