当前位置: 首页>>代码示例>>Python>>正文


Python torch.dtype方法代码示例

本文整理汇总了Python中torch.dtype方法的典型用法代码示例。如果您正苦于以下问题:Python torch.dtype方法的具体用法?Python torch.dtype怎么用?Python torch.dtype使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.dtype方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_points

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def get_points(self, featmap_sizes, dtype, device, flatten=False):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self._get_points_single(featmap_sizes[i], self.strides[i],
                                        dtype, device, flatten))
        return mlvl_points 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:19,代码来源:anchor_free_head.py

示例2: patch_forward_method

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.

    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.

    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:23,代码来源:hooks.py

示例3: get_points

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def get_points(self, featmap_sizes, dtype, device):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self.get_points_single(featmap_sizes[i], self.strides[i],
                                       dtype, device))
        return mlvl_points 
开发者ID:dingjiansw101,项目名称:AerialDetection,代码行数:19,代码来源:fcos_head.py

示例4: origin

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def origin(
        self, *size, dtype=None, device=None, seed=42
    ) -> "geoopt.ManifoldTensor":
        """
        Zero point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        return geoopt.ManifoldTensor(
            torch.zeros(*size, dtype=dtype, device=device), manifold=self
        ) 
开发者ID:geoopt,项目名称:geoopt,代码行数:27,代码来源:euclidean.py

示例5: random_naive

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor:
        """
        Naive approach to get random matrix on Birkhoff Polytope manifold.

        A helper function to sample a random point on the Birkhoff Polytope manifold.
        The measure is non-uniform for this method, but fast to compute.

        Parameters
        ----------
        size : shape
            the desired output shape
        dtype : torch.dtype
            desired dtype
        device : torch.device
            desired device

        Returns
        -------
        ManifoldTensor
            random point on Birkhoff Polytope manifold
        """
        self._assert_check_shape(size2shape(*size), "x")
        # projection requires all values be non-negative
        tens = torch.randn(*size, device=device, dtype=dtype).abs_()
        return ManifoldTensor(self.projx(tens), manifold=self) 
开发者ID:geoopt,项目名称:geoopt,代码行数:27,代码来源:birkhoff_polytope.py

示例6: origin

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        shape = size2shape(*size)
        self._assert_check_shape(shape, "x")
        eye = torch.eye(*shape[-2:], dtype=dtype, device=device)
        eye = eye.expand(shape)
        return ManifoldTensor(eye, manifold=self) 
开发者ID:geoopt,项目名称:geoopt,代码行数:26,代码来源:birkhoff_polytope.py

示例7: origin

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        eye = torch.zeros(*size, dtype=dtype, device=device)
        eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
        return ManifoldTensor(eye, manifold=self) 
开发者ID:geoopt,项目名称:geoopt,代码行数:25,代码来源:stiefel.py

示例8: degree

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def degree(index, num_nodes: Optional[int] = None,
           dtype: Optional[int] = None):
    r"""Computes the (unweighted) degree of a given one-dimensional index
    tensor.

    Args:
        index (LongTensor): Index tensor.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned tensor.

    :rtype: :class:`Tensor`
    """
    N = maybe_num_nodes(index, num_nodes)
    out = torch.zeros((N, ), dtype=dtype, device=index.device)
    one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
    return out.scatter_add_(0, index, one) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:20,代码来源:degree.py

示例9: grid

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def grid(height, width, dtype=None, device=None):
    r"""Returns the edge indices of a two-dimensional grid graph with height
    :attr:`height` and width :attr:`width` and its node positions.

    Args:
        height (int): The height of the grid.
        width (int): The width of the grid.
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned position tensor.
        dtype (:obj:`torch.device`, optional): The desired device of the
            returned tensors.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    edge_index = grid_index(height, width, device)
    pos = grid_pos(height, width, dtype, device)
    return edge_index, pos 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:20,代码来源:grid.py

示例10: grid_index

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def grid_index(height, width, device=None):
    w = width
    kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
    kernel = torch.tensor(kernel, device=device)

    row = torch.arange(height * width, dtype=torch.long, device=device)
    row = row.view(-1, 1).repeat(1, kernel.size(0))
    col = row + kernel.view(1, -1)
    row, col = row.view(height, -1), col.view(height, -1)
    index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device)
    row, col = row[:, index].view(-1), col[:, index].view(-1)

    mask = (col >= 0) & (col < height * width)
    row, col = row[mask], col[mask]

    edge_index = torch.stack([row, col], dim=0)
    edge_index, _ = coalesce(edge_index, None, height * width, height * width)

    return edge_index 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:21,代码来源:grid.py

