當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.dtype方法代碼示例

本文整理匯總了Python中torch.dtype方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.dtype方法的具體用法?Python torch.dtype怎麽用?Python torch.dtype使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.dtype方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_points

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def get_points(self, featmap_sizes, dtype, device, flatten=False):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self._get_points_single(featmap_sizes[i], self.strides[i],
                                        dtype, device, flatten))
        return mlvl_points 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:19,代碼來源:anchor_free_head.py

示例2: patch_forward_method

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.

    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.

    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:23,代碼來源:hooks.py

示例3: get_points

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def get_points(self, featmap_sizes, dtype, device):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self.get_points_single(featmap_sizes[i], self.strides[i],
                                       dtype, device))
        return mlvl_points 
開發者ID:dingjiansw101,項目名稱:AerialDetection,代碼行數:19,代碼來源:fcos_head.py

示例4: origin

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def origin(
        self, *size, dtype=None, device=None, seed=42
    ) -> "geoopt.ManifoldTensor":
        """
        Zero point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        return geoopt.ManifoldTensor(
            torch.zeros(*size, dtype=dtype, device=device), manifold=self
        ) 
開發者ID:geoopt,項目名稱:geoopt,代碼行數:27,代碼來源:euclidean.py

示例5: random_naive

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor:
        """
        Naive approach to get random matrix on Birkhoff Polytope manifold.

        A helper function to sample a random point on the Birkhoff Polytope manifold.
        The measure is non-uniform for this method, but fast to compute.

        Parameters
        ----------
        size : shape
            the desired output shape
        dtype : torch.dtype
            desired dtype
        device : torch.device
            desired device

        Returns
        -------
        ManifoldTensor
            random point on Birkhoff Polytope manifold
        """
        self._assert_check_shape(size2shape(*size), "x")
        # projection requires all values be non-negative
        tens = torch.randn(*size, device=device, dtype=dtype).abs_()
        return ManifoldTensor(self.projx(tens), manifold=self) 
開發者ID:geoopt,項目名稱:geoopt,代碼行數:27,代碼來源:birkhoff_polytope.py

示例6: origin

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        shape = size2shape(*size)
        self._assert_check_shape(shape, "x")
        eye = torch.eye(*shape[-2:], dtype=dtype, device=device)
        eye = eye.expand(shape)
        return ManifoldTensor(eye, manifold=self) 
開發者ID:geoopt,項目名稱:geoopt,代碼行數:26,代碼來源:birkhoff_polytope.py

示例7: origin

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        eye = torch.zeros(*size, dtype=dtype, device=device)
        eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
        return ManifoldTensor(eye, manifold=self) 
開發者ID:geoopt,項目名稱:geoopt,代碼行數:25,代碼來源:stiefel.py

示例8: degree

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def degree(index, num_nodes: Optional[int] = None,
           dtype: Optional[int] = None):
    r"""Computes the (unweighted) degree of a given one-dimensional index
    tensor.

    Args:
        index (LongTensor): Index tensor.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned tensor.

    :rtype: :class:`Tensor`
    """
    N = maybe_num_nodes(index, num_nodes)
    out = torch.zeros((N, ), dtype=dtype, device=index.device)
    one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
    return out.scatter_add_(0, index, one) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:20,代碼來源:degree.py

示例9: grid

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def grid(height, width, dtype=None, device=None):
    r"""Returns the edge indices of a two-dimensional grid graph with height
    :attr:`height` and width :attr:`width` and its node positions.

    Args:
        height (int): The height of the grid.
        width (int): The width of the grid.
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned position tensor.
        dtype (:obj:`torch.device`, optional): The desired device of the
            returned tensors.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """

    edge_index = grid_index(height, width, device)
    pos = grid_pos(height, width, dtype, device)
    return edge_index, pos 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:20,代碼來源:grid.py

示例10: grid_index

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def grid_index(height, width, device=None):
    w = width
    kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
    kernel = torch.tensor(kernel, device=device)

    row = torch.arange(height * width, dtype=torch.long, device=device)
    row = row.view(-1, 1).repeat(1, kernel.size(0))
    col = row + kernel.view(1, -1)
    row, col = row.view(height, -1), col.view(height, -1)
    index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device)
    row, col = row[:, index].view(-1), col[:, index].view(-1)

    mask = (col >= 0) & (col < height * width)
    row, col = row[mask], col[mask]

    edge_index = torch.stack([row, col], dim=0)
    edge_index, _ = coalesce(edge_index, None, height * width, height * width)

    return edge_index 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:21,代碼來源:grid.py

示例11: patch_forward_method

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def patch_forward_method(func, src_type, dst_type, convert_output=True):
    """Patch the forward method of a module.
    Args:
        func (callable): The original forward method.
        src_type (torch.dtype): Type of input arguments to be converted from.
        dst_type (torch.dtype): Type of input arguments to be converted to.
        convert_output (bool): Whether to convert the output back to src_type.
    Returns:
        callable: The patched forward method.
    """

    def new_forward(*args, **kwargs):
        output = func(*cast_tensor_type(args, src_type, dst_type),
                      **cast_tensor_type(kwargs, src_type, dst_type))
        if convert_output:
            output = cast_tensor_type(output, dst_type, src_type)
        return output

    return new_forward 
開發者ID:DeepMotionAIResearch,項目名稱:DenseMatchingBenchmark,代碼行數:21,代碼來源:hooks.py

示例12: torch_dtype_to_np_dtype

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def torch_dtype_to_np_dtype(dtype):
    dtype_dict = {
            torch.bool    : np.dtype(np.bool),
            torch.uint8   : np.dtype(np.uint8),
            torch.int8    : np.dtype(np.int8),
            torch.int16   : np.dtype(np.int16),
            torch.short   : np.dtype(np.int16),
            torch.int32   : np.dtype(np.int32),
            torch.int     : np.dtype(np.int32),
            torch.int64   : np.dtype(np.int64),
            torch.long    : np.dtype(np.int64),
            torch.float16 : np.dtype(np.float16),
            torch.half    : np.dtype(np.float16),
            torch.float32 : np.dtype(np.float32),
            torch.float   : np.dtype(np.float32),
            torch.float64 : np.dtype(np.float64),
            torch.double  : np.dtype(np.float64),
            }
    return dtype_dict[dtype]


# ---------------------- InferenceEngine internal types ------------------------ 
開發者ID:pfnet-research,項目名稱:chainer-compiler,代碼行數:24,代碼來源:types.py

示例13: update_device_and_dtype

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def update_device_and_dtype(state, *args, **kwargs):
    """Function gets data type and device values from the args / kwargs and updates state.

    Args:
        state (State): The :class:`.State` to update
        args: Arguments to the :func:`Trial.to` function
        kwargs: Keyword arguments to the :func:`Trial.to` function

    Returns:
        state
    """
    for key, _ in kwargs.items():
        if key == str(torchbearer.DATA_TYPE):
            state[torchbearer.DATA_TYPE] = kwargs['dtype']
        elif str(torchbearer.DEVICE) in kwargs:
            state[torchbearer.DEVICE] = kwargs['device']

    for arg in args:
        if isinstance(arg, torch.dtype):
            state[torchbearer.DATA_TYPE] = arg
        else:
            state[torchbearer.DEVICE] = arg

    return state 
開發者ID:pytorchbearer,項目名稱:torchbearer,代碼行數:26,代碼來源:trial.py

示例14: protobuf_tensor_serializer

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def protobuf_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> TensorDataPB:
    """Strategy to serialize a tensor using Protobuf"""
    dtype = TORCH_DTYPE_STR[tensor.dtype]

    protobuf_tensor = TensorDataPB()

    if tensor.is_quantized:
        protobuf_tensor.is_quantized = True
        protobuf_tensor.scale = tensor.q_scale()
        protobuf_tensor.zero_point = tensor.q_zero_point()
        data = torch.flatten(tensor).int_repr().tolist()
    else:
        data = torch.flatten(tensor).tolist()

    protobuf_tensor.dtype = dtype
    protobuf_tensor.shape.dims.extend(tensor.size())
    getattr(protobuf_tensor, "contents_" + dtype).extend(data)

    return protobuf_tensor 
開發者ID:OpenMined,項目名稱:PySyft,代碼行數:21,代碼來源:torch_serde.py

示例15: protobuf_tensor_deserializer

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def protobuf_tensor_deserializer(
    worker: AbstractWorker, protobuf_tensor: TensorDataPB
) -> torch.Tensor:
    """Strategy to deserialize a binary input using Protobuf"""
    size = tuple(protobuf_tensor.shape.dims)
    data = getattr(protobuf_tensor, "contents_" + protobuf_tensor.dtype)

    if protobuf_tensor.is_quantized:
        # Drop the 'q' from the beginning of the quantized dtype to get the int type
        dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype[1:]]
        int_tensor = torch.tensor(data, dtype=dtype).reshape(size)
        # Automatically converts int types to quantized types
        return torch._make_per_tensor_quantized_tensor(
            int_tensor, protobuf_tensor.scale, protobuf_tensor.zero_point
        )
    else:
        dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype]
        return torch.tensor(data, dtype=dtype).reshape(size) 
開發者ID:OpenMined,項目名稱:PySyft,代碼行數:20,代碼來源:torch_serde.py


注:本文中的torch.dtype方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。