当前位置: 首页>>代码示例>>Python>>正文


Python ndarray.equal方法代码示例

本文整理汇总了Python中mxnet.ndarray.equal方法的典型用法代码示例。如果您正苦于以下问题:Python ndarray.equal方法的具体用法?Python ndarray.equal怎么用?Python ndarray.equal使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet.ndarray的用法示例。


在下文中一共展示了ndarray.equal方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: batch_intersection_union

# 需要导入模块: from mxnet import ndarray [as 别名]
# 或者: from mxnet.ndarray import equal [as 别名]
def batch_intersection_union(output, target, nclass):
    """mIoU"""
    # inputs are NDarray, output 4D, target 3D
    predict = F.argmax(output, 1)
    target = target.astype(predict.dtype)
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = predict.asnumpy() + 1
    target = target.asnumpy() + 1

    predict = predict * (target > 0).astype(predict.dtype)
    #intersection = predict * (F.equal(predict, target)).astype(predict.dtype)
    intersection = predict * (predict == target)
    # areas of intersection and union
    area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
    area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
    area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
    area_union = area_pred + area_lab - area_inter
    return area_inter, area_union 
开发者ID:zzdang,项目名称:cascade_rcnn_gluon,代码行数:22,代码来源:voc_segmentation.py

示例2: array_equal

# 需要导入模块: from mxnet import ndarray [as 别名]
# 或者: from mxnet.ndarray import equal [as 别名]
def array_equal(a, b):
    return nd.equal(a, b).asnumpy().all() 
开发者ID:dmlc,项目名称:dgl,代码行数:4,代码来源:__init__.py

示例3: hybrid_forward

# 需要导入模块: from mxnet import ndarray [as 别名]
# 或者: from mxnet.ndarray import equal [as 别名]
def hybrid_forward(self, F, fts, ys, ftt, yt):
        """
        Semantic Alignment Loss
        :param F: Function
        :param yt: label for the target domain [N]
        :param ftt: features for the target domain [N, K]
        :param ys: label for the source domain [M]
        :param fts: features for the source domain [M, K]
        :return:
        """
        if self._fn:
            # Normalize ft
            fts = F.L2Normalization(fts, mode='instance')
            ftt = F.L2Normalization(ftt, mode='instance')

        fts_rpt = F.broadcast_to(fts.expand_dims(axis=0), shape=(self._bs_tgt, self._bs_src, self._embed_size))
        ftt_rpt = F.broadcast_to(ftt.expand_dims(axis=1), shape=(self._bs_tgt, self._bs_src, self._embed_size))

        dists = F.sum(F.square(ftt_rpt - fts_rpt), axis=2)

        yt_rpt = F.broadcast_to(yt.expand_dims(axis=1), shape=(self._bs_tgt, self._bs_src)).astype('int32')
        ys_rpt = F.broadcast_to(ys.expand_dims(axis=0), shape=(self._bs_tgt, self._bs_src)).astype('int32')

        y_same = F.equal(yt_rpt, ys_rpt).astype('float32')
        y_diff = F.not_equal(yt_rpt, ys_rpt).astype('float32')

        intra_cls_dists = dists * y_same
        inter_cls_dists = dists * y_diff

        max_dists = F.max(dists, axis=1, keepdims=True)
        max_dists = F.broadcast_to(max_dists, shape=(self._bs_tgt, self._bs_src))
        revised_inter_cls_dists = F.where(y_same, max_dists, inter_cls_dists)

        max_intra_cls_dist = F.max(intra_cls_dists, axis=1)
        min_inter_cls_dist = F.min(revised_inter_cls_dists, axis=1)

        loss = F.relu(max_intra_cls_dist - min_inter_cls_dist + self._margin)

        return loss 
开发者ID:aws-samples,项目名称:d-SNE,代码行数:41,代码来源:custom_layers.py


注:本文中的mxnet.ndarray.equal方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。