当前位置: 首页>>代码示例>>Python>>正文


Python torch.is_grad_enabled方法代码示例

本文整理汇总了Python中torch.is_grad_enabled方法的典型用法代码示例。如果您正苦于以下问题:Python torch.is_grad_enabled方法的具体用法?Python torch.is_grad_enabled怎么用?Python torch.is_grad_enabled使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.is_grad_enabled方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_simple_inference

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def test_simple_inference(self):
            if not torch.cuda.is_available():
                import pytest

                pytest.skip('test requires GPU and torch+cuda')

            ori_grad_enabled = torch.is_grad_enabled()
            root_dir = os.path.dirname(os.path.dirname(__name__))
            model_config = os.path.join(
                root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
            detector = MaskRCNNDetector(model_config)
            await detector.init()
            img_path = os.path.join(root_dir, 'demo/demo.jpg')
            bboxes, _ = await detector.apredict(img_path)
            self.assertTrue(bboxes)
            # asy inference detector will hack grad_enabled,
            # so restore here to avoid it to influence other tests
            torch.set_grad_enabled(ori_grad_enabled) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:20,代码来源:test_async.py

示例2: _save_input

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
开发者ID:montrealrobotics,项目名称:dal,代码行数:18,代码来源:kfac.py

示例3: _get_rnn_output

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def _get_rnn_output(self, input_ids: Tensor, input_mask: Tensor,
                        first_subtokens: List[List[int]], last_subtokens: List[List[int]], mask: Tensor = None) \
            -> Tensor:
        # [batch, length, word_dim]
        with torch.set_grad_enabled(self.fine_tune and torch.is_grad_enabled()):
            sequence_output = self.bert(input_ids, attention_mask=input_mask)
            if self.fine_tune:
                sequence_output = sequence_output[0]
            else:
                sequence_output = torch.cat(tuple(sequence_output[2][-self.bert_layers:]), 2).detach()
            batch, _, word_dim = sequence_output.size()
            input = sequence_output.new_zeros((batch, max([len(fst) for fst in first_subtokens]), word_dim))
            for i, subtokens_list_tuple in enumerate(zip(first_subtokens, last_subtokens)):
                for j, subtokens_tuple in enumerate(zip(subtokens_list_tuple[0], subtokens_list_tuple[1])):
                    input[i, j, :] = torch.mean(sequence_output[i, subtokens_tuple[0]:subtokens_tuple[1], :], dim=0)
        # output from rnn [batch, length, hidden_size]
        output, hn = self.rnn(input, mask)

        # apply dropout for the output of rnn
        # [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
        output = self.dropout_out(output.transpose(1, 2)).transpose(1, 2)

        return output 
开发者ID:yahshibu,项目名称:nested-ner-tacl2020-transformers,代码行数:25,代码来源:sequence_labeling.py

示例4: async_inference_detector

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def async_inference_detector(model, img):
    """Async inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        Awaitable detection results.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    # We don't restore `torch.is_grad_enabled()` value during concurrent
    # inference since execution can overlap
    torch.set_grad_enabled(False)
    result = await model.aforward_test(rescale=True, **data)
    return result 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:28,代码来源:inference.py

示例5: replicate

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def replicate(self, module, device_ids):
        return replicate(module, device_ids, not torch.is_grad_enabled()) 
开发者ID:PistonY,项目名称:torch-toolbox,代码行数:4,代码来源:EncodingDataParallel.py

示例6: fork

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
开发者ID:kakaobrain,项目名称:torchgpipe,代码行数:10,代码来源:dependency.py

示例7: join

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input 
开发者ID:kakaobrain,项目名称:torchgpipe,代码行数:8,代码来源:dependency.py

示例8: test_grad_mode

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def test_grad_mode(grad_mode):
    def detect_grad_enabled():
        x = torch.rand(1, requires_grad=torch.is_grad_enabled())
        return Batch(x)

    with torch.set_grad_enabled(grad_mode):
        with spawn_workers([torch.device('cpu')]) as (in_queues, out_queues):
            task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
            in_queues[0].put(task)

            ok, (_, batch) = out_queues[0].get()

            assert ok
            assert batch[0].requires_grad == grad_mode 
开发者ID:kakaobrain,项目名称:torchgpipe,代码行数:16,代码来源:test_worker.py

示例9: async_inference_detector

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def async_inference_detector(model, img):
    """Async inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        Awaitable detection results.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    # We don't restore `torch.is_grad_enabled()` value during concurrent
    # inference since execution can overlap
    torch.set_grad_enabled(False)
    result = await model.aforward_test(rescale=True, **data)
    return result


# TODO: merge this method with the one in BaseDetector 
开发者ID:open-mmlab,项目名称:mmfashion,代码行数:31,代码来源:inference.py

示例10: train_step

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def train_step(self, *inputs, **kwargs):
        """train_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.train_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.train_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.train_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:33,代码来源:distributed.py

示例11: val_step

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def val_step(self, *inputs, **kwargs):
        """val_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.val_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.val_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.val_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:33,代码来源:distributed.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def forward(self, *args):
        # if torch.is_grad_enabled() and self.module.training:
        if self.module.training:
            self._update_u_v()
        else:
            self._noupdate_u_v()
        return self.module.forward(*args) 
开发者ID:Yaoyi-Li,项目名称:GCA-Matting,代码行数:9,代码来源:ops.py

示例13: test_helper_train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def test_helper_train(self):
        """
        Tests train/eval mode helper methods
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.train()
        self.assertTrue(torch.is_grad_enabled())
        self.assertTrue(rnn.nn.training)
        rnn.eval()
        self.assertFalse(torch.is_grad_enabled())
        self.assertFalse(rnn.nn.training) 
开发者ID:mittagessen,项目名称:kraken,代码行数:13,代码来源:test_vgsl.py

示例14: scatter

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def scatter(self, inputs, kwargs, device_ids):
        try:
            params = kwargs.pop('params')
        except KeyError:
            return super(DataParallel, self).scatter(inputs, kwargs, device_ids)

        inputs_, kwargs_ = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
        # Add params argument unchanged back in kwargs
        replicas = self._replicate_params(params, inputs_, device_ids,
                                          detach=not torch.is_grad_enabled())
        kwargs_ = tuple(dict(params=replica, **kwarg)
                        for (kwarg, replica) in zip(kwargs_, replicas))
        return inputs_, kwargs_ 
开发者ID:tristandeleu,项目名称:pytorch-meta,代码行数:15,代码来源:parallel.py

示例15: cached_cast

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_grad_enabled [as 别名]
def cached_cast(cast_fn, x, cache):
    if is_nested(x):
        return type(x)([cached_cast(y) for y in x])
    if x in cache:
        cached_x = cache[x]
        if x.requires_grad and cached_x.requires_grad:
            # Make sure x is actually cached_x's autograd parent.
            if cached_x.grad_fn.next_functions[1][0].variable is not x:
                raise RuntimeError("x and cache[x] both require grad, but x is not "
                                   "cache[x]'s parent.  This is likely an error.")
        # During eval, it's possible to end up caching casted weights with
        # requires_grad=False.  On the next training iter, if cached_x is found
        # and reused from the cache, it will not actually have x as its parent.
        # Therefore, we choose to invalidate the cache (and force refreshing the cast)
        # if x.requires_grad and cached_x.requires_grad do not match.
        #
        # During eval (i.e. running under with torch.no_grad()) the invalidation
        # check would cause the cached value to be dropped every time, because
        # cached_x would always be created with requires_grad=False, while x would
        # still have requires_grad=True.  This would render the cache effectively
        # useless during eval.  Therefore, if we are running under the no_grad()
        # context manager (torch.is_grad_enabled=False) we elide the invalidation
        # check, and use the cached value even though its requires_grad flag doesn't
        # match.  During eval, we don't care that there's no autograd-graph
        # connection between x and cached_x.
        if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
            del cache[x]
        else:
            return cached_x

    casted_x = cast_fn(x)
    cache[x] = casted_x
    return casted_x 
开发者ID:NVIDIA,项目名称:apex,代码行数:35,代码来源:utils.py


注:本文中的torch.is_grad_enabled方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。