当前位置: 首页>>代码示例>>Python>>正文


Python torch.kthvalue方法代码示例

本文整理汇总了Python中torch.kthvalue方法的典型用法代码示例。如果您正苦于以下问题:Python torch.kthvalue方法的具体用法?Python torch.kthvalue怎么用?Python torch.kthvalue使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.kthvalue方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _bbox_forward_train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
                            img_metas):
        num_imgs = len(img_metas)
        rois = bbox2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(x, rois)

        bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
                                                  gt_labels, self.train_cfg)
        # record the `beta_topk`-th smallest target
        # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
        # and bbox_weights, respectively
        pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
        num_pos = len(pos_inds)
        cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
        beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
                        num_pos)
        cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
        self.beta_history.append(cur_target)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:26,代码来源:dynamic_roi_head.py

示例2: select_over_all_levels

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def select_over_all_levels(self, boxlists):
        num_images = len(boxlists)
        results = []
        for i in range(num_images):
            # multiclass nms
            result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
            number_of_detections = len(result)

            # Limit to max_per_image detections **over all classes**
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                cls_scores = result.get_field("scores")
                image_thresh, _ = torch.kthvalue(
                    cls_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1
                )
                keep = cls_scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                result = result[keep]
            results.append(result)
        return results 
开发者ID:yinghdb,项目名称:EmbedMask,代码行数:22,代码来源:inference.py

示例3: compute_image

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def compute_image(self, x):
        grad_img = x.grad.abs().sum(1, keepdim=True)
        b, c, h, w = grad_img.shape
        gi_flat = grad_img.view(b, c, -1)
        cl = torch.kthvalue(gi_flat, int(grad_img[0].numel() * 0.99),
                            dim=-1)[0]
        cl = cl.unsqueeze(-1).unsqueeze(-1)
        grad_img = torch.min(grad_img, cl) / cl
        x = x.detach()
        xm = x.min()
        xM = x.max()
        x = (x - xm) / (xM - xm)
        img = x * grad_img + 0.5 * (1 - grad_img)
        return img 
开发者ID:Vermeille,项目名称:Torchelie,代码行数:16,代码来源:callbacks.py

示例4: _pvalue

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def _pvalue(data: torch.Tensor, ratio: float = 0.25, **kwargs):
    """
    Finds the P-(ratio* 100)'s value in the tensor, equivalent
    to the kth largest element where k = ratio * len(data)
    """
    cut = max(1, int(data.numel() * (1 - ratio)))
    return torch.kthvalue(data, cut)[0].item() 
开发者ID:facebookresearch,项目名称:pytorch-dp,代码行数:9,代码来源:utils.py

示例5: binarize_mask

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def binarize_mask(mask):
    with torch.no_grad():
        avg = F.avg_pool2d(mask, 224, stride=1).squeeze()
        flat_mask = mask.cpu().view(mask.size(0), -1)
        binarized_mask = torch.zeros_like(flat_mask)
        for i in range(mask.size(0)):
            kth = 1 + int((flat_mask[i].size(0) - 1) * (1 - avg[i].item()) + 0.5)
            th, _ = torch.kthvalue(flat_mask[i], kth)
            th.clamp_(1e-6, 1 - 1e-6)
            binarized_mask[i] = flat_mask[i].gt(th).float()
        binarized_mask = binarized_mask.view(mask.size())

        return binarized_mask 
开发者ID:kondiz,项目名称:casme,代码行数:15,代码来源:model_basics.py

示例6: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def __call__(self, images, *targets_list):
        import matplotlib.pyplot as plt
        import seaborn as sbn
        if (self.counter + 1) % self.show_iter != 0:
            self.counter += 1
            return
        self.counter += 1
        colors = sbn.color_palette(n_colors=len(targets_list))
        img = images.tensors[0].permute((1, 2, 0)).cpu().numpy() + self.image_mean
        img = img[:, :, [2, 1, 0]]
        plt.imshow(img/255)
        title = "boxes:"
        for ci, targets in enumerate(targets_list):
            if targets is not None:
                bboxes = targets[0].bbox.detach().cpu().numpy().tolist()
                scores = targets[0].extra_fields['scores'].detach().cpu() if 'scores' in targets[0].extra_fields else None
                locations = targets[0].extra_fields['det_locations'].detach().cpu() if 'det_locations' in targets[0].extra_fields else None
                labels = targets[0].extra_fields['labels'].cpu()
                if scores is None or len(scores) == 0:
                    self.plot1(bboxes, scores, locations, labels, None, (1, 0, 0))  # ground-truth
                else:
                    score_th = -torch.kthvalue(-scores, self.show_score_topk)[0]\
                        if self.score_th is None else self.score_th
                    self.plot(bboxes, scores, locations, labels, score_th, colors[ci])
                count = len(targets[0].bbox) if scores is None else (scores > score_th).sum()
                title += "{}({}) ".format(count, len(targets[0].bbox))
        plt.title(title)
        plt.show()
        input() 
