本文整理汇总了Python中torch.finfo方法的典型用法代码示例。如果您正苦于以下问题:Python torch.finfo方法的具体用法?Python torch.finfo怎么用?Python torch.finfo使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.finfo方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def __init__(self, ignore_index, reduction, normalize_targets):
"""Intializer for the soft target cross-entropy loss loss.
This allows the targets for the cross entropy loss to be multilabel
Config params:
'weight': weight of sample (not yet implemented),
'ignore_index': sample should be ignored for loss (optional),
'reduction': specifies reduction to apply to the output (optional),
"""
super(SoftTargetCrossEntropyLoss, self).__init__()
self._ignore_index = ignore_index
self._reduction = reduction
assert normalize_targets in [None, "count_based"]
self._normalize_targets = normalize_targets
if self._reduction != "mean":
raise NotImplementedError(
'reduction type "{}" not implemented'.format(self._reduction)
)
self._eps = torch.finfo(torch.float32).eps
示例2: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
"""
Calculate rank cross entropy loss.
:param y_pred: Predicted result.
:param y_true: Label.
:return: Rank cross loss.
"""
logits = y_pred[::(self.num_neg + 1), :]
labels = y_true[::(self.num_neg + 1), :]
for neg_idx in range(self.num_neg):
neg_logits = y_pred[(neg_idx + 1)::(self.num_neg + 1), :]
neg_labels = y_true[(neg_idx + 1)::(self.num_neg + 1), :]
logits = torch.cat((logits, neg_logits), dim=-1)
labels = torch.cat((labels, neg_labels), dim=-1)
return -torch.mean(
torch.sum(
labels * torch.log(F.softmax(logits, dim=-1) + torch.finfo(float).eps),
dim=-1
)
)
示例3: _quadratic_expand
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def _quadratic_expand(x, y):
"""
Helper function to calculate quadratic expansion |x-y|**2=|x|**2 + |y|**2 - 2xy
Parameters
----------
x : torch.tensor
2D tensor of size m x f
y : torch.tensor
2D tensor of size n x f
Returns
-------
torch.tensor
2D tensor of size m x n
"""
x_norm = (x ** 2).sum(1).view(-1, 1)
y_t = torch.transpose(y, 0, 1)
y_norm = (y ** 2).sum(1).view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
info = torch.finfo(dist.dtype)
return torch.clamp(dist, 0.0, info.max)
示例4: safe_cumprod
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def safe_cumprod(x: torch.Tensor,
*args,
**kwargs) -> torch.Tensor:
r"""Computes cumprod of x in logspace using cumsum to avoid underflow.
The cumprod function and its gradient can result in numerical
instabilities when its argument has very small and/or zero values.
As long as the argument is all positive, we can instead compute the
cumulative product as `exp(cumsum(log(x)))`. This function can be called
identically to :torch:`cumprod`.
Args:
x: Tensor to take the cumulative product of.
*args: Passed on to cumsum; these are identical to those in cumprod.
**kwargs: Passed on to cumsum; these are identical to those in cumprod.
Returns:
Cumulative product of x.
"""
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
tiny = torch.finfo(x.dtype).tiny
return torch.exp(torch.cumsum(torch.log(torch.clamp(x, tiny, 1)),
*args, **kwargs))
示例5: topk
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def topk(x, ratio, batch, min_score=None, tol=1e-7):
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter_max(x, batch)[0][batch] - tol
scores_min = scores_max.clamp(max=min_score)
perm = torch.nonzero(x > scores_min).view(-1)
else:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ),
torch.finfo(x.dtype).min)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
_, perm = dense_x.sort(dim=-1, descending=True)
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
return perm
示例6: rodrigues
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def rodrigues(self, r):
"""
Rodrigues' rotation formula that turns axis-angle tensor into rotation
matrix in a batch-ed manner.
Parameter:
----------
r: Axis-angle rotation tensor of shape [N, 1, 3].
Return:
-------
Rotation matrix of shape [N, 3, 3].
"""
theta = torch.norm(r, dim=(1, 2), keepdim=True)
# avoid division by zero
torch.max(theta, theta.new_full((1,), torch.finfo(theta.dtype).tiny), out=theta)
#The .tiny has to be uploaded to GPU, but self.regress_joints is such a big bottleneck it is not felt.
r_hat = r / theta
z_stick = torch.zeros_like(r_hat[:, 0, 0])
m = torch.stack(
(z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1],
r_hat[:, 0, 2], z_stick, -r_hat[:, 0, 0],
-r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick), dim=1)
m = m.reshape(-1, 3, 3)
dot = torch.bmm(r_hat.transpose(1, 2), r_hat) # Batched outer product.
# torch.matmul or torch.stack([torch.ger(r, r) for r in r_hat.squeeze(1)] works too.
cos = theta.cos()
R = cos * self.eye + (1 - cos) * dot + theta.sin() * m
return R
示例7: _finfo
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def _finfo(x):
return torch.finfo(x.dtype)
示例8: _reciprocal
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def _reciprocal(x):
result = x.reciprocal().clamp(max=torch.finfo(x.dtype).max)
return result
示例9: _safediv
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def _safediv(x, y):
try:
finfo = torch.finfo(y.dtype)
except TypeError:
finfo = torch.iinfo(y.dtype)
return x * y.reciprocal().clamp(max=finfo.max)
示例10: _safesub
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def _safesub(x, y):
try:
finfo = torch.finfo(y.dtype)
except TypeError:
finfo = torch.iinfo(y.dtype)
return x + (-y).clamp(max=finfo.max)
示例11: info_value_of_dtype
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def info_value_of_dtype(dtype: torch.dtype):
"""
Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
"""
if dtype == torch.bool:
raise TypeError("Does not support torch.bool")
elif dtype.is_floating_point:
return torch.finfo(dtype)
else:
return torch.iinfo(dtype)
示例12: get_loss
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def get_loss(self, y_pred, y_true, *args, **kwargs):
if isinstance(self.criterion_, torch.nn.NLLLoss):
eps = torch.finfo(y_pred.dtype).eps
y_pred = torch.log(y_pred + eps)
return super().get_loss(y_pred, y_true, *args, **kwargs)
# pylint: disable=signature-differs
示例13: compute_class_ap
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def compute_class_ap(self, tp, fp, npos):
"""
Args:
tp (Tensor, shape [N*D]): cumulative sum of true positive detections
fp (Tensor, shape [N*D]): cumulative sum of false positive detections
npos (Tensor, int): actual positives (from ground truth)
Return:
ap (Tensor, float): average precision calculation
"""
#Values for precision-recall curve
rc = tp/npos
pr = tp / torch.clamp(tp + fp, min=torch.finfo(torch.float).eps)
rc_values = torch.linspace(0,1,self.num_points) #sampled recall points for n-point precision-recall curve
#The interpotaled P-R curve will take on the max precision value to the right at each recall
ap = 0.
for t in rc_values:
if torch.sum(rc >= t) == 0:
p = 0
else:
p = torch.max(pr[rc >= t])
ap = ap + p/self.num_points
return ap
示例14: __new__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def __new__(cls, dtype):
try:
dtype = heat_type_of(dtype)
except (KeyError, IndexError, TypeError):
# If given type is not heat type
pass
if dtype not in _inexact:
raise TypeError("Data type {} not inexact, not supported".format(dtype))
return super(finfo, cls).__new__(cls)._init(dtype)
示例15: _init
# 需要导入模块: import torch [as 别名]
# 或者: from torch import finfo [as 别名]
def _init(self, dtype):
_torch_finfo = torch.finfo(dtype.torch_type())
for word in ["bits", "eps", "max", "tiny"]:
setattr(self, word, getattr(_torch_finfo, word))
self.min = -self.max
return self