本文整理匯總了Python中torch.dtype方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.dtype方法的具體用法?Python torch.dtype怎麽用?Python torch.dtype使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch
的用法示例。
在下文中一共展示了torch.dtype方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: get_points
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def get_points(self, featmap_sizes, dtype, device, flatten=False):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
self._get_points_single(featmap_sizes[i], self.strides[i],
dtype, device, flatten))
return mlvl_points
示例2: patch_forward_method
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def patch_forward_method(func, src_type, dst_type, convert_output=True):
"""Patch the forward method of a module.
Args:
func (callable): The original forward method.
src_type (torch.dtype): Type of input arguments to be converted from.
dst_type (torch.dtype): Type of input arguments to be converted to.
convert_output (bool): Whether to convert the output back to src_type.
Returns:
callable: The patched forward method.
"""
def new_forward(*args, **kwargs):
output = func(*cast_tensor_type(args, src_type, dst_type),
**cast_tensor_type(kwargs, src_type, dst_type))
if convert_output:
output = cast_tensor_type(output, dst_type, src_type)
return output
return new_forward
示例3: get_points
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def get_points(self, featmap_sizes, dtype, device):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
self.get_points_single(featmap_sizes[i], self.strides[i],
dtype, device))
return mlvl_points
示例4: origin
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def origin(
self, *size, dtype=None, device=None, seed=42
) -> "geoopt.ManifoldTensor":
"""
Zero point origin.
Parameters
----------
size : shape
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored
Returns
-------
ManifoldTensor
"""
self._assert_check_shape(size2shape(*size), "x")
return geoopt.ManifoldTensor(
torch.zeros(*size, dtype=dtype, device=device), manifold=self
)
示例5: random_naive
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor:
"""
Naive approach to get random matrix on Birkhoff Polytope manifold.
A helper function to sample a random point on the Birkhoff Polytope manifold.
The measure is non-uniform for this method, but fast to compute.
Parameters
----------
size : shape
the desired output shape
dtype : torch.dtype
desired dtype
device : torch.device
desired device
Returns
-------
ManifoldTensor
random point on Birkhoff Polytope manifold
"""
self._assert_check_shape(size2shape(*size), "x")
# projection requires all values be non-negative
tens = torch.randn(*size, device=device, dtype=dtype).abs_()
return ManifoldTensor(self.projx(tens), manifold=self)
示例6: origin
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
"""
Identity matrix point origin.
Parameters
----------
size : shape
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored
Returns
-------
ManifoldTensor
"""
shape = size2shape(*size)
self._assert_check_shape(shape, "x")
eye = torch.eye(*shape[-2:], dtype=dtype, device=device)
eye = eye.expand(shape)
return ManifoldTensor(eye, manifold=self)
示例7: origin
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
"""
Identity matrix point origin.
Parameters
----------
size : shape
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored
Returns
-------
ManifoldTensor
"""
self._assert_check_shape(size2shape(*size), "x")
eye = torch.zeros(*size, dtype=dtype, device=device)
eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
return ManifoldTensor(eye, manifold=self)
示例8: degree
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def degree(index, num_nodes: Optional[int] = None,
dtype: Optional[int] = None):
r"""Computes the (unweighted) degree of a given one-dimensional index
tensor.
Args:
index (LongTensor): Index tensor.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
dtype (:obj:`torch.dtype`, optional): The desired data type of the
returned tensor.
:rtype: :class:`Tensor`
"""
N = maybe_num_nodes(index, num_nodes)
out = torch.zeros((N, ), dtype=dtype, device=index.device)
one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
return out.scatter_add_(0, index, one)
示例9: grid
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def grid(height, width, dtype=None, device=None):
r"""Returns the edge indices of a two-dimensional grid graph with height
:attr:`height` and width :attr:`width` and its node positions.
Args:
height (int): The height of the grid.
width (int): The width of the grid.
dtype (:obj:`torch.dtype`, optional): The desired data type of the
returned position tensor.
dtype (:obj:`torch.device`, optional): The desired device of the
returned tensors.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
edge_index = grid_index(height, width, device)
pos = grid_pos(height, width, dtype, device)
return edge_index, pos
示例10: grid_index
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def grid_index(height, width, device=None):
w = width
kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
kernel = torch.tensor(kernel, device=device)
row = torch.arange(height * width, dtype=torch.long, device=device)
row = row.view(-1, 1).repeat(1, kernel.size(0))
col = row + kernel.view(1, -1)
row, col = row.view(height, -1), col.view(height, -1)
index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device)
row, col = row[:, index].view(-1), col[:, index].view(-1)
mask = (col >= 0) & (col < height * width)
row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0)
edge_index, _ = coalesce(edge_index, None, height * width, height * width)
return edge_index
示例11: patch_forward_method
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def patch_forward_method(func, src_type, dst_type, convert_output=True):
"""Patch the forward method of a module.
Args:
func (callable): The original forward method.
src_type (torch.dtype): Type of input arguments to be converted from.
dst_type (torch.dtype): Type of input arguments to be converted to.
convert_output (bool): Whether to convert the output back to src_type.
Returns:
callable: The patched forward method.
"""
def new_forward(*args, **kwargs):
output = func(*cast_tensor_type(args, src_type, dst_type),
**cast_tensor_type(kwargs, src_type, dst_type))
if convert_output:
output = cast_tensor_type(output, dst_type, src_type)
return output
return new_forward
示例12: torch_dtype_to_np_dtype
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def torch_dtype_to_np_dtype(dtype):
dtype_dict = {
torch.bool : np.dtype(np.bool),
torch.uint8 : np.dtype(np.uint8),
torch.int8 : np.dtype(np.int8),
torch.int16 : np.dtype(np.int16),
torch.short : np.dtype(np.int16),
torch.int32 : np.dtype(np.int32),
torch.int : np.dtype(np.int32),
torch.int64 : np.dtype(np.int64),
torch.long : np.dtype(np.int64),
torch.float16 : np.dtype(np.float16),
torch.half : np.dtype(np.float16),
torch.float32 : np.dtype(np.float32),
torch.float : np.dtype(np.float32),
torch.float64 : np.dtype(np.float64),
torch.double : np.dtype(np.float64),
}
return dtype_dict[dtype]
# ---------------------- InferenceEngine internal types ------------------------
示例13: update_device_and_dtype
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def update_device_and_dtype(state, *args, **kwargs):
"""Function gets data type and device values from the args / kwargs and updates state.
Args:
state (State): The :class:`.State` to update
args: Arguments to the :func:`Trial.to` function
kwargs: Keyword arguments to the :func:`Trial.to` function
Returns:
state
"""
for key, _ in kwargs.items():
if key == str(torchbearer.DATA_TYPE):
state[torchbearer.DATA_TYPE] = kwargs['dtype']
elif str(torchbearer.DEVICE) in kwargs:
state[torchbearer.DEVICE] = kwargs['device']
for arg in args:
if isinstance(arg, torch.dtype):
state[torchbearer.DATA_TYPE] = arg
else:
state[torchbearer.DEVICE] = arg
return state
示例14: protobuf_tensor_serializer
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def protobuf_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> TensorDataPB:
"""Strategy to serialize a tensor using Protobuf"""
dtype = TORCH_DTYPE_STR[tensor.dtype]
protobuf_tensor = TensorDataPB()
if tensor.is_quantized:
protobuf_tensor.is_quantized = True
protobuf_tensor.scale = tensor.q_scale()
protobuf_tensor.zero_point = tensor.q_zero_point()
data = torch.flatten(tensor).int_repr().tolist()
else:
data = torch.flatten(tensor).tolist()
protobuf_tensor.dtype = dtype
protobuf_tensor.shape.dims.extend(tensor.size())
getattr(protobuf_tensor, "contents_" + dtype).extend(data)
return protobuf_tensor
示例15: protobuf_tensor_deserializer
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import dtype [as 別名]
def protobuf_tensor_deserializer(
worker: AbstractWorker, protobuf_tensor: TensorDataPB
) -> torch.Tensor:
"""Strategy to deserialize a binary input using Protobuf"""
size = tuple(protobuf_tensor.shape.dims)
data = getattr(protobuf_tensor, "contents_" + protobuf_tensor.dtype)
if protobuf_tensor.is_quantized:
# Drop the 'q' from the beginning of the quantized dtype to get the int type
dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype[1:]]
int_tensor = torch.tensor(data, dtype=dtype).reshape(size)
# Automatically converts int types to quantized types
return torch._make_per_tensor_quantized_tensor(
int_tensor, protobuf_tensor.scale, protobuf_tensor.zero_point
)
else:
dtype = TORCH_STR_DTYPE[protobuf_tensor.dtype]
return torch.tensor(data, dtype=dtype).reshape(size)