开发者ID:ucas-vg,项目名称:TinyBenchmark,代码行数:31,代码来源:locnet.py

示例7: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def __call__(self, images, *targets_list):
        import matplotlib.pyplot as plt
        import seaborn as sbn
        if (self.counter + 1) % self.show_iter != 0:
            self.counter += 1
            return
        self.counter += 1
        colors = sbn.color_palette(n_colors=len(targets_list))
        img = images.tensors[0].permute((1, 2, 0)).cpu().numpy() + self.image_mean
        img = img[:, :, [2, 1, 0]]
        plt.imshow(img/255)
        title = "boxes:"
        for ci, targets in enumerate(targets_list):
            if targets is not None:
                bboxes = targets[0].bbox.cpu().numpy().tolist()
                scores = targets[0].extra_fields['scores'].cpu() if 'scores' in targets[0].extra_fields else None
                locations = targets[0].extra_fields['det_locations'].cpu() if 'det_locations' in targets[0].extra_fields else None
                labels = targets[0].extra_fields['labels'].cpu()
                if scores is None:
                    self.plot1(bboxes, scores, locations, labels, None, (1, 0, 0))  # ground-truth
                else:
                    score_th = -torch.kthvalue(-scores, self.show_score_topk)[0]\
                        if self.score_th is None else self.score_th
                    self.plot(bboxes, scores, locations, labels, score_th, colors[ci])
                count = len(targets[0].bbox) if scores is None else (scores > score_th).sum()
                title += "{}({}) ".format(count, len(targets[0].bbox))
        plt.title(title)
        plt.show()
        input() 
开发者ID:ucas-vg,项目名称:TinyBenchmark,代码行数:31,代码来源:cascade_fcos.py

示例8: filter_results

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(
                boxlist_for_class, self.nms
            )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:44,代码来源:inference.py

示例9: select_over_all_levels

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def select_over_all_levels(self, boxlists):
        num_images = len(boxlists)
        results = []
        for i in range(num_images):
            scores = boxlists[i].get_field("scores")
            labels = boxlists[i].get_field("labels")
            boxes = boxlists[i].bbox
            boxlist = boxlists[i]
            result = []
            # skip the background
            for j in range(1, self.num_classes):
                inds = (labels == j).nonzero().view(-1)

                scores_j = scores[inds]
                boxes_j = boxes[inds, :].view(-1, 4)
                boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
                boxlist_for_class.add_field("scores", scores_j)
                boxlist_for_class = boxlist_nms(
                    boxlist_for_class, self.nms_thresh,
                    score_field="scores"
                )
                num_labels = len(boxlist_for_class)
                boxlist_for_class.add_field(
                    "labels", torch.full((num_labels,), j,
                                         dtype=torch.int64,
                                         device=scores.device)
                )
                result.append(boxlist_for_class)

            result = cat_boxlist(result)
            number_of_detections = len(result)

            # Limit to max_per_image detections **over all classes**
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                cls_scores = result.get_field("scores")
                image_thresh, _ = torch.kthvalue(
                    cls_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1
                )
                keep = cls_scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                result = result[keep]
            results.append(result)
        return results 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:46,代码来源:inference.py

