当前位置: 首页>>代码示例>>Python>>正文


Python torch.is_floating_point方法代码示例

本文整理汇总了Python中torch.is_floating_point方法的典型用法代码示例。如果您正苦于以下问题:Python torch.is_floating_point方法的具体用法?Python torch.is_floating_point怎么用?Python torch.is_floating_point使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.is_floating_point方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def forward(ctx, x, bbx, idx, roi_size, interpolation, padding, valid_mask):
        ctx.save_for_backward(bbx, idx)
        ctx.input_shape = (x.size(0), x.size(2), x.size(3))
        ctx.valid_mask = valid_mask

        try:
            ctx.interpolation = _INTERPOLATION[interpolation]
        except KeyError:
            raise ValueError("Unknown interpolation {}".format(interpolation))
        try:
            ctx.padding = _PADDING[padding]
        except KeyError:
            raise ValueError("Unknown padding {}".format(padding))

        y, mask = _backend.roi_sampling_forward(x, bbx, idx, roi_size, ctx.interpolation, ctx.padding, valid_mask)

        if not torch.is_floating_point(x):
            ctx.mark_non_differentiable(y)
        if valid_mask:
            ctx.mark_non_differentiable(mask)
            return y, mask
        else:
            return y 
开发者ID:mapillary,项目名称:seamseg,代码行数:25,代码来源:functions.py

示例2: _check_inputs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def _check_inputs(func, y0, t):
    tensor_input = False
    if torch.is_tensor(y0):
        tensor_input = True
        y0 = (y0,)
        _base_nontuple_func_ = func
        func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
    assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
    for y0_ in y0:
        assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))

    if _decreasing(t):
        t = -t
        _base_reverse_func = func
        func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y))

    for y0_ in y0:
        if not torch.is_floating_point(y0_):
            raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
    if not torch.is_floating_point(t):
        raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))

    return tensor_input, func, y0, t 
开发者ID:rtqichen,项目名称:torchdiffeq,代码行数:25,代码来源:misc.py

示例3: scatter_mean

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:

    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count.clamp_(1)
    count = broadcast(count, out, dim)
    if torch.is_floating_point(out):
        out.true_divide_(count)
    else:
        out.floor_divide_(count)
    return out 
开发者ID:rusty1s,项目名称:pytorch_scatter,代码行数:24,代码来源:scatter.py

示例4: scatter_softmax

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                    eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_softmax` can only be computed over tensors '
                         'with floating point data types.')

    index = broadcast(index, src, dim)

    max_value_per_index = scatter_max(src, index, dim=dim)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element
    recentered_scores_exp = recentered_scores.exp()

    sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
    normalizing_constants = sum_per_index.add_(eps).gather(dim, index)

    return recentered_scores_exp.div(normalizing_constants) 
开发者ID:rusty1s,项目名称:pytorch_scatter,代码行数:20,代码来源:softmax.py

示例5: scatter_log_softmax

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                        eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_log_softmax` can only be computed over '
                         'tensors with floating point data types.')

    index = broadcast(index, src, dim)

    max_value_per_index = scatter_max(src, index, dim=dim)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element

    sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
    normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)

    return recentered_scores.sub_(normalizing_constants) 
开发者ID:rusty1s,项目名称:pytorch_scatter,代码行数:19,代码来源:softmax.py

示例6: _check_inputs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def _check_inputs(func, y0, t, f_options):
    tensor_input = False
    if torch.is_tensor(y0):
        tensor_input = True
        y0 = (y0,)
        _base_nontuple_func_ = func
        func = lambda t, y, **f_options: (_base_nontuple_func_(t, y[0], **f_options),)
    assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
    for y0_ in y0:
        assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))

    if _decreasing(t):
        t = -t
        _base_reverse_func = func
        func = lambda t, y, **f_options: tuple(-f_ for f_ in _base_reverse_func(-t, y, **f_options))

    for y0_ in y0:
        if not torch.is_floating_point(y0_):
            raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
    if not torch.is_floating_point(t):
        raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))

    return tensor_input, func, y0, t 
开发者ID:autonomousvision,项目名称:occupancy_flow,代码行数:25,代码来源:misc.py

示例7: torch_tensor_to_schema

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def torch_tensor_to_schema(tensor):
    assert torch_installed, """Your Python environment does not have torch installed. You can install it with
    pip install torch
or with
    pip install 'lale[full]'"""
    assert isinstance(tensor, torch.Tensor)
    #https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
    if tensor.dtype == torch.bool:
        result = {'type': 'boolean'}
    elif tensor.dtype == torch.uint8:
        result = {'type': 'integer', 'minimum': 0, 'maximum': 255}
    elif torch.is_floating_point(tensor):
        result = {'type': 'number'}
    else:
        result = {'type': 'integer'}
    for dim in reversed(tensor.shape):
        result = {
            'type': 'array',
            'minItems': dim, 'maxItems': dim,
            'items': result}
    return result 
开发者ID:IBM,项目名称:lale,代码行数:23,代码来源:data_schemas.py

