本文整理汇总了Python中torch.is_grad_enabled方法的典型用法代码示例。如果您正苦于以下问题:Python torch.is_grad_enabled方法的具体用法?Python torch.is_grad_enabled怎么用?Python torch.is_grad_enabled使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.is_grad_enabled方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_simple_inference
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def test_simple_inference(self):
if not torch.cuda.is_available():
import pytest
pytest.skip('test requires GPU and torch+cuda')
ori_grad_enabled = torch.is_grad_enabled()
root_dir = os.path.dirname(os.path.dirname(__name__))
model_config = os.path.join(
root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
detector = MaskRCNNDetector(model_config)
await detector.init()
img_path = os.path.join(root_dir, 'demo/demo.jpg')
bboxes, _ = await detector.apredict(img_path)
self.assertTrue(bboxes)
# asy inference detector will hack grad_enabled,
# so restore here to avoid it to influence other tests
torch.set_grad_enabled(ori_grad_enabled)
示例2: _save_input
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def _save_input(self, module, input):
if torch.is_grad_enabled() and self.steps % self.Ts == 0:
classname = module.__class__.__name__
layer_info = None
if classname == 'Conv2d':
layer_info = (module.kernel_size, module.stride,
module.padding)
aa = compute_cov_a(input[0].data, classname, layer_info,
self.fast_cnn)
# Initialize buffers
if self.steps == 0:
self.m_aa[module] = aa.clone()
update_running_stat(aa, self.m_aa[module], self.stat_decay)
示例3: _get_rnn_output
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def _get_rnn_output(self, input_ids: Tensor, input_mask: Tensor,
first_subtokens: List[List[int]], last_subtokens: List[List[int]], mask: Tensor = None) \
-> Tensor:
# [batch, length, word_dim]
with torch.set_grad_enabled(self.fine_tune and torch.is_grad_enabled()):
sequence_output = self.bert(input_ids, attention_mask=input_mask)
if self.fine_tune:
sequence_output = sequence_output[0]
else:
sequence_output = torch.cat(tuple(sequence_output[2][-self.bert_layers:]), 2).detach()
batch, _, word_dim = sequence_output.size()
input = sequence_output.new_zeros((batch, max([len(fst) for fst in first_subtokens]), word_dim))
for i, subtokens_list_tuple in enumerate(zip(first_subtokens, last_subtokens)):
for j, subtokens_tuple in enumerate(zip(subtokens_list_tuple[0], subtokens_list_tuple[1])):
input[i, j, :] = torch.mean(sequence_output[i, subtokens_tuple[0]:subtokens_tuple[1], :], dim=0)
# output from rnn [batch, length, hidden_size]
output, hn = self.rnn(input, mask)
# apply dropout for the output of rnn
# [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
output = self.dropout_out(output.transpose(1, 2)).transpose(1, 2)
return output
示例4: async_inference_detector
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def async_inference_detector(model, img):
"""Async inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
Awaitable detection results.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
# We don't restore `torch.is_grad_enabled()` value during concurrent
# inference since execution can overlap
torch.set_grad_enabled(False)
result = await model.aforward_test(rescale=True, **data)
return result
示例5: replicate
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def replicate(self, module, device_ids):
return replicate(module, device_ids, not torch.is_grad_enabled())
示例6: fork
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
"""Branches out from an autograd lane of the given tensor."""
if torch.is_grad_enabled() and input.requires_grad:
input, phony = Fork.apply(input)
else:
phony = get_phony(input.device, requires_grad=False)
return input, phony
示例7: join
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def join(input: Tensor, phony: Tensor) -> Tensor:
"""Merges two autograd lanes."""
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
input = Join.apply(input, phony)
return input
示例8: test_grad_mode
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def test_grad_mode(grad_mode):
def detect_grad_enabled():
x = torch.rand(1, requires_grad=torch.is_grad_enabled())
return Batch(x)
with torch.set_grad_enabled(grad_mode):
with spawn_workers([torch.device('cpu')]) as (in_queues, out_queues):
task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
in_queues[0].put(task)
ok, (_, batch) = out_queues[0].get()
assert ok
assert batch[0].requires_grad == grad_mode
示例9: async_inference_detector
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def async_inference_detector(model, img):
"""Async inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
Awaitable detection results.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
# We don't restore `torch.is_grad_enabled()` value during concurrent
# inference since execution can overlap
torch.set_grad_enabled(False)
result = await model.aforward_test(rescale=True, **data)
return result
# TODO: merge this method with the one in BaseDetector
示例10: train_step
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
self.require_forward_param_sync = False
return output
示例11: val_step
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
self.require_forward_param_sync = False
return output
示例12: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def forward(self, *args):
# if torch.is_grad_enabled() and self.module.training:
if self.module.training:
self._update_u_v()
else:
self._noupdate_u_v()
return self.module.forward(*args)
示例13: test_helper_train
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def test_helper_train(self):
"""
Tests train/eval mode helper methods
"""
rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
rnn.train()
self.assertTrue(torch.is_grad_enabled())
self.assertTrue(rnn.nn.training)
rnn.eval()
self.assertFalse(torch.is_grad_enabled())
self.assertFalse(rnn.nn.training)
示例14: scatter
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def scatter(self, inputs, kwargs, device_ids):
try:
params = kwargs.pop('params')
except KeyError:
return super(DataParallel, self).scatter(inputs, kwargs, device_ids)
inputs_, kwargs_ = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Add params argument unchanged back in kwargs
replicas = self._replicate_params(params, inputs_, device_ids,
detach=not torch.is_grad_enabled())
kwargs_ = tuple(dict(params=replica, **kwarg)
for (kwarg, replica) in zip(kwargs_, replicas))
return inputs_, kwargs_
示例15: cached_cast
# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def cached_cast(cast_fn, x, cache):
if is_nested(x):
return type(x)([cached_cast(y) for y in x])
if x in cache:
cached_x = cache[x]
if x.requires_grad and cached_x.requires_grad:
# Make sure x is actually cached_x's autograd parent.
if cached_x.grad_fn.next_functions[1][0].variable is not x:
raise RuntimeError("x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error.")
# During eval, it's possible to end up caching casted weights with
# requires_grad=False. On the next training iter, if cached_x is found
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
#
# During eval (i.e. running under with torch.no_grad()) the invalidation
# check would cause the cached value to be dropped every time, because
# cached_x would always be created with requires_grad=False, while x would
# still have requires_grad=True. This would render the cache effectively
# useless during eval. Therefore, if we are running under the no_grad()
# context manager (torch.is_grad_enabled=False) we elide the invalidation
# check, and use the cached value even though its requires_grad flag doesn't
# match. During eval, we don't care that there's no autograd-graph
# connection between x and cached_x.
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
del cache[x]
else:
return cached_x
casted_x = cast_fn(x)
cache[x] = casted_x
return casted_x