本文整理匯總了Python中torch.enable_grad方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.enable_grad方法的具體用法?Python torch.enable_grad怎麽用?Python torch.enable_grad使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch
的用法示例。
在下文中一共展示了torch.enable_grad方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
ctx.training = training
with torch.enable_grad():
x = x.detach().requires_grad_(True)
g = gnet(x)
ctx.g = g
ctx.x = x
logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)
if training:
grad_x, *grad_params = torch.autograd.grad(
logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
)
if grad_x is None:
grad_x = torch.zeros_like(x)
ctx.save_for_backward(grad_x, *g_params, *grad_params)
return safe_detach(g), safe_detach(logdetgrad)
示例2: test_fork_join_enable_grad
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def test_fork_join_enable_grad():
x = torch.rand(1, requires_grad=True)
with torch.enable_grad():
x2, p = fork(x)
assert p.requires_grad
assert x2 is not x
x = x2
assert x.requires_grad
assert p.requires_grad
assert x.grad_fn.__class__ is Fork._backward_cls
assert p.grad_fn.__class__ is Fork._backward_cls
with torch.enable_grad():
x2 = join(x, p)
assert x2 is not x
x = x2
assert x.requires_grad
assert x.grad_fn.__class__ is Join._backward_cls
示例3: compute_valid
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def compute_valid(self):
"""Rewrite this method to compute valid_epoch values.
You can return a ``dict`` of values that you want to visualize.
.. note::
This method is under ``torch.no_grad():``. So, it will never compute grad.
If you want to compute grad, please use ``torch.enable_grad():`` to wrap your operations.
Example::
d_fake = self.netD(self.fake.detach())
d_real = self.netD(self.ground_truth)
var_dic = {}
var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
return var_dic
"""
_, d_var_dic = self.compute_g_loss()
_, g_var_dic = self.compute_d_loss()
var_dic = dict(d_var_dic, **g_var_dic)
return var_dic
示例4: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def forward(self, inputs, targets):
if not args.attack:
return self.model(inputs), inputs
x = inputs.detach()
if self.rand:
x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
for i in range(self.num_steps):
x.requires_grad_()
with torch.enable_grad():
logits = self.model(x)
loss = F.cross_entropy(logits, targets, size_average=False)
grad = torch.autograd.grad(loss, [x])[0]
# print(grad)
x = x.detach() + self.step_size * torch.sign(grad.detach())
x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
x = torch.clamp(x, 0, 1)
return self.model(x), x
示例5: backward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrouding state
# when we're done.
rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
if preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_rng_state)
if ctx.had_cuda_in_fwd:
torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs)
示例6: __call__
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def __call__(self, *args, **kwargs) -> Any:
"""
Call super class with correct torch context
Args:
*args: forwarded positional arguments
**kwargs: forwarded keyword arguments
Returns:
Any: transformed data
"""
if self.grad:
context = torch.enable_grad()
else:
context = torch.no_grad()
with context:
return super().__call__(*args, **kwargs)
示例7: train
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def train(self, training_batch, train=True):
learning_input = training_batch.training_input
with torch.enable_grad():
action_preds = self.imitator(learning_input.state.float_features)
# Classification label is index of action with value 1
pred_action_idxs = torch.max(action_preds, dim=1)[1]
actual_action_idxs = torch.max(learning_input.action, dim=1)[1]
if train:
imitator_loss = torch.nn.CrossEntropyLoss()
bcq_loss = imitator_loss(action_preds, actual_action_idxs)
bcq_loss.backward()
self._maybe_run_optimizer(
self.imitator_optimizer, self.minibatches_per_step
)
return self._imitator_accuracy(pred_action_idxs, actual_action_idxs)
示例8: rsample
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def rsample(self, sample_shape, log_score=True):
"""
sample_shape: number of samples from the PL distribution. Scalar.
"""
with torch.enable_grad(): # torch.distributions turns off autograd
n_samples = sample_shape[0]
def sample_gumbel(samples_shape, eps=1e-20):
U = torch.zeros(samples_shape, device='cuda').uniform_()
return -torch.log(-torch.log(U + eps) + eps)
if not log_score:
log_s_perturb = torch.log(self.scores.unsqueeze(
0)) + sample_gumbel([n_samples, 1, self.n, 1])
else:
log_s_perturb = self.scores.unsqueeze(
0) + sample_gumbel([n_samples, 1, self.n, 1])
log_s_perturb = log_s_perturb.view(-1, self.n, 1)
P_hat = self.relaxed_sort(log_s_perturb)
P_hat = P_hat.view(n_samples, -1, self.n, self.n)
return P_hat.squeeze()
示例9: gradient_penalty
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def gradient_penalty(self, x_real, x_fake, netD, index=0):
device = x_real.device
with torch.enable_grad():
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
alpha = alpha.expand_as(x_real)
x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
x_hat.requires_grad = True
output = netD(x_hat)[index]
grad_output = torch.ones(output.size()).to(device)
grad = torch.autograd.grad(outputs=output,
inputs=x_hat,
grad_outputs=grad_output,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
grad = grad.view(grad.size(0), -1)
loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
return loss_gp
示例10: cond_gradient_penalty
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def cond_gradient_penalty(self, x_real, x_fake, y, netD, index=0):
device = x_real.device
with torch.enable_grad():
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
alpha = alpha.expand_as(x_real)
x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
x_hat.requires_grad = True
output = netD(x_hat, y)[index]
grad_output = torch.ones(output.size()).to(device)
grad = torch.autograd.grad(outputs=output,
inputs=x_hat,
grad_outputs=grad_output,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
grad = grad.contiguous().view(grad.size(0), -1)
loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
return loss_gp
示例11: backward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def backward(ctx, grad_output):
with torch.enable_grad():
input, target, mask = ctx.saved_tensors
former = input.narrow(1, 0, ctx.c//2)
former_in_mask = torch.mul(former, mask)
if former_in_mask.size() != target.size(): # For the last iteration of one epoch
target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
ctx.loss.backward()
grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad
return grad_output, None, None, None, None
示例12: perturb_hinge
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def perturb_hinge(net, x_nat):
# Perturb function based on (E[\phi(f(x)f(x'))])
# init with random noise
net.eval()
x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach()
for _ in range(args.num_steps):
x.requires_grad_()
with torch.enable_grad():
# perturb based on hinge loss
loss = torch.mean(torch.clamp(1 - net(x).squeeze(1) * (net(x_nat).squeeze(1) / args.beta), min=0))
grad = torch.autograd.grad(loss, [x])[0]
x = x.detach() + args.step_size * torch.sign(grad.detach())
x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon)
x = torch.clamp(x, 0.0, 1.0)
net.train()
return x
示例13: train
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def train(self,
data_loader: Iterable or DataLoader,
mode: str = TRAIN):
""" Training the model for an epoch.
:param data_loader:
:param mode: Name of this loop. Default is `train`. Passed to callbacks.
"""
self._is_train = True
self._epoch += 1
self.model.train()
if hasattr(self.loss_f, "train"):
self.loss_f.train()
with torch.enable_grad():
self._loop(data_loader, mode=mode)
if self.scheduler is not None and self.update_scheduler_by_epoch:
self.scheduler.step()
if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
data_loader.sampler.set_epoch(self.epoch)
示例14: forward
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def forward(self, model, bx, by, by_prime, curr_batch_size):
"""
:param model: the classifier's forward method
:param bx: batch of images
:param by: true labels
:return: perturbed batch of images
"""
adv_bx = bx.detach()
adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)
for i in range(self.num_steps):
adv_bx.requires_grad_()
with torch.enable_grad():
logits, pen = model(adv_bx * 2 - 1)
loss = F.cross_entropy(logits[:curr_batch_size], by, reduction='sum')
if self.attack_rotations:
loss += F.cross_entropy(model.module.rot_pred(pen[curr_batch_size:]), by_prime, reduction='sum') / 8.
grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]
adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())
adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)
return adv_bx
示例15: train
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import enable_grad [as 別名]
def train(self, data_loader, print_freq=20):
self.model.train()
with torch.enable_grad():
return self.train_iteration(data_loader, print_freq)