示例8: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def backward(ctx, *args):
        if ctx.valid_mask:
            dy, _ = args
        else:
            dy = args[0]

        assert torch.is_floating_point(dy), "ROISampling.backward is only defined for floating point types"
        bbx, idx = ctx.saved_tensors

        dx = _backend.roi_sampling_backward(dy, bbx, idx, ctx.input_shape, ctx.interpolation, ctx.padding)
        return dx, None, None, None, None, None, None 
开发者ID:mapillary,项目名称:seamseg,代码行数:13,代码来源:functions.py

示例9: sample_to_cuda

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def sample_to_cuda(data, dtype=None):
    if isinstance(data, str):
        return data
    elif isinstance(data, dict):
        return {key: sample_to_cuda(data[key], dtype) for key in data.keys()}
    elif isinstance(data, list):
        return [sample_to_cuda(val, dtype) for val in data]
    else:
        # only convert floats (e.g., to half), otherwise preserve (e.g, ints)
        dtype = dtype if torch.is_floating_point(data) else None
        return data.to('cuda', dtype=dtype) 
开发者ID:TRI-ML,项目名称:packnet-sfm,代码行数:13,代码来源:base_trainer.py

示例10: scatter_logsumexp

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                      out: Optional[torch.Tensor] = None,
                      dim_size: Optional[int] = None,
                      eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_logsumexp` can only be computed over '
                         'tensors with floating point data types.')

    index = broadcast(index, src, dim)

    if out is not None:
        dim_size = out.size(dim)
    else:
        if dim_size is None:
            dim_size = int(index.max()) + 1

    size = list(src.size())
    size[dim] = dim_size
    max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
                                     device=src.device)
    scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)
    recentered_score = src - max_per_src_element
    recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))

    if out is not None:
        out = out.sub(max_per_src_element).exp()

    sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
                                dim_size)

    return sum_per_index.add_(eps).log_().add_(max_value_per_index) 
开发者ID:rusty1s,项目名称:pytorch_scatter,代码行数:34,代码来源:logsumexp.py

示例11: is_floating_point

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def is_floating_point(self) -> bool:
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True 
开发者ID:rusty1s,项目名称:pytorch_sparse,代码行数:5,代码来源:tensor.py

示例12: is_tensor_like

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def is_tensor_like(x):
    return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)

# Wraps `torch.is_floating_point` if present, otherwise checks
# the suffix of `x.type()`. 
开发者ID:NVIDIA,项目名称:apex,代码行数:7,代码来源:compat.py

示例13: is_floating_point

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def is_floating_point(x):
    if hasattr(torch, 'is_floating_point'):
        return torch.is_floating_point(x)
    try:
        torch_type = x.type()
        return torch_type.endswith('FloatTensor') or \
            torch_type.endswith('HalfTensor') or \
            torch_type.endswith('DoubleTensor')
    except AttributeError:
        return False 
开发者ID:NVIDIA,项目名称:apex,代码行数:12,代码来源:compat.py

示例14: _analyse

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def _analyse(self, module, function: str):
        def _correct_types(data, submodule, index, is_float: bool):
            correct_types = self.float_types if is_float else self.integer_types
            if not any(
                correct_type == submodule.weight.dtype for correct_type in correct_types
            ):
                data["type"]["float" if is_float else "integer"].append(index)

        def _correct_shapes(
            data, submodule, index, attributes, attribute_name, is_float: bool
        ):
            for attribute in attributes[type(submodule)]:
                if hasattr(submodule, attribute):
                    shape = getattr(submodule, attribute)
                    correct = shape % (8 if is_float else 16) == 0
                    if not correct:
                        data["shape"]["float" if is_float else "integer"][
                            attribute_name
                        ].append(index)

        def _find_problems(data, submodule, index, is_float: bool):
            def _operation_problems(operation: str):
                for entry in ("inputs", "outputs"):
                    _correct_shapes(
                        data,
                        submodule,
                        index,
                        getattr(self, operation + "_" + entry),
                        entry,
                        is_float,
                    )

            _correct_types(data, submodule, index, is_float)

            if isinstance(submodule, self.linear_types):
                _operation_problems("linear")
            elif isinstance(submodule, self.convolution_types):
                _operation_problems("convolution")

        #######################################################################
        #
        #                           MAIN FUNCTION
        #
        #######################################################################

        data = {
            "type": {"float": [], "integer": []},
            "shape": {
                "float": {"inputs": [], "outputs": []},
                "integer": {"inputs": [], "outputs": []},
            },
        }

        for index, submodule in enumerate(getattr(module, function)()):
            if hasattr(submodule, "weight"):
                if torch.is_floating_point(submodule.weight):
                    _find_problems(data, submodule, index, is_float=True)
                else:
                    _find_problems(data, submodule, index, is_float=False)

        return data 
开发者ID:szymonmaszke,项目名称:torchfunc,代码行数:63,代码来源:technology.py


注:本文中的torch.is_floating_point方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。