示例10: filter_results

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4: (j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class_old = boxlist_for_class
            if cfg.TEST.SOFT_NMS.ENABLED:
                boxlist_for_class = boxlist_soft_nms(
                    boxlist_for_class,
                    sigma=cfg.TEST.SOFT_NMS.SIGMA,
                    overlap_thresh=self.nms,
                    score_thresh=0.0001,
                    method=cfg.TEST.SOFT_NMS.METHOD
                )
            else:
                boxlist_for_class = boxlist_nms(
                    boxlist_for_class, self.nms
                )
            # Refine the post-NMS boxes using bounding-box voting
            if cfg.TEST.BBOX_VOTE.ENABLED and boxes_j.shape[0] > 0:
                boxlist_for_class = boxlist_box_voting(
                    boxlist_for_class,
                    boxlist_for_class_old,
                    cfg.TEST.BBOX_VOTE.VOTE_TH,
                    scoring_method=cfg.TEST.BBOX_VOTE.SCORING_METHOD
                )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:62,代码来源:inference.py

示例11: filter_results

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        quad_boxes = boxlist.quad_bbox.reshape(-1, num_classes * 8)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
            quad_boxes_j = quad_boxes[inds, j * 8 : (j + 1) * 8]
            boxlist_for_class = QuadBoxList(quad_boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.bbox = boxes_j
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(
                boxlist_for_class, self.nms
            )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result 
开发者ID:Xiangyu-CAS,项目名称:R2CNN.pytorch,代码行数:47,代码来源:inference.py

示例12: filter_results

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 5)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 5 : (j + 1) * 5]
            boxlist_for_class = RBoxList(boxes_j, boxlist.size, mode="xywha")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(
                boxlist_for_class, self.nms, score_field="scores"
            )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result 
开发者ID:clw5180,项目名称:remote_sensing_object_detection_2019,代码行数:44,代码来源:inference.py

示例13: filter_results

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(
                boxlist_for_class, self.nms, score_field="scores"
            )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result 
开发者ID:clw5180,项目名称:remote_sensing_object_detection_2019,代码行数:44,代码来源:inference.py

示例14: box_results_with_nms_and_limit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def box_results_with_nms_and_limit(
    scores, boxes, score_thresh=0.05, nms=0.5, detections_per_img=100
):
    """Returns bounding-box detection results by thresholding on scores and
    applying non-maximum suppression (NMS).
    `boxes` has shape (#detections, 4 * #classes), where each row represents
    a list of predicted bounding boxes for each of the object classes in the
    dataset (including the background class). The detections in each row
    originate from the same object proposal.
    `scores` has shape (#detection, #classes), where each row represents a list
    of object detection confidence scores for each of the object classes in the
    dataset (including the background class). `scores[i, j]`` corresponds to the
    box at `boxes[i, j * 4:(j + 1) * 4]`.
    """
    num_classes = scores.shape[1]
    cls_boxes = []
    cls_scores = []
    labels = []
    device = scores.device
    # Apply threshold on detection probabilities and apply NMS
    # Skip j = 0, because it's the background class
    for j in range(1, num_classes):
        inds = scores[:, j] > score_thresh
        scores_j = scores[inds, j]
        boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
        keep = box_nms(boxes_j, scores_j, nms)
        cls_boxes.append(boxes_j[keep])
        cls_scores.append(scores_j[keep])
        # TODO see why we need the device argument
        labels.append(torch.full_like(keep, j, device=device))

    cls_scores = torch.cat(cls_scores, dim=0)
    cls_boxes = torch.cat(cls_boxes, dim=0)
    labels = torch.cat(labels, dim=0)
    number_of_detections = len(cls_scores)

    # Limit to max_per_image detections **over all classes**
    if number_of_detections > detections_per_img > 0:
        image_thresh, _ = torch.kthvalue(
            cls_scores.cpu(), number_of_detections - detections_per_img + 1
        )
        keep = cls_scores >= image_thresh.item()
        keep = torch.nonzero(keep)
        keep = keep.squeeze(1) if keep.numel() else keep
        cls_boxes = cls_boxes[keep]
        cls_scores = cls_scores[keep]
        labels = labels[keep]
    return cls_scores, cls_boxes, labels 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:50,代码来源:fast_rcnn.py

示例15: filter_results

# 需要导入模块: import torch [as 别名]
# 或者: from torch import kthvalue [as 别名]
def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        if self.imbalanced_decider is None:
            inds_all = scores > self.score_thresh
        else:
            inds_all = self.imbalanced_decider(scores)
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(
                boxlist_for_class, self.nms
            )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result 
开发者ID:ChenJoya,项目名称:sampling-free,代码行数:47,代码来源:inference.py


注:本文中的torch.kthvalue方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。