当前位置: 首页>>代码示例>>Python>>正文


Python visualizer.Visualizer方法代码示例

本文整理汇总了Python中utils.visualizer.Visualizer方法的典型用法代码示例。如果您正苦于以下问题:Python visualizer.Visualizer方法的具体用法?Python visualizer.Visualizer怎么用?Python visualizer.Visualizer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.visualizer的用法示例。


在下文中一共展示了visualizer.Visualizer方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: evaluate

# 需要导入模块: from utils import visualizer [as 别名]
# 或者: from utils.visualizer import Visualizer [as 别名]
def evaluate(opt, dloader, model, use_saved_file=False):
  # Visualizer
  if hasattr(opt, 'save_visuals') and opt.save_visuals:
    vis = Visualizer(os.path.join(opt.ckpt_path, 'tb_test'))
  else:
    opt.save_visuals = False

  model.setup(is_train=False)
  metric = utils.Metrics()
  results = {}

  if hasattr(opt, 'save_all_results') and opt.save_all_results:
    save_dir = os.path.join(opt.ckpt_path, 'results')
    os.makedirs(save_dir, exist_ok=True)
  else:
    opt.save_all_results = False

  # Hacky
  is_bouncing_balls = ('bouncing_balls' in opt.dset_name) and opt.n_components == 4
  if is_bouncing_balls:
    dloader.dataset.return_positions = True
    saved_positions = os.path.join(opt.ckpt_path, 'positions.npy') if use_saved_file else ''
    velocity_metric = utils.VelocityMetrics(saved_positions)

  count = 0
  for step, data in enumerate(dloader):
    if not is_bouncing_balls:
      input, gt = data
    else:
      input, gt, positions = data
    output, latent = model.test(input, gt)
    pred = output[:, opt.n_frames_input:, ...]
    metric.update(gt, pred)

    if opt.save_all_results:
      gt = np.concatenate([input.numpy(), gt.numpy()], axis=1)
      prediction = utils.to_numpy(output)
      count = save_images(prediction, gt, latent, save_dir, count)

    if is_bouncing_balls:
      # Calculate position and velocity from pose
      pose = latent['pose'].data.cpu()
      velocity_metric.update(positions, pose, opt.n_frames_input)

    if (step + 1) % opt.log_every == 0:
      print('{}/{}'.format(step + 1, len(dloader)))
      if opt.save_visuals:
        vis.add_images(model.get_visuals(), step, prefix='test')

  # BCE, MSE
  results.update(metric.get_scores())

  if is_bouncing_balls:
    # Don't break the original code
    dloader.dataset.return_positions = False
    results.update(velocity_metric.get_scores())

  return results 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:60,代码来源:test.py

示例2: main

# 需要导入模块: from utils import visualizer [as 别名]
# 或者: from utils.visualizer import Visualizer [as 别名]
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
                                   opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2, num_modules=1,
                         num_classes=num_classes, chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        # exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2, hg, opt, is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                                val_distri_path_2, hg, opt, is_train=False) 
开发者ID:zhiqiangdon,项目名称:pose-adv-aug,代码行数:48,代码来源:collect-rotation-ditri.py

示例3: main

# 需要导入模块: from utils import visualizer [as 别名]
# 或者: from utils.visualizer import Visualizer [as 别名]
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
                                   opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    # train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_scales_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_scales_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2, num_modules=1,
                         num_classes=num_classes, chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2, hg, opt, is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                                val_distri_path_2, hg, opt, is_train=False) 
开发者ID:zhiqiangdon,项目名称:pose-adv-aug,代码行数:48,代码来源:collect-scale-ditri.py

示例4: main

# 需要导入模块: from utils import visualizer [as 别名]
# 或者: from utils.visualizer import Visualizer [as 别名]
def main():
    parser = argparse.ArgumentParser()
    # Seed option
    parser.add_argument('--seed', default=0, type=int)
    # GPU option
    parser.add_argument('--gpu_id', type=int, default=0)
    # Genrator option
    parser.add_argument('--g_path', type=str, required=True)
    # Output options
    parser.add_argument('--out', type=str, default='samples')
    parser.add_argument('--num_samples', type=int, default=10)
    parser.add_argument('--eval_batch_size', type=int, default=128)
    args = parser.parse_args()

    # Set up seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Set up GPU
    if torch.cuda.is_available() and args.gpu_id >= 0:
        device = torch.device('cuda:%d' % args.gpu_id)
    else:
        device = torch.device('cpu')

    # Set up generator
    g_root = os.path.dirname(args.g_path)
    g_params = util.load_params(os.path.join(g_root, 'netG_params.pkl'))
    g_iteration = int(
        os.path.splitext(os.path.basename(args.g_path))[0].split('_')[-1])
    netG = resnet.Generator(**g_params)
    netG.to(device)
    netG.load_state_dict(
        torch.load(args.g_path, map_location=lambda storage, loc: storage))
    netG.eval()

    # Set up output
    if not os.path.exists(args.out):
        os.makedirs(args.out)

    # Set up visualizer
    visualizer = Visualizer(netG, device, args.out, args.num_samples,
                            netG.num_classes, args.eval_batch_size)

    # Visualize
    visualizer.visualize(g_iteration) 
开发者ID:takuhirok,项目名称:rGAN,代码行数:47,代码来源:test.py


注:本文中的utils.visualizer.Visualizer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。