当前位置: 首页>>代码示例>>Python>>正文


Python Bar.suffix方法代码示例

本文整理汇总了Python中progress.bar.Bar.suffix方法的典型用法代码示例。如果您正苦于以下问题:Python Bar.suffix方法的具体用法?Python Bar.suffix怎么用?Python Bar.suffix使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在progress.bar.Bar的用法示例。


在下文中一共展示了Bar.suffix方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def test(cfg):

    Dataset = dataset_factory[cfg.SAMPLE_METHOD]
    Logger(cfg)
    Detector = detector_factory[cfg.TEST.TASK]

    dataset = Dataset(cfg, 'val')
    detector = Detector(cfg)

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(cfg.EXP_ID), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])
        #img_path = '/home/tensorboy/data/coco/images/val2017/000000004134.jpg'
        ret = detector.run(img_path)

        results[img_id] = ret['results']

        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                       ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
        bar.next()
    bar.finish()
    dataset.run_eval(results, cfg.OUTPUT_DIR) 
开发者ID:tensorboy,项目名称:centerpose,代码行数:33,代码来源:evaluate.py

示例2: validate

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def validate(model, dataset, opt, ctx):
    """Test on validation dataset."""
    detector = CenterDetector(opt)
    detector.model = model

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(opt.exp_id), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    print("Reporting every 1000 images...")
    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])

        ret = detector.run(img_path)
        results[img_id] = ret['results']
        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                       ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
        if ind % 1000 == 0:
            bar.next()
    bar.finish()
    val_dataset.run_eval(results = results, save_dir = './output/') 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:29,代码来源:test_own_validation.py

示例3: validate

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def validate(model, dataset, opt, ctx):
    """Test on validation dataset."""
    detector = DddDetector(opt)
    detector.model = model

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(opt.exp_id), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    print("Reporting every 1000 images...")
    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]  # Kitti has been transformed to COCO format
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])

        ret = detector.run(img_path, img_info['calib'])

        results[img_id] = ret['results']
        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                       ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
        if ind % 1000 == 0:
            bar.next()
    bar.finish()
    val_dataset.run_eval(results = results, save_dir = './output/') 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:30,代码来源:train_3dod.py

示例4: validate

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def validate(model, dataset, opt, ctx):
    """Test on validation dataset."""
    detector = PoseDetector(opt)
    detector.model = model

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(opt.exp_id), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    print("Reporting every 1000 images...")

    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])

        ret = detector.run(img_path)
        results[img_id] = ret['results']
        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                       ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
        if ind % 1000 == 0:
            bar.next()

    bar.finish()
    val_dataset.run_eval(results = results, save_dir = './output/') 
开发者ID:Guanghan,项目名称:mxnet-centernet,代码行数:31,代码来源:train_2dpose.py

示例5: _epoch

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def _epoch(self, dataloader, epoch, mode = 'train'):
                """
                Training logic for an epoch
                """
                self.initepoch()
                if mode == 'train':
                        self.model.train()
                else :
                        self.model.eval()

                nIters = len(dataloader)
                bar = Bar('==>', max=nIters)

                for batch_idx, (data, target, meta1, meta2) in enumerate(dataloader):
                        model = self.model.to(self.gpu)
                        data = data.to(self.gpu, non_blocking=True).float()
                        target = target.to(self.gpu, non_blocking=True).float()
                        output = model(data)

                        loss = self.Loss(output, target, meta1.to(self.gpu, non_blocking=True).float().unsqueeze(-1))
                        self.loss.update(loss.item(), data.shape[0])

                        self._eval_metrics(output, target, meta1, meta2, data.shape[0])

                        if self.opts.DEBUG:
                                pass

                        if mode == 'train':
                                loss.backward()
                                if (batch_idx+1)%self.opts.mini_batch_count==0:
                                        self.optimizer.step()
                                        self.optimizer.zero_grad()
                                        if self.opts.DEBUG:
                                                pass

                        Bar.suffix = mode + ' Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss: {loss.avg:.6f} ({loss.val:.6f})'.format(epoch, batch_idx+1, nIters, total=bar.elapsed_td, eta=bar.eta_td, loss=self.loss) + self._print_metrics()
                        bar.next()
                bar.finish()
                return '{:8f} '.format(self.loss.avg) + ' '.join(['{:4f}'.format(getattr(self, key).avg) for key,_ in self.metrics.items()]) 
开发者ID:Naman-ntc,项目名称:Pytorch-Human-Pose-Estimation,代码行数:41,代码来源:trainer.py

