當前位置: 首頁>>代碼示例>>Python>>正文


Python data.Sampler方法代碼示例

本文整理匯總了Python中mxnet.gluon.data.Sampler方法的典型用法代碼示例。如果您正苦於以下問題:Python data.Sampler方法的具體用法?Python data.Sampler怎麽用?Python data.Sampler使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在mxnet.gluon.data的用法示例。


在下文中一共展示了data.Sampler方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from mxnet.gluon import data [as 別名]
# 或者: from mxnet.gluon.data import Sampler [as 別名]
def __init__(self, batch_size, cls_idx_dict1, cls_idx_dict2, ratio=1):
        """
        Balanced Two Steam Sampler, use cls_idx_dict1 as main dictinary and list
        :param batch_size: batch size
        :param cls_idx_dict1: class index dictionary
        :param cls_idx_dict2: class index dictionary
        :param ratio: negative / positive flag
        """
        self.batch_size = batch_size
        self.cls_idx_dict1 = cls_idx_dict1
        self.cls_idx_dict2 = cls_idx_dict2
        self.ratio = ratio

        assert set(cls_idx_dict1.keys()) == set(cls_idx_dict2.keys()), 'The labels of two classes are not consistent'

        self.n_cls = len(cls_idx_dict1.keys())
        self.n_samples = self.batch_size // self.n_cls

        assert self.batch_size >= self.n_cls, "batch size should equal or larger than number of classes"

        self.length = self.cal_len() 
開發者ID:aws-samples,項目名稱:d-SNE,代碼行數:23,代碼來源:samplers.py

示例2: forward

# 需要導入模塊: from mxnet.gluon import data [as 別名]
# 或者: from mxnet.gluon.data import Sampler [as 別名]
def forward(self, matches, ious):
        """Quota Sampler

        Parameters:
        ----------
        matches : NDArray or Symbol
            Matching results, positive number for positive matching, -1 for not matched.
        ious : NDArray or Symbol
            IOU overlaps with shape (N, M), batching is supported.

        Returns:
        --------
        NDArray or Symbol
            Sampling results with same shape as ``matches``.
            1 for positive, -1 for negative, 0 for ignore.

        """
        F = mx.nd
        max_pos = int(round(self._pos_ratio * self._num_sample))
        max_neg = int(self._neg_ratio * self._num_sample)
        results = []
        for i in range(matches.shape[0]):
            # init with 0s, which are ignored
            result = F.zeros_like(matches[0])
            # positive samples
            ious_max = ious.max(axis=-1)[i]
            result = F.where(matches[i] >= 0, F.ones_like(result), result)
            result = F.where(ious_max >= self._pos_thresh, F.ones_like(result), result)
            # negative samples with label -1
            neg_mask = ious_max < self._neg_thresh_high
            neg_mask = neg_mask * (ious_max >= self._neg_thresh_low)
            result = F.where(neg_mask, F.ones_like(result) * -1, result)

            # re-balance if number of positive or negative exceed limits
            result = result.asnumpy()
            num_pos = int((result > 0).sum())
            if num_pos > max_pos:
                disable_indices = np.random.choice(
                    np.where(result > 0)[0], size=(num_pos - max_pos), replace=False)
                result[disable_indices] = 0  # use 0 to ignore
            num_neg = int((result < 0).sum())
            if self._fill_negative:
                # if pos_sample is less than quota, we can have negative samples filling the gap
                max_neg = max(self._num_sample - min(num_pos, max_pos), max_neg)
            if num_neg > max_neg:
                disable_indices = np.random.choice(
                    np.where(result < 0)[0], size=(num_neg - max_neg), replace=False)
                result[disable_indices] = 0
            results.append(mx.nd.array(result))

        return mx.nd.stack(*results, axis=0) 
開發者ID:dmlc,項目名稱:gluon-cv,代碼行數:53,代碼來源:sampler.py


注:本文中的mxnet.gluon.data.Sampler方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。