当前位置: 首页>>代码示例>>Python>>正文


Python config.num_classes方法代码示例

本文整理汇总了Python中config.config.num_classes方法的典型用法代码示例。如果您正苦于以下问题:Python config.num_classes方法的具体用法?Python config.num_classes怎么用?Python config.num_classes使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config.config的用法示例。


在下文中一共展示了config.num_classes方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: next_sample

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def next_sample(self):
      """Helper function for reading in next sample."""
      if self.cur >= len(self.seq):
        raise StopIteration
      idx = self.seq[self.cur]
      self.cur += 1
      uv_path = self.uv_file_list[idx]
      image_path = self.image_file_list[idx]
      uvmap = np.load(uv_path)
      img = cv2.imread(image_path)[:,:,::-1]#to rgb
      hlabel = uvmap
      #print(hlabel.shape)
      #hlabel = np.array(header.label).reshape( (self.output_label_size, self.output_label_size, self.num_classes) )
      hlabel /= self.input_img_size

      return img, hlabel 
开发者ID:deepinsight,项目名称:insightface,代码行数:18,代码来源:data.py

示例2: next_sample

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def next_sample(self):
      """Helper function for reading in next sample."""
      if self.cur >= len(self.seq):
        raise StopIteration
      idx = self.seq[self.cur]
      self.cur += 1
      s = self.imgrec.read_idx(idx)
      header, img = recordio.unpack(s)
      img = mx.image.imdecode(img).asnumpy()
      hlabel = np.array(header.label).reshape( (self.num_classes,2) )
      if not config.label_xfirst:
        hlabel = hlabel[:,::-1] #convert to X/W first
      annot = {'scale': config.base_scale}

      #ul = np.array( (50000,50000), dtype=np.int32)
      #br = np.array( (0,0), dtype=np.int32)
      #for i in range(hlabel.shape[0]):
      #  h = int(hlabel[i][0])
      #  w = int(hlabel[i][1])
      #  key = np.array((h,w))
      #  ul = np.minimum(key, ul)
      #  br = np.maximum(key, br)

      return img, hlabel, annot 
开发者ID:deepinsight,项目名称:insightface,代码行数:26,代码来源:data.py

示例3: compute_metric

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def compute_metric(self, results):
        hist = np.zeros((config.num_classes, config.num_classes))
        correct = 0
        labeled = 0
        count = 0
        for d in results:
            hist += d['hist']
            correct += d['correct']
            labeled += d['labeled']
            count += 1

        iu, mean_IU, _, mean_pixel_acc = compute_score(hist, correct,
                                                       labeled)
        result_line = print_iou(iu, mean_pixel_acc,
                                dataset.get_class_names(), True)
        return result_line 
开发者ID:StevenGrove,项目名称:TreeFilter-Torch,代码行数:18,代码来源:eval.py

示例4: split_data

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def split_data(file2idx, val_ratio=0.1):
    '''
    划分数据集,val需保证每类至少有1个样本
    :param file2idx:
    :param val_ratio:验证集占总数据的比例
    :return:训练集,验证集路径
    '''
    data = set(os.listdir(config.train_dir))
    val = set()
    idx2file = [[] for _ in range(config.num_classes)]
    for file, list_idx in file2idx.items():
        for idx in list_idx:
            idx2file[idx].append(file)
    for item in idx2file:
        # print(len(item), item)
        num = int(len(item) * val_ratio)
        val = val.union(item[:num])
    train = data.difference(val)
    return list(train), list(val) 
开发者ID:JavisPeng,项目名称:ecg_pytorch,代码行数:21,代码来源:data_process.py

示例5: func_per_iteration

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        pred = self.sliding_eval(img, config.eval_crop_size,
                                 config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred,
                                                       label)
        results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp,
                        'correct': correct_tmp}

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors()
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean,
                                label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict 
开发者ID:StevenGrove,项目名称:TreeFilter-Torch,代码行数:31,代码来源:eval.py

示例6: func_per_iteration

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        pred = self.sliding_eval(img, config.eval_crop_size,
                                 config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred,
                                                       label)
        results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp,
                        'correct': correct_tmp}

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean,
                                label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict 
开发者ID:StevenGrove,项目名称:TreeFilter-Torch,代码行数:31,代码来源:eval.py

示例7: count_labels

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_classes [as 别名]
def count_labels(data, file2idx):
    '''
    统计每个类别的样本数
    :param data:
    :param file2idx:
    :return:
    '''
    cc = [0] * config.num_classes
    for fp in data:
        for i in file2idx[fp]:
            cc[i] += 1
    return np.array(cc) 
开发者ID:JavisPeng,项目名称:ecg_pytorch,代码行数:14,代码来源:data_process.py


注:本文中的config.config.num_classes方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。