當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.enable_grad方法代碼示例

本文整理匯總了Python中torch.enable_grad方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.enable_grad方法的具體用法?Python torch.enable_grad怎麽用?Python torch.enable_grad使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.enable_grad方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
        ctx.training = training
        with torch.enable_grad():
            x = x.detach().requires_grad_(True)
            g = gnet(x)
            ctx.g = g
            ctx.x = x
            logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)

            if training:
                grad_x, *grad_params = torch.autograd.grad(
                    logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
                )
                if grad_x is None:
                    grad_x = torch.zeros_like(x)
                ctx.save_for_backward(grad_x, *g_params, *grad_params)

        return safe_detach(g), safe_detach(logdetgrad) 
開發者ID:rtqichen,項目名稱:residual-flows,代碼行數:20,代碼來源:iresblock.py

示例2: test_fork_join_enable_grad

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def test_fork_join_enable_grad():
    x = torch.rand(1, requires_grad=True)

    with torch.enable_grad():
        x2, p = fork(x)

    assert p.requires_grad
    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert p.requires_grad
    assert x.grad_fn.__class__ is Fork._backward_cls
    assert p.grad_fn.__class__ is Fork._backward_cls

    with torch.enable_grad():
        x2 = join(x, p)

    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert x.grad_fn.__class__ is Join._backward_cls 
開發者ID:kakaobrain,項目名稱:torchgpipe,代碼行數:25,代碼來源:test_dependency.py

示例3: compute_valid

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def compute_valid(self):
        """Rewrite this method to compute valid_epoch values.

        You can return a ``dict`` of values that you want to visualize.

        .. note::

            This method is under ``torch.no_grad():``. So, it will never compute grad.
            If you want to compute grad, please use ``torch.enable_grad():`` to wrap your operations.

        Example::

            d_fake = self.netD(self.fake.detach())
            d_real = self.netD(self.ground_truth)
            var_dic = {}
            var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
            return var_dic

        """
        _, d_var_dic = self.compute_g_loss()
        _, g_var_dic = self.compute_d_loss()
        var_dic = dict(d_var_dic, **g_var_dic)
        return var_dic 
開發者ID:dingguanglei,項目名稱:jdit,代碼行數:25,代碼來源:pix2pix.py

示例4: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def forward(self, inputs, targets):
        if not args.attack:
            return self.model(inputs), inputs

        x = inputs.detach()
        if self.rand:
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, targets, size_average=False)
            grad = torch.autograd.grad(loss, [x])[0]
            # print(grad)
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
            x = torch.clamp(x, 0, 1)

        return self.model(x), x 
開發者ID:YyzHarry,項目名稱:ME-Net,代碼行數:21,代碼來源:train_adv.py

示例5: backward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrouding state
        # when we're done.
        rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
        with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
            if preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_rng_state)
                if ctx.had_cuda_in_fwd:
                    torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs) 
開發者ID:Lyken17,項目名稱:pytorch-memonger,代碼行數:23,代碼來源:checkpoint.py

示例6: __call__

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def __call__(self, *args, **kwargs) -> Any:
        """
        Call super class with correct torch context

        Args:
            *args: forwarded positional arguments
            **kwargs: forwarded keyword arguments

        Returns:
            Any: transformed data

        """
        if self.grad:
            context = torch.enable_grad()
        else:
            context = torch.no_grad()

        with context:
            return super().__call__(*args, **kwargs) 
開發者ID:PhoenixDL,項目名稱:rising,代碼行數:21,代碼來源:abstract.py

示例7: train

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def train(self, training_batch, train=True):
        learning_input = training_batch.training_input

        with torch.enable_grad():
            action_preds = self.imitator(learning_input.state.float_features)
            # Classification label is index of action with value 1
            pred_action_idxs = torch.max(action_preds, dim=1)[1]
            actual_action_idxs = torch.max(learning_input.action, dim=1)[1]

            if train:
                imitator_loss = torch.nn.CrossEntropyLoss()
                bcq_loss = imitator_loss(action_preds, actual_action_idxs)
                bcq_loss.backward()
                self._maybe_run_optimizer(
                    self.imitator_optimizer, self.minibatches_per_step
                )

        return self._imitator_accuracy(pred_action_idxs, actual_action_idxs) 
開發者ID:facebookresearch,項目名稱:ReAgent,代碼行數:20,代碼來源:imitator_training.py

示例8: rsample

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def rsample(self, sample_shape, log_score=True):
        """
        sample_shape: number of samples from the PL distribution. Scalar.
        """
        with torch.enable_grad():  # torch.distributions turns off autograd
            n_samples = sample_shape[0]

            def sample_gumbel(samples_shape, eps=1e-20):
                U = torch.zeros(samples_shape, device='cuda').uniform_()
                return -torch.log(-torch.log(U + eps) + eps)
            if not log_score:
                log_s_perturb = torch.log(self.scores.unsqueeze(
                    0)) + sample_gumbel([n_samples, 1, self.n, 1])
            else:
                log_s_perturb = self.scores.unsqueeze(
                    0) + sample_gumbel([n_samples, 1, self.n, 1])
            log_s_perturb = log_s_perturb.view(-1, self.n, 1)
            P_hat = self.relaxed_sort(log_s_perturb)
            P_hat = P_hat.view(n_samples, -1, self.n, self.n)

            return P_hat.squeeze() 