示例11: patch_forward_method

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.
    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.
    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
开发者ID:DeepMotionAIResearch,项目名称:DenseMatchingBenchmark,代码行数:21,代码来源:hooks.py

示例12: torch_dtype_to_np_dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def torch_dtype_to_np_dtype(dtype):
    dtype_dict = {
            torch.bool    : np.dtype(np.bool),
            torch.uint8   : np.dtype(np.uint8),
            torch.int8    : np.dtype(np.int8),
            torch.int16   : np.dtype(np.int16),
            torch.short   : np.dtype(np.int16),
            torch.int32   : np.dtype(np.int32),
            torch.int     : np.dtype(np.int32),
            torch.int64   : np.dtype(np.int64),
            torch.long    : np.dtype(np.int64),
            torch.float16 : np.dtype(np.float16),
            torch.half    : np.dtype(np.float16),
            torch.float32 : np.dtype(np.float32),
            torch.float   : np.dtype(np.float32),
            torch.float64 : np.dtype(np.float64),
            torch.double  : np.dtype(np.float64),
            }
    return dtype_dict[dtype]


# ---------------------- InferenceEngine internal types ------------------------ 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:24,代码来源:types.py

示例13: update_device_and_dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def update_device_and_dtype(state, *args, **kwargs):
    """Function gets data type and device values from the args / kwargs and updates state.

    Args:
        state (State): The :class:`.State` to update
        args: Arguments to the :func:`Trial.to` function
        kwargs: Keyword arguments to the :func:`Trial.to` function

    Returns:
        state
    """
    for key, _ in kwargs.items():
        if key == str(torchbearer.DATA_TYPE):
            state[torchbearer.DATA_TYPE] = kwargs['dtype']
        elif str(torchbearer.DEVICE) in kwargs:
            state[torchbearer.DEVICE] = kwargs['device']

    for arg in args:
        if isinstance(arg, torch.dtype):
            state[torchbearer.DATA_TYPE] = arg
        else:
            state[torchbearer.DEVICE] = arg

    return state 
开发者ID:pytorchbearer,项目名称:torchbearer,代码行数:26,代码来源:trial.py

示例14: protobuf_tensor_serializer

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def protobuf_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> TensorDataPB:
    """Strategy to serialize a tensor using Protobuf"""
    dtype = TORCH_DTYPE_STR[tensor.dtype]

    protobuf_tensor = TensorDataPB()

    if tensor.is_quantized:
        protobuf_tensor.is_quantized = True
        protobuf_tensor.scale = tensor.q_scale()
        protobuf_tensor.zero_point = tensor.q_zero_point()
        data = torch.flatten(tensor).int_repr().tolist()
    else:
        data = torch.flatten(tensor).tolist()

    protobuf_tensor.dtype = dtype
    protobuf_tensor.shape.dims.extend(tensor.size())
    getattr(protobuf_tensor, "contents_" + dtype).extend(data)

    return protobuf_tensor 
开发者ID:OpenMined,项目名称:PySyft,代码行数:21,代码来源:torch_serde.py

示例15: protobuf_tensor_deserializer

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dtype [as 别名]
def protobuf_tensor_deserializer(
    worker: AbstractWorker, protobuf_tensor: TensorDataPB
) -> torch.Tensor:
    """Strategy to deserialize a binary input using Protobuf"""
    size = tuple(protobuf_tensor.shape.dims)
    data = getattr(protobuf_tensor, "contents_" + protobuf_tensor.dtype)

    if protobuf_tensor.is_quantized:
        # Drop the 'q' from the beginning of the quantized dtype to get the int type
        dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype[1:]]
        int_tensor = torch.tensor(data, dtype=dtype).reshape(size)
        # Automatically converts int types to quantized types
        return torch._make_per_tensor_quantized_tensor(
            int_tensor, protobuf_tensor.scale, protobuf_tensor.zero_point
        )
    else:
        dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype]
        return torch.tensor(data, dtype=dtype).reshape(size) 
开发者ID:OpenMined,项目名称:PySyft,代码行数:20,代码来源:torch_serde.py


注:本文中的torch.dtype方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。