Python torch.is_floating_point方法代码示例

本文整理汇总了Python中torch.is_floating_point方法的典型用法代码示例。如果您正苦于以下问题:Python torch.is_floating_point方法的具体用法?Python torch.is_floating_point怎么用?Python torch.is_floating_point使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


示例1: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def forward(ctx, x, bbx, idx, roi_size, interpolation, padding, valid_mask):
        ctx.save_for_backward(bbx, idx)
        ctx.input_shape = (x.size(0), x.size(2), x.size(3))
        ctx.valid_mask = valid_mask

            ctx.interpolation = _INTERPOLATION[interpolation]
        except KeyError:
            raise ValueError("Unknown interpolation {}".format(interpolation))
            ctx.padding = _PADDING[padding]
        except KeyError:
            raise ValueError("Unknown padding {}".format(padding))

        y, mask = _backend.roi_sampling_forward(x, bbx, idx, roi_size, ctx.interpolation, ctx.padding, valid_mask)

        if not torch.is_floating_point(x):
        if valid_mask:
            return y, mask
            return y 

示例2: _check_inputs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def _check_inputs(func, y0, t):
    tensor_input = False
    if torch.is_tensor(y0):
        tensor_input = True
        y0 = (y0,)
        _base_nontuple_func_ = func
        func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
    assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
    for y0_ in y0:
        assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))

    if _decreasing(t):
        t = -t
        _base_reverse_func = func
        func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y))

    for y0_ in y0:
        if not torch.is_floating_point(y0_):
            raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
    if not torch.is_floating_point(t):
        raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))

    return tensor_input, func, y0, t 

示例3: scatter_mean

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:

    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count = broadcast(count, out, dim)
    if torch.is_floating_point(out):
    return out 

示例4: scatter_softmax

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                    eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_softmax` can only be computed over tensors '
                         'with floating point data types.')

    index = broadcast(index, src, dim)

    max_value_per_index = scatter_max(src, index, dim=dim)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element
    recentered_scores_exp = recentered_scores.exp()

    sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
    normalizing_constants = sum_per_index.add_(eps).gather(dim, index)

    return recentered_scores_exp.div(normalizing_constants) 

示例5: scatter_log_softmax

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                        eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_log_softmax` can only be computed over '
                         'tensors with floating point data types.')

    index = broadcast(index, src, dim)

    max_value_per_index = scatter_max(src, index, dim=dim)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)

    recentered_scores = src - max_per_src_element

    sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
    normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)

    return recentered_scores.sub_(normalizing_constants) 

示例6: _check_inputs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def _check_inputs(func, y0, t, f_options):
    tensor_input = False
    if torch.is_tensor(y0):
        tensor_input = True
        y0 = (y0,)
        _base_nontuple_func_ = func
        func = lambda t, y, **f_options: (_base_nontuple_func_(t, y[0], **f_options),)
    assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
    for y0_ in y0:
        assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))

    if _decreasing(t):
        t = -t
        _base_reverse_func = func
        func = lambda t, y, **f_options: tuple(-f_ for f_ in _base_reverse_func(-t, y, **f_options))

    for y0_ in y0:
        if not torch.is_floating_point(y0_):
            raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
    if not torch.is_floating_point(t):
        raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))

    return tensor_input, func, y0, t 

示例7: torch_tensor_to_schema

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def torch_tensor_to_schema(tensor):
    assert torch_installed, """Your Python environment does not have torch installed. You can install it with
    pip install torch
or with
    pip install 'lale[full]'"""
    assert isinstance(tensor, torch.Tensor)
    if tensor.dtype == torch.bool:
        result = {'type': 'boolean'}
    elif tensor.dtype == torch.uint8:
        result = {'type': 'integer', 'minimum': 0, 'maximum': 255}
    elif torch.is_floating_point(tensor):
        result = {'type': 'number'}
        result = {'type': 'integer'}
    for dim in reversed(tensor.shape):
        result = {
            'type': 'array',
            'minItems': dim, 'maxItems': dim,
            'items': result}
    return result 

示例8: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def backward(ctx, *args):
        if ctx.valid_mask:
            dy, _ = args
            dy = args[0]

        assert torch.is_floating_point(dy), "ROISampling.backward is only defined for floating point types"
        bbx, idx = ctx.saved_tensors

        dx = _backend.roi_sampling_backward(dy, bbx, idx, ctx.input_shape, ctx.interpolation, ctx.padding)
        return dx, None, None, None, None, None, None 

示例9: sample_to_cuda

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def sample_to_cuda(data, dtype=None):
    if isinstance(data, str):
        return data
    elif isinstance(data, dict):
        return {key: sample_to_cuda(data[key], dtype) for key in data.keys()}
    elif isinstance(data, list):
        return [sample_to_cuda(val, dtype) for val in data]
        # only convert floats (e.g., to half), otherwise preserve (e.g, ints)
        dtype = dtype if torch.is_floating_point(data) else None
        return data.to('cuda', dtype=dtype) 

示例10: scatter_logsumexp

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                      out: Optional[torch.Tensor] = None,
                      dim_size: Optional[int] = None,
                      eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_logsumexp` can only be computed over '
                         'tensors with floating point data types.')

    index = broadcast(index, src, dim)

    if out is not None:
        dim_size = out.size(dim)
        if dim_size is None:
            dim_size = int(index.max()) + 1

    size = list(src.size())
    size[dim] = dim_size
    max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
    scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)
    recentered_score = src - max_per_src_element
    recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))

    if out is not None:
        out = out.sub(max_per_src_element).exp()

    sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,

    return sum_per_index.add_(eps).log_().add_(max_value_per_index) 

示例11: is_floating_point

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def is_floating_point(self) -> bool:
        value = self.storage.value()
        return torch.is_floating_point(value) if value is not None else True 

示例12: is_tensor_like

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def is_tensor_like(x):
    return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)

# Wraps `torch.is_floating_point` if present, otherwise checks
# the suffix of `x.type()`. 

示例13: is_floating_point

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def is_floating_point(x):
    if hasattr(torch, 'is_floating_point'):
        return torch.is_floating_point(x)
        torch_type = x.type()
        return torch_type.endswith('FloatTensor') or \
            torch_type.endswith('HalfTensor') or \
    except AttributeError:
        return False 

示例14: _analyse

# 需要导入模块: import torch [as 别名]
# 或者: from torch import is_floating_point [as 别名]
def _analyse(self, module, function: str):
        def _correct_types(data, submodule, index, is_float: bool):
            correct_types = self.float_types if is_float else self.integer_types
            if not any(
                correct_type == submodule.weight.dtype for correct_type in correct_types
                data["type"]["float" if is_float else "integer"].append(index)

        def _correct_shapes(
            data, submodule, index, attributes, attribute_name, is_float: bool
            for attribute in attributes[type(submodule)]:
                if hasattr(submodule, attribute):
                    shape = getattr(submodule, attribute)
                    correct = shape % (8 if is_float else 16) == 0
                    if not correct:
                        data["shape"]["float" if is_float else "integer"][

        def _find_problems(data, submodule, index, is_float: bool):
            def _operation_problems(operation: str):
                for entry in ("inputs", "outputs"):
                        getattr(self, operation + "_" + entry),

            _correct_types(data, submodule, index, is_float)

            if isinstance(submodule, self.linear_types):
            elif isinstance(submodule, self.convolution_types):

        #                           MAIN FUNCTION

        data = {
            "type": {"float": [], "integer": []},
            "shape": {
                "float": {"inputs": [], "outputs": []},
                "integer": {"inputs": [], "outputs": []},

        for index, submodule in enumerate(getattr(module, function)()):
            if hasattr(submodule, "weight"):
                if torch.is_floating_point(submodule.weight):
                    _find_problems(data, submodule, index, is_float=True)
                    _find_problems(data, submodule, index, is_float=False)

        return data 
