本文整理匯總了Python中mmdet.ops.sigmoid_focal_loss方法的典型用法代碼示例。如果您正苦於以下問題:Python ops.sigmoid_focal_loss方法的具體用法?Python ops.sigmoid_focal_loss怎麽用?Python ops.sigmoid_focal_loss使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類mmdet.ops
的用法示例。
在下文中一共展示了ops.sigmoid_focal_loss方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: sigmoid_focal_loss
# 需要導入模塊: from mmdet import ops [as 別名]
# 或者: from mmdet.ops import sigmoid_focal_loss [as 別名]
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred, target, gamma, alpha)
# TODO: find a proper way to handle the shape of weight
if weight is not None:
weight = weight.view(-1, 1)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
示例2: sigmoid_focal_loss
# 需要導入模塊: from mmdet import ops [as 別名]
# 或者: from mmdet.ops import sigmoid_focal_loss [as 別名]
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A warpper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred, target, gamma, alpha)
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss