当前位置: 首页>>代码示例>>Python>>正文


Python torch.enable_grad方法代码示例

本文整理汇总了Python中torch.enable_grad方法的典型用法代码示例。如果您正苦于以下问题:Python torch.enable_grad方法的具体用法?Python torch.enable_grad怎么用?Python torch.enable_grad使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.enable_grad方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
        ctx.training = training
        with torch.enable_grad():
            x = x.detach().requires_grad_(True)
            g = gnet(x)
            ctx.g = g
            ctx.x = x
            logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)

            if training:
                grad_x, *grad_params = torch.autograd.grad(
                    logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
                )
                if grad_x is None:
                    grad_x = torch.zeros_like(x)
                ctx.save_for_backward(grad_x, *g_params, *grad_params)

        return safe_detach(g), safe_detach(logdetgrad) 
开发者ID:rtqichen,项目名称:residual-flows,代码行数:20,代码来源:iresblock.py

示例2: test_fork_join_enable_grad

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def test_fork_join_enable_grad():
    x = torch.rand(1, requires_grad=True)

    with torch.enable_grad():
        x2, p = fork(x)

    assert p.requires_grad
    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert p.requires_grad
    assert x.grad_fn.__class__ is Fork._backward_cls
    assert p.grad_fn.__class__ is Fork._backward_cls

    with torch.enable_grad():
        x2 = join(x, p)

    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert x.grad_fn.__class__ is Join._backward_cls 
开发者ID:kakaobrain,项目名称:torchgpipe,代码行数:25,代码来源:test_dependency.py

示例3: compute_valid

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def compute_valid(self):
        """Rewrite this method to compute valid_epoch values.

        You can return a ``dict`` of values that you want to visualize.

        .. note::

            This method is under ``torch.no_grad():``. So, it will never compute grad.
            If you want to compute grad, please use ``torch.enable_grad():`` to wrap your operations.

        Example::

            d_fake = self.netD(self.fake.detach())
            d_real = self.netD(self.ground_truth)
            var_dic = {}
            var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
            return var_dic

        """
        _, d_var_dic = self.compute_g_loss()
        _, g_var_dic = self.compute_d_loss()
        var_dic = dict(d_var_dic, **g_var_dic)
        return var_dic 
开发者ID:dingguanglei,项目名称:jdit,代码行数:25,代码来源:pix2pix.py

示例4: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def forward(self, inputs, targets):
        if not args.attack:
            return self.model(inputs), inputs

        x = inputs.detach()
        if self.rand:
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, targets, size_average=False)
            grad = torch.autograd.grad(loss, [x])[0]
            # print(grad)
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
            x = torch.clamp(x, 0, 1)

        return self.model(x), x 
开发者ID:YyzHarry,项目名称:ME-Net,代码行数:21,代码来源:train_adv.py

示例5: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrouding state
        # when we're done.
        rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
        with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
            if preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_rng_state)
                if ctx.had_cuda_in_fwd:
                    torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs) 
开发者ID:Lyken17,项目名称:pytorch-memonger,代码行数:23,代码来源:checkpoint.py

示例6: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def __call__(self, *args, **kwargs) -> Any:
        """
        Call super class with correct torch context

        Args:
            *args: forwarded positional arguments
            **kwargs: forwarded keyword arguments

        Returns:
            Any: transformed data

        """
        if self.grad:
            context = torch.enable_grad()
        else:
            context = torch.no_grad()

        with context:
            return super().__call__(*args, **kwargs) 
开发者ID:PhoenixDL,项目名称:rising,代码行数:21,代码来源:abstract.py

示例7: train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def train(self, training_batch, train=True):
        learning_input = training_batch.training_input

        with torch.enable_grad():
            action_preds = self.imitator(learning_input.state.float_features)
            # Classification label is index of action with value 1
            pred_action_idxs = torch.max(action_preds, dim=1)[1]
            actual_action_idxs = torch.max(learning_input.action, dim=1)[1]

            if train:
                imitator_loss = torch.nn.CrossEntropyLoss()
                bcq_loss = imitator_loss(action_preds, actual_action_idxs)
                bcq_loss.backward()
                self._maybe_run_optimizer(
                    self.imitator_optimizer, self.minibatches_per_step
                )

        return self._imitator_accuracy(pred_action_idxs, actual_action_idxs) 
开发者ID:facebookresearch,项目名称:ReAgent,代码行数:20,代码来源:imitator_training.py

示例8: rsample

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def rsample(self, sample_shape, log_score=True):
        """
        sample_shape: number of samples from the PL distribution. Scalar.
        """
        with torch.enable_grad():  # torch.distributions turns off autograd
            n_samples = sample_shape[0]

            def sample_gumbel(samples_shape, eps=1e-20):
                U = torch.zeros(samples_shape, device='cuda').uniform_()
                return -torch.log(-torch.log(U + eps) + eps)
            if not log_score:
                log_s_perturb = torch.log(self.scores.unsqueeze(
                    0)) + sample_gumbel([n_samples, 1, self.n, 1])
            else:
                log_s_perturb = self.scores.unsqueeze(
                    0) + sample_gumbel([n_samples, 1, self.n, 1])
            log_s_perturb = log_s_perturb.view(-1, self.n, 1)
            P_hat = self.relaxed_sort(log_s_perturb)
            P_hat = P_hat.view(n_samples, -1, self.n, self.n)

            return P_hat.squeeze() 
开发者ID:ermongroup,项目名称:neuralsort,代码行数:23,代码来源:pl.py

示例9: gradient_penalty

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def gradient_penalty(self, x_real, x_fake, netD, index=0):
        device = x_real.device
        with torch.enable_grad():
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
            alpha = alpha.expand_as(x_real)
            x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
            x_hat.requires_grad = True
            output = netD(x_hat)[index]
            grad_output = torch.ones(output.size()).to(device)
            grad = torch.autograd.grad(outputs=output,
                                       inputs=x_hat,
                                       grad_outputs=grad_output,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            grad = grad.view(grad.size(0), -1)
            loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
        return loss_gp 
开发者ID:takuhirok,项目名称:rGAN,代码行数:20,代码来源:racgan_trainer.py

示例10: cond_gradient_penalty

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def cond_gradient_penalty(self, x_real, x_fake, y, netD, index=0):
        device = x_real.device
        with torch.enable_grad():
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
            alpha = alpha.expand_as(x_real)
            x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
            x_hat.requires_grad = True
            output = netD(x_hat, y)[index]
            grad_output = torch.ones(output.size()).to(device)
            grad = torch.autograd.grad(outputs=output,
                                       inputs=x_hat,
                                       grad_outputs=grad_output,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            grad = grad.contiguous().view(grad.size(0), -1)
            loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
        return loss_gp 
开发者ID:takuhirok,项目名称:rGAN,代码行数:20,代码来源:rcgan_trainer.py

示例11: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def backward(ctx, grad_output):

        with torch.enable_grad():
            input, target, mask = ctx.saved_tensors
            former = input.narrow(1, 0, ctx.c//2)
            former_in_mask = torch.mul(former, mask)
            if former_in_mask.size() != target.size():  # For the last iteration of one epoch
                target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
            
            former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
            ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
            ctx.loss.backward()

        grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad  
        
        return grad_output, None, None, None, None 
开发者ID:Zhaoyi-Yan,项目名称:Shift-Net_pytorch,代码行数:18,代码来源:InnerCosFunction.py

示例12: perturb_hinge

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def perturb_hinge(net, x_nat):
    # Perturb function based on (E[\phi(f(x)f(x'))])
    # init with random noise
    net.eval()
    x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach()
    for _ in range(args.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
            # perturb based on hinge loss
            loss = torch.mean(torch.clamp(1 - net(x).squeeze(1) * (net(x_nat).squeeze(1) / args.beta), min=0))
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + args.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon)
        x = torch.clamp(x, 0.0, 1.0)
    net.train()
    return x 
开发者ID:yaodongyu,项目名称:TRADES,代码行数:18,代码来源:train_trades_mnist_binary.py

示例13: train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def train(self,
              data_loader: Iterable or DataLoader,
              mode: str = TRAIN):
        """ Training the model for an epoch.

        :param data_loader:
        :param mode: Name of this loop. Default is `train`. Passed to callbacks.
        """

        self._is_train = True
        self._epoch += 1
        self.model.train()
        if hasattr(self.loss_f, "train"):
            self.loss_f.train()
        with torch.enable_grad():
            self._loop(data_loader, mode=mode)

        if self.scheduler is not None and self.update_scheduler_by_epoch:
            self.scheduler.step()

        if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
            data_loader.sampler.set_epoch(self.epoch) 
开发者ID:moskomule,项目名称:homura,代码行数:24,代码来源:trainers.py

示例14: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def forward(self, model, bx, by, by_prime, curr_batch_size):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)

        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, pen = model(adv_bx * 2 - 1)
                loss = F.cross_entropy(logits[:curr_batch_size], by, reduction='sum')
                if self.attack_rotations:
                    loss += F.cross_entropy(model.module.rot_pred(pen[curr_batch_size:]), by_prime, reduction='sum') / 8.
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]

            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())

            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)

        return adv_bx 
开发者ID:hendrycks,项目名称:ss-ood,代码行数:26,代码来源:attacks.py

示例15: train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import enable_grad [as 别名]
def train(self, data_loader, print_freq=20):
        self.model.train()
        with torch.enable_grad():
            return self.train_iteration(data_loader, print_freq) 
开发者ID:iBelieveCJM,项目名称:Tricks-of-Semi-supervisedDeepLeanring-Pytorch,代码行数:6,代码来源:PIv1.py


注:本文中的torch.enable_grad方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。