示例6: prefetch_test

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def prefetch_test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)
  
  data_loader = torch.utils.data.DataLoader(
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    ret = detector.run(pre_processed_images)
    results[img_id.numpy().astype(np.int32)[0]] = ret['results']
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:CaoWGG,项目名称:CenterNet-CondInst,代码行数:36,代码来源:test.py

示例7: test

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind in range(num_iters):
    img_id = dataset.images[ind]
    img_info = dataset.coco.loadImgs(ids=[img_id])[0]
    img_path = os.path.join(dataset.img_dir, img_info['file_name'])

    if opt.task == 'ddd':
      ret = detector.run(img_path, img_info['calib'])
    else:
      ret = detector.run(img_path)
    
    results[img_id] = ret['results']

    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:kimyoon-young,项目名称:centerNet-deep-sort,代码行数:40,代码来源:test.py

示例8: step

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def step(split, epoch, opt, dataLoader, model, optimizer = None):
	if split == 'train':
		model.train()
	else:
		model.eval()
	Loss2D, Acc = AverageMeter(), AverageMeter()

	nIters = len(dataLoader)
	bar = Bar('==>', max=nIters)

	for i, (input, targetMaps, target2D, target3D, meta) in enumerate(dataLoader):
		input_var = (input).float().cuda()
		targetMaps = (targetMaps).float().cuda()
		target2D_var = (target2D).float().cuda()
		target3D_var = (target3D).float().cuda()
		model = model.float()
		output = model(input_var)[0]

		if opt.DEBUG == 2:
			for j in range(input_var.shape[2]):
				#plt.imshow(input_var.data[0,:,i,:,:].transpose(0,1).transpose(1,2).cpu().numpy())
				test_heatmaps(targetMaps[0,:,j,:,:].cpu(),input_var.data[0,:,j,:,:].cpu(),0)
				a = np.zeros((16,3))
				b = np.zeros((16,3))
				a[:,:2] = getPreds(targetMaps[:,:,i,:,:].cpu().numpy())
				b[:,:2] = getPreds(output[opt.nStack - 1][:,:,i,:,:].data.cpu().numpy())
				visualise3d(b,a,'%d'%(epoch),i,j,input_var.data[0,:,j,:,:].transpose(0,1).transpose(1,2).cpu().numpy(),opt)

		loss = 0
		for k in range(opt.nStack):
			loss += Joints2DHeatMapsSquaredError(output[k], targetMaps)
		
		Loss2D.update(loss.item(), input.size(0))

		tempAcc = Accuracy((output[opt.nStack - 1].data).transpose(1,2).reshape(-1,ref.nJoints,ref.outputRes,ref.outputRes).cpu().numpy(), (targetMaps.data).transpose(1,2).reshape(-1,ref.nJoints,ref.outputRes,ref.outputRes).cpu().numpy())
		Acc.update(tempAcc)


		if opt.DEBUG == 3 and (float(tempAcc) < 0.80):
			for j in range(input_var.shape[2]):
				a = np.zeros((16,3))
				b = np.zeros((16,3))
				a[:,:2] = getPreds(targetMaps[:1,:,j,:,:].cpu().numpy())
				b[:,:2] = getPreds(output[opt.nStack - 1][:1,:,j,:,:].data.cpu().numpy())
				visualise3d(b,a,'train-vis',i,j,input_var.data[0,:,j,:,:].transpose(0,1).transpose(1,2).cpu().numpy(),opt)


		if split == 'train':
			loss = loss/opt.trainBatch
			loss.backward()
			if ((opt.dataloaderSize*(i+1))%opt.trainBatch == 0):
				optimizer.step()
				optimizer.zero_grad()
		
		Bar.suffix = '{split} Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss2D {loss.avg:.6f} | PCK {PCK.avg:.6f} {PCK.val:.6f}'.format(epoch, i, nIters, total=bar.elapsed_td, eta=bar.eta_td, loss=Loss2D, split = split, PCK = Acc)
		bar.next()

	bar.finish()
	return Loss2D.avg, Acc.avg 
开发者ID:Naman-ntc,项目名称:3D-HourGlass-Network,代码行数:61,代码来源:train.py

示例9: step

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def step(split, epoch, opt, dataLoader, model, criterion, optimizer = None):
    if split == 'train':
        model.train()
    else:
        model.eval()
    Loss, Acc = AverageMeter(), AverageMeter()
    preds = []

    nIters = len(dataLoader)
    bar = Bar('{}'.format(opt.expID), max=nIters)
    for i, (input, target, meta) in enumerate(dataLoader):
        input_var = torch.autograd.Variable(input).float().cuda()
        target_var = torch.autograd.Variable(target).float().cuda()
        # model = torch.nn.DataParallel(model,device_ids=[0,1,2])
        output = model(input_var)
        # output = torch.nn.parallel.data_parallel(model,input_var,device_ids=[0,1,2,3,4,5])

        if opt.DEBUG >= 2:
            gt = getPreds(target.cuda().numpy()) * 4
            pred = getPreds((output[opt.nStack - 1].data).cuda().numpy()) * 4
            debugger = Debugger()
            img = (input[0].numpy().transpose(1, 2, 0)*256).astype(np.uint8).copy()
            debugger.addImg(img)
            debugger.addPoint2D(pred[0], (255, 0, 0))
            debugger.addPoint2D(gt[0], (0, 0, 255))
            debugger.showAllImg(pause = True)

        loss = criterion(output[0], target_var)
        for k in range(1, opt.nStack):
            loss += criterion(output[k], target_var)
        # Warning.after pytorch0.5.0 -> Tensor.item()代替loss.data[0]
        Loss.update(loss.data[0], input.size(0))
        Acc.update(Accuracy((output[opt.nStack - 1].data).cpu().numpy(), (target_var.data).cpu().numpy()))
        if split == 'train':
            # train
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            input_ = input.cpu().numpy()
            input_[0] = Flip(input_[0]).copy()
            inputFlip_var = torch.autograd.Variable(torch.from_numpy(input_).view(1, input_.shape[1], ref.inputRes, ref.inputRes)).float().cuda()
            outputFlip = model(inputFlip_var)
            outputFlip = ShuffleLR(Flip((outputFlip[opt.nStack - 1].data).cpu().numpy()[0])).reshape(1, ref.nJoints, ref.outputRes, ref.outputRes)
            output_ = ((output[opt.nStack - 1].data).cpu().numpy() + outputFlip) / 2
            preds.append(finalPreds(output_, meta['center'], meta['scale'], meta['rotate'])[0])

        Bar.suffix = '{split} Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss {loss.avg:.6f} | Acc {Acc.avg:.6f} ({Acc.val:.6f})'.format(epoch, i, nIters, total=bar.elapsed_td, eta=bar.eta_td, loss=Loss, Acc=Acc, split = split)
        bar.next()

    bar.finish()
    return {'Loss': Loss.avg, 'Acc': Acc.avg}, preds 
开发者ID:IcewineChen,项目名称:pytorch-PyraNet,代码行数:54,代码来源:train.py

示例10: run_epoch

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def run_epoch(self, phase, epoch, data_loader):
    model_with_loss = self.model_with_loss
    if phase == 'train':
      model_with_loss.train()
    else:
      if len(self.opt.gpus) > 1:
        model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()
    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)    
      output, loss, loss_stats = model_with_loss(batch)
      loss = loss.mean()
      if phase == 'train':
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      if opt.debug > 0:
        self.debug(batch, output, iter_id)
      
      if opt.test:
        self.save_result(output, batch, results)
      del output, loss, loss_stats
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    return ret, results 
开发者ID:kimyoon-young,项目名称:centerNet-deep-sort,代码行数:63,代码来源:base_trainer.py

