本文整理汇总了Python中mmdet.ops.sigmoid_focal_loss方法的典型用法代码示例。如果您正苦于以下问题:Python ops.sigmoid_focal_loss方法的具体用法?Python ops.sigmoid_focal_loss怎么用?Python ops.sigmoid_focal_loss使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mmdet.ops
的用法示例。
在下文中一共展示了ops.sigmoid_focal_loss方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: sigmoid_focal_loss
# 需要导入模块: from mmdet import ops [as 别名]
# 或者: from mmdet.ops import sigmoid_focal_loss [as 别名]
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred, target, gamma, alpha)
# TODO: find a proper way to handle the shape of weight
if weight is not None:
weight = weight.view(-1, 1)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
示例2: sigmoid_focal_loss
# 需要导入模块: from mmdet import ops [as 别名]
# 或者: from mmdet.ops import sigmoid_focal_loss [as 别名]
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A warpper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred, target, gamma, alpha)
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss