当前位置: 首页>>代码示例>>Python>>正文


Python dataset_factory.dataset_factory方法代码示例

本文整理汇总了Python中datasets.dataset_factory.dataset_factory方法的典型用法代码示例。如果您正苦于以下问题:Python dataset_factory.dataset_factory方法的具体用法?Python dataset_factory.dataset_factory怎么用?Python dataset_factory.dataset_factory使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在datasets.dataset_factory的用法示例。


在下文中一共展示了dataset_factory.dataset_factory方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import dataset_factory [as 别名]
def test(cfg):

    Dataset = dataset_factory[cfg.SAMPLE_METHOD]
    Logger(cfg)
    Detector = detector_factory[cfg.TEST.TASK]

    dataset = Dataset(cfg, 'val')
    detector = Detector(cfg)

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(cfg.EXP_ID), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])
        #img_path = '/home/tensorboy/data/coco/images/val2017/000000004134.jpg'
        ret = detector.run(img_path)

        results[img_id] = ret['results']

        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                       ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
        bar.next()
    bar.finish()
    dataset.run_eval(results, cfg.OUTPUT_DIR) 
开发者ID:tensorboy,项目名称:centerpose,代码行数:33,代码来源:evaluate.py

示例2: prefetch_test

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import dataset_factory [as 别名]
def prefetch_test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)
  
  data_loader = torch.utils.data.DataLoader(
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    ret = detector.run(pre_processed_images)
    results[img_id.numpy().astype(np.int32)[0]] = ret['results']
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:CaoWGG,项目名称:CenterNet-CondInst,代码行数:36,代码来源:test.py

示例3: test

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import dataset_factory [as 别名]
def test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)

  results = {}
  num_iters = len(dataset)
  for ind in tqdm(range(num_iters)):
    img_id = dataset.images[ind]
    img_info = dataset.coco.loadImgs(ids=[img_id])[0]
    img_path = os.path.join(dataset.img_dir, img_info['file_name'])

    if opt.task == 'ddd':
      ret = detector.run(img_path, img_info['calib'])
    else:
      ret = detector.run(img_path)
    
    results[img_id] = ret['results']
  dataset.run_eval(results, opt.save_dir) 
开发者ID:CaoWGG,项目名称:CenterNet-CondInst,代码行数:29,代码来源:test.py

示例4: test

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import dataset_factory [as 别名]
def test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind in range(num_iters):
    img_id = dataset.images[ind]
    img_info = dataset.coco.loadImgs(ids=[img_id])[0]
    img_path = os.path.join(dataset.img_dir, img_info['file_name'])

    if opt.task == 'ddd':
      ret = detector.run(img_path, img_info['calib'])
    else:
      ret = detector.run(img_path)
    
    results[img_id] = ret['results']

    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:kimyoon-young,项目名称:centerNet-deep-sort,代码行数:40,代码来源:test.py

示例5: predict

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import dataset_factory [as 别名]
def predict(hparams,
            model_dir, checkpoint_path, output_dir,
            test_source_files, test_target_files):
    def predict_input_fn():
        source = tf.data.TFRecordDataset(list(test_source_files))
        target = tf.data.TFRecordDataset(list(test_target_files))
        dataset = dataset_factory(source, target, hparams)
        batched = dataset.prepare_and_zip().group_by_batch(
            batch_size=1).merge_target_to_source()
        return batched.dataset

    estimator = tacotron_model_factory(hparams, model_dir, None)

    predictions = map(
        lambda p: PredictedMel(p["id"], p["key"], p["mel"], p.get("mel_postnet"), p["mel"].shape[1], p["mel"].shape[0],
                               p["ground_truth_mel"], p["alignment"], p.get("alignment2"), p.get("alignment3"),
                               p.get("alignment4"), p.get("alignment5"), p.get("alignment6"),
                               p["source"], p["text"], p.get("accent_type")),
        estimator.predict(predict_input_fn, checkpoint_path=checkpoint_path))

    for v in predictions:
        key = v.key.decode('utf-8')
        mel_filename = f"{key}.{hparams.predicted_mel_extension}"
        mel_filepath = os.path.join(output_dir, mel_filename)
        mel = v.predicted_mel_postnet if hparams.use_postnet_v2 else v.predicted_mel
        assert mel.shape[1] == hparams.num_mels
        mel.tofile(mel_filepath, format='<f4')
        text = v.text.decode("utf-8")
        plot_filename = f"{key}.png"
        plot_filepath = os.path.join(output_dir, plot_filename)
        alignments = list(filter(lambda x: x is not None,
                                 [v.alignment, v.alignment2, v.alignment3, v.alignment4, v.alignment5, v.alignment6]))

        plot_predictions(alignments, v.ground_truth_mel, v.predicted_mel, v.predicted_mel_postnet,
                         text, v.key, plot_filepath)
        prediction_filename = f"{key}.tfrecord"
        prediction_filepath = os.path.join(output_dir, prediction_filename)
        write_prediction_result(v.id, key, alignments, mel, v.ground_truth_mel, text, v.source,
                                v.accent_type, prediction_filepath) 