示例11: initLatent

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def initLatent(loader, model, Y, nViews, S, AVG = False):
  model.eval()
  nIters = len(loader)
  N = loader.dataset.nImages 
  M = np.zeros((N, ref.J, 3))
  bar = Bar('==>', max=nIters)
  sum_sigma2 = 0
  cnt_sigma2 = 1
  for i, (input, target, meta) in enumerate(loader):
    output = (model(torch.autograd.Variable(input)).data).cpu().numpy()
    G = output.shape[0] / nViews
    output = output.reshape(G, nViews, ref.J, 3)
    if AVG:
      for g in range(G):
        id = int(meta[g * nViews, 1])
        for j in range(nViews):
          RR, tt = horn87(output[g, j].transpose(), output[g, 0].transpose())
          MM = (np.dot(RR, output[g, j].transpose())).transpose().copy()
          M[id] += MM.copy() / nViews
    else:
      for g in range(G):
        #assert meta[g * nViews, 0] > 1 + ref.eps
        p = np.zeros(nViews)
        sigma2 = 0.1
        for j in range(nViews):
          for kk in range(Y.shape[0] / S):
            k = kk * S
            d = Dis(Y[k], output[g, j])
            sum_sigma2 += d 
            cnt_sigma2 += 1
            p[j] += np.exp(- d / 2 / sigma2)
            
        id = int(meta[g * nViews, 1])
        M[id] = output[g, p.argmax()]
        
        if DEBUG and g == 0:
          print 'M[id]', id, M[id], p.argmax()
          debugger = Debugger()
          for j in range(nViews):
            RR, tt = horn87(output[g, j].transpose(), output[g, p.argmax()].transpose())
            MM = (np.dot(RR, output[g, j].transpose())).transpose().copy()
            debugger.addPoint3D(MM, 'b')
            debugger.addImg(input[g * nViews + j].numpy().transpose(1, 2, 0), j)
          debugger.showAllImg()
          debugger.addPoint3D(M[id], 'r')
          debugger.show3D()
        
    
    Bar.suffix = 'Init    : [{0:3}/{1:3}] | Total: {total:} | ETA: {eta:} | Dis: {dis:.6f}'.format(i, nIters, total=bar.elapsed_td, eta=bar.eta_td, dis = sum_sigma2 / cnt_sigma2)
    bar.next()
  bar.finish()
  #print 'mean sigma2', sum_sigma2 / cnt_sigma2
  return M 
开发者ID:xingyizhou,项目名称:3DKeypoints-DA,代码行数:55,代码来源:optim_latent.py

示例12: step

# 需要导入模块: from progress.bar import Bar [as 别名]
# 或者: from progress.bar.Bar import suffix [as 别名]
def step(args, split, epoch, loader, model, optimizer = None, M = None, f = None, tag = None):
  losses, mpjpe, mpjpe_r = AverageMeter(), AverageMeter(), AverageMeter()
  viewLosses, shapeLosses, supLosses = AverageMeter(), AverageMeter(), AverageMeter()
  
  if split == 'train':
    model.train()
  else:
    model.eval()
  bar = Bar('{}'.format(ref.category), max=len(loader))
  
  nViews = loader.dataset.nViews
  for i, (input, target, meta) in enumerate(loader):
    input_var = torch.autograd.Variable(input)
    target_var = torch.autograd.Variable(target)
    output = model(input_var)
    loss = ShapeConsistencyCriterion(nViews, supWeight = 1, unSupWeight = args.shapeWeight, M = M)(output, target_var, torch.autograd.Variable(meta))

    if split == 'test':
      for j in range(input.numpy().shape[0]):
        img = (input.numpy()[j] * 255).transpose(1, 2, 0).astype(np.uint8)
        cv2.imwrite('{}/img_{}/{}.png'.format(args.save_path, tag, i * input.numpy().shape[0] + j), img)
        gt = target.cpu().numpy()[j]
        pred = (output.data).cpu().numpy()[j]
        vis = meta.cpu().numpy()[j][5:]
        for t in range(ref.J):
          f.write('{} {} {} '.format(pred[t * 3], pred[t * 3 + 1], pred[t * 3 + 2]))
        f.write('\n')
        for t in range(ref.J):
          f.write('{} {} {} '.format(gt[t, 0], gt[t, 1], gt[t, 2]))
        f.write('\n')
        if args.saveVis:
          for t in range(ref.J):
            f.write('{} 0 0 '.format(vis[t]))
          f.write('\n')

    mpjpe_this = accuracy(output.data, target, meta)
    mpjpe_r_this = accuracy_dis(output.data, target, meta)
    shapeLoss = shapeConsistency(output.data, meta, nViews, M, split = split)

    losses.update(loss.data[0], input.size(0))
    shapeLosses.update(shapeLoss, input.size(0))
    mpjpe.update(mpjpe_this, input.size(0))
    mpjpe_r.update(mpjpe_r_this, input.size(0))
    
    if split == 'train':
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    
    Bar.suffix = '{split:10}: [{0:2}][{1:3}/{2:3}] | Total: {total:} | ETA: {eta:} | Loss {loss.avg:.6f} | shapeLoss {shapeLoss.avg:.6f} | AE {mpjpe.avg:.6f} | ShapeDis {mpjpe_r.avg:.6f}'.format(epoch, i, len(loader), total=bar.elapsed_td, eta=bar.eta_td, loss=losses, mpjpe=mpjpe, split = split, shapeLoss = shapeLosses, mpjpe_r = mpjpe_r)
    bar.next()
      
  bar.finish()
  return mpjpe.avg, losses.avg, shapeLosses.avg 
开发者ID:xingyizhou,项目名称:3DKeypoints-DA,代码行数:56,代码来源:train.py


注:本文中的progress.bar.Bar.suffix方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。