開發者ID:ermongroup,項目名稱:neuralsort,代碼行數:23,代碼來源:pl.py

示例9: gradient_penalty

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def gradient_penalty(self, x_real, x_fake, netD, index=0):
        device = x_real.device
        with torch.enable_grad():
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
            alpha = alpha.expand_as(x_real)
            x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
            x_hat.requires_grad = True
            output = netD(x_hat)[index]
            grad_output = torch.ones(output.size()).to(device)
            grad = torch.autograd.grad(outputs=output,
                                       inputs=x_hat,
                                       grad_outputs=grad_output,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            grad = grad.view(grad.size(0), -1)
            loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
        return loss_gp 
開發者ID:takuhirok,項目名稱:rGAN,代碼行數:20,代碼來源:racgan_trainer.py

示例10: cond_gradient_penalty

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def cond_gradient_penalty(self, x_real, x_fake, y, netD, index=0):
        device = x_real.device
        with torch.enable_grad():
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
            alpha = alpha.expand_as(x_real)
            x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
            x_hat.requires_grad = True
            output = netD(x_hat, y)[index]
            grad_output = torch.ones(output.size()).to(device)
            grad = torch.autograd.grad(outputs=output,
                                       inputs=x_hat,
                                       grad_outputs=grad_output,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            grad = grad.contiguous().view(grad.size(0), -1)
            loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
        return loss_gp 
開發者ID:takuhirok,項目名稱:rGAN,代碼行數:20,代碼來源:rcgan_trainer.py

示例11: backward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def backward(ctx, grad_output):

        with torch.enable_grad():
            input, target, mask = ctx.saved_tensors
            former = input.narrow(1, 0, ctx.c//2)
            former_in_mask = torch.mul(former, mask)
            if former_in_mask.size() != target.size():  # For the last iteration of one epoch
                target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
            
            former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
            ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
            ctx.loss.backward()

        grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad  
        
        return grad_output, None, None, None, None 
開發者ID:Zhaoyi-Yan,項目名稱:Shift-Net_pytorch,代碼行數:18,代碼來源:InnerCosFunction.py

示例12: perturb_hinge

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def perturb_hinge(net, x_nat):
    # Perturb function based on (E[\phi(f(x)f(x'))])
    # init with random noise
    net.eval()
    x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach()
    for _ in range(args.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
            # perturb based on hinge loss
            loss = torch.mean(torch.clamp(1 - net(x).squeeze(1) * (net(x_nat).squeeze(1) / args.beta), min=0))
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + args.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon)
        x = torch.clamp(x, 0.0, 1.0)
    net.train()
    return x 
開發者ID:yaodongyu,項目名稱:TRADES,代碼行數:18,代碼來源:train_trades_mnist_binary.py

示例13: train

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def train(self,
              data_loader: Iterable or DataLoader,
              mode: str = TRAIN):
        """ Training the model for an epoch.

        :param data_loader:
        :param mode: Name of this loop. Default is `train`. Passed to callbacks.
        """

        self._is_train = True
        self._epoch += 1
        self.model.train()
        if hasattr(self.loss_f, "train"):
            self.loss_f.train()
        with torch.enable_grad():
            self._loop(data_loader, mode=mode)

        if self.scheduler is not None and self.update_scheduler_by_epoch:
            self.scheduler.step()

        if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
            data_loader.sampler.set_epoch(self.epoch) 
開發者ID:moskomule,項目名稱:homura,代碼行數:24,代碼來源:trainers.py

示例14: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def forward(self, model, bx, by, by_prime, curr_batch_size):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)

        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, pen = model(adv_bx * 2 - 1)
                loss = F.cross_entropy(logits[:curr_batch_size], by, reduction='sum')
                if self.attack_rotations:
                    loss += F.cross_entropy(model.module.rot_pred(pen[curr_batch_size:]), by_prime, reduction='sum') / 8.
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]

            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())

            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)

        return adv_bx 
開發者ID:hendrycks,項目名稱:ss-ood,代碼行數:26,代碼來源:attacks.py

示例15: train

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def train(self, data_loader, print_freq=20):
        self.model.train()
        with torch.enable_grad():
            return self.train_iteration(data_loader, print_freq) 
開發者ID:iBelieveCJM,項目名稱:Tricks-of-Semi-supervisedDeepLeanring-Pytorch,代碼行數:6,代碼來源:PIv1.py


注:本文中的torch.enable_grad方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。