本文整理汇总了Python中mxnet.ndarray.equal方法的典型用法代码示例。如果您正苦于以下问题:Python ndarray.equal方法的具体用法?Python ndarray.equal怎么用?Python ndarray.equal使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet.ndarray
的用法示例。
在下文中一共展示了ndarray.equal方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: batch_intersection_union
# 需要导入模块: from mxnet import ndarray [as 别名]
# 或者: from mxnet.ndarray import equal [as 别名]
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are NDarray, output 4D, target 3D
predict = F.argmax(output, 1)
target = target.astype(predict.dtype)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.asnumpy() + 1
target = target.asnumpy() + 1
predict = predict * (target > 0).astype(predict.dtype)
#intersection = predict * (F.equal(predict, target)).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
return area_inter, area_union
示例2: array_equal
# 需要导入模块: from mxnet import ndarray [as 别名]
# 或者: from mxnet.ndarray import equal [as 别名]
def array_equal(a, b):
return nd.equal(a, b).asnumpy().all()
示例3: hybrid_forward
# 需要导入模块: from mxnet import ndarray [as 别名]
# 或者: from mxnet.ndarray import equal [as 别名]
def hybrid_forward(self, F, fts, ys, ftt, yt):
"""
Semantic Alignment Loss
:param F: Function
:param yt: label for the target domain [N]
:param ftt: features for the target domain [N, K]
:param ys: label for the source domain [M]
:param fts: features for the source domain [M, K]
:return:
"""
if self._fn:
# Normalize ft
fts = F.L2Normalization(fts, mode='instance')
ftt = F.L2Normalization(ftt, mode='instance')
fts_rpt = F.broadcast_to(fts.expand_dims(axis=0), shape=(self._bs_tgt, self._bs_src, self._embed_size))
ftt_rpt = F.broadcast_to(ftt.expand_dims(axis=1), shape=(self._bs_tgt, self._bs_src, self._embed_size))
dists = F.sum(F.square(ftt_rpt - fts_rpt), axis=2)
yt_rpt = F.broadcast_to(yt.expand_dims(axis=1), shape=(self._bs_tgt, self._bs_src)).astype('int32')
ys_rpt = F.broadcast_to(ys.expand_dims(axis=0), shape=(self._bs_tgt, self._bs_src)).astype('int32')
y_same = F.equal(yt_rpt, ys_rpt).astype('float32')
y_diff = F.not_equal(yt_rpt, ys_rpt).astype('float32')
intra_cls_dists = dists * y_same
inter_cls_dists = dists * y_diff
max_dists = F.max(dists, axis=1, keepdims=True)
max_dists = F.broadcast_to(max_dists, shape=(self._bs_tgt, self._bs_src))
revised_inter_cls_dists = F.where(y_same, max_dists, inter_cls_dists)
max_intra_cls_dist = F.max(intra_cls_dists, axis=1)
min_inter_cls_dist = F.min(revised_inter_cls_dists, axis=1)
loss = F.relu(max_intra_cls_dist - min_inter_cls_dist + self._margin)
return loss