开发者ID:nii-yamagishilab,项目名称:self-attention-tacotron,代码行数:41,代码来源:predict_mel.py

示例6: train_and_evaluate

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import dataset_factory [as 别名]
def train_and_evaluate(hparams, model_dir, train_source_files, train_target_files, eval_source_files,
                       eval_target_files, use_multi_gpu):

    interleave_parallelism = get_parallelism(hparams.interleave_cycle_length_cpu_factor,
                                             hparams.interleave_cycle_length_min,
                                             hparams.interleave_cycle_length_max)

    tf.logging.info("Interleave parallelism is %d.", interleave_parallelism)

    def train_input_fn():
        source_and_target_files = list(zip(train_source_files, train_target_files))
        shuffle(source_and_target_files)
        source = [s for s, _ in source_and_target_files]
        target = [t for _, t in source_and_target_files]

        dataset = create_from_tfrecord_files(source, target, hparams,
                                             cycle_length=interleave_parallelism,
                                             buffer_output_elements=hparams.interleave_buffer_output_elements,
                                             prefetch_input_elements=hparams.interleave_prefetch_input_elements)

        zipped = dataset.prepare_and_zip()
        zipped = zipped.cache(hparams.cache_file_name) if hparams.use_cache else zipped
        batched = zipped.filter_by_max_output_length().repeat(count=None).shuffle(
            hparams.suffle_buffer_size).group_by_batch().prefetch(hparams.prefetch_buffer_size)
        return batched.dataset

    def eval_input_fn():
        source_and_target_files = list(zip(eval_source_files, eval_target_files))
        shuffle(source_and_target_files)
        source = tf.data.TFRecordDataset([s for s, _ in source_and_target_files])
        target = tf.data.TFRecordDataset([t for _, t in source_and_target_files])

        dataset = dataset_factory(source, target, hparams)
        zipped = dataset.prepare_and_zip()
        dataset = zipped.filter_by_max_output_length().repeat().group_by_batch(batch_size=1)
        return dataset.dataset

    distribution = tf.contrib.distribute.MirroredStrategy() if use_multi_gpu else None

    run_config = tf.estimator.RunConfig(save_summary_steps=hparams.save_summary_steps,
                                        save_checkpoints_steps=hparams.save_checkpoints_steps,
                                        keep_checkpoint_max=hparams.keep_checkpoint_max,
                                        log_step_count_steps=hparams.log_step_count_steps,
                                        train_distribute=distribution)

    ws = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=hparams.ckpt_to_initialize_from,
        vars_to_warm_start=hparams.vars_to_warm_start) if hparams.warm_start else None

    estimator = tacotron_model_factory(hparams, model_dir, run_config, ws)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                      steps=hparams.num_evaluation_steps,
                                      throttle_secs=hparams.eval_throttle_secs,
                                      start_delay_secs=hparams.eval_start_delay_secs)

    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 
开发者ID:nii-yamagishilab,项目名称:self-attention-tacotron,代码行数:60,代码来源:train.py


注:本文中的datasets.dataset_factory.dataset_factory方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。