当前位置: 首页>>代码示例>>Python>>正文


Python utils.AverageMeter方法代码示例

本文整理汇总了Python中utils.utils.AverageMeter方法的典型用法代码示例。如果您正苦于以下问题:Python utils.AverageMeter方法的具体用法?Python utils.AverageMeter怎么用?Python utils.AverageMeter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.utils的用法示例。


在下文中一共展示了utils.AverageMeter方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: update_states

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def update_states(self, states, batch_size=1):

        if len(self.states) == 0:
            state_names = states.keys()
            self.states = OrderedDict(
                [(key, 0) for key in state_names]
            )

            self.average_meters = OrderedDict(
                [(key, AverageMeter())
                 for key in state_names]
            )

        self.states.update(states)
        for key, meter in self.average_meters.items():
            meter.update(self.states[key], batch_size) 
开发者ID:zhixinwang,项目名称:frustum-convnet,代码行数:18,代码来源:training_states.py

示例2: test

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def test(cfg):

    Dataset = dataset_factory[cfg.SAMPLE_METHOD]
    Logger(cfg)
    Detector = detector_factory[cfg.TEST.TASK]

    dataset = Dataset(cfg, 'val')
    detector = Detector(cfg)

    results = {}
    num_iters = len(dataset)
    bar = Bar('{}'.format(cfg.EXP_ID), max=num_iters)
    time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
    avg_time_stats = {t: AverageMeter() for t in time_stats}
    for ind in range(num_iters):
        img_id = dataset.images[ind]
        img_info = dataset.coco.loadImgs(ids=[img_id])[0]
        img_path = os.path.join(dataset.img_dir, img_info['file_name'])
        #img_path = '/home/tensorboy/data/coco/images/val2017/000000004134.jpg'
        ret = detector.run(img_path)

        results[img_id] = ret['results']

        Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                       ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
        for t in avg_time_stats:
            avg_time_stats[t].update(ret[t])
            Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
        bar.next()
    bar.finish()
    dataset.run_eval(results, cfg.OUTPUT_DIR) 
开发者ID:tensorboy,项目名称:centerpose,代码行数:33,代码来源:evaluate.py

示例3: validate

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def validate(self, loader, model, criterion, epoch, args):
        timer = Timer()
        losses = AverageMeter()
        top1 = AverageMeter()
        wtop1 = AverageMeter()
        alloutputs = []
        metrics = {}

        # switch to evaluate mode
        model.eval()

        def part(x):
            return itertools.islice(x, int(len(x) * args.val_size))
        for i, x in enumerate(part(loader)):
            inputs, target, meta = parse(x)
            output, loss, weights = forward(inputs, target, model, criterion, meta['id'], train=False)
            prec1 = triplet_accuracy(output, target)
            wprec1 = triplet_accuracy(output, target, weights)
            losses.update(loss.data[0], inputs[0].size(0))
            top1.update(prec1, inputs[0].size(0))
            wtop1.update(wprec1, inputs[0].size(0))
            alloutputs.extend(zip([(x.data[0], y.data[0]) for x, y in zip(*output)], target, weights))
            timer.tic()

            if i % args.print_freq == 0:
                print('[{name}] Test [{epoch}]: [{0}/{1} ({2})]\t'
                      'Time {timer.val:.3f} ({timer.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'WAcc@1 {wtop1.val:.3f} ({wtop1.avg:.3f})\t'.format(
                          i, int(len(loader) * args.val_size), len(loader), name=args.name,
                          timer=timer, loss=losses, top1=top1, epoch=epoch, wtop1=wtop1))

        metrics.update(triplet_allk(*zip(*alloutputs)))
        metrics.update({'top1val': top1.avg, 'wtop1val': wtop1.avg})
        print(' * Acc@1 {top1val:.3f} \t WAcc@1 {wtop1val:.3f}'
              '\n   topk1: {topk1:.3f} \t topk2: {topk2:.3f} \t '
              'topk5: {topk5:.3f} \t topk10: {topk10:.3f} \t topk50: {topk50:.3f}'
              .format(**metrics))

        return metrics 
开发者ID:gsig,项目名称:actor-observer,代码行数:43,代码来源:train.py

示例4: prefetch_test

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def prefetch_test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)
  
  data_loader = torch.utils.data.DataLoader(
    PrefetchDataset(opt, dataset, detector.pre_process), 
    batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind, (img_id, pre_processed_images) in enumerate(data_loader):
    ret = detector.run(pre_processed_images)
    results[img_id.numpy().astype(np.int32)[0]] = ret['results']
    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {tm.val:.3f}s ({tm.avg:.3f}s) '.format(
        t, tm = avg_time_stats[t])
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:CaoWGG,项目名称:CenterNet-CondInst,代码行数:36,代码来源:test.py

示例5: test

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def test(opt):
  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str

  Dataset = dataset_factory[opt.dataset]
  opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
  print(opt)
  Logger(opt)
  Detector = detector_factory[opt.task]
  
  split = 'val' if not opt.trainval else 'test'
  dataset = Dataset(opt, split)
  detector = Detector(opt)

  results = {}
  num_iters = len(dataset)
  bar = Bar('{}'.format(opt.exp_id), max=num_iters)
  time_stats = ['tot', 'load', 'pre', 'net', 'dec', 'post', 'merge']
  avg_time_stats = {t: AverageMeter() for t in time_stats}
  for ind in range(num_iters):
    img_id = dataset.images[ind]
    img_info = dataset.coco.loadImgs(ids=[img_id])[0]
    img_path = os.path.join(dataset.img_dir, img_info['file_name'])

    if opt.task == 'ddd':
      ret = detector.run(img_path, img_info['calib'])
    else:
      ret = detector.run(img_path)
    
    results[img_id] = ret['results']

    Bar.suffix = '[{0}/{1}]|Tot: {total:} |ETA: {eta:} '.format(
                   ind, num_iters, total=bar.elapsed_td, eta=bar.eta_td)
    for t in avg_time_stats:
      avg_time_stats[t].update(ret[t])
      Bar.suffix = Bar.suffix + '|{} {:.3f} '.format(t, avg_time_stats[t].avg)
    bar.next()
  bar.finish()
  dataset.run_eval(results, opt.save_dir) 
开发者ID:kimyoon-young,项目名称:centerNet-deep-sort,代码行数:40,代码来源:test.py

示例6: __init__

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def __init__(self, state_names=[]):

        self.states = OrderedDict(
            [(key, 0) for key in state_names]
        )

        self.average_meters = OrderedDict(
            [(key, AverageMeter())
             for key in state_names]
        )

        self.state_names = state_names 
开发者ID:zhixinwang,项目名称:frustum-convnet,代码行数:14,代码来源:training_states.py

示例7: step

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def step(split, epoch, opt, dataLoader, model, criterion, optimizer = None):
    if split == 'train':
        model.train()
    else:
        model.eval()
    Loss, Acc = AverageMeter(), AverageMeter()
    preds = []

    nIters = len(dataLoader)
    bar = Bar('{}'.format(opt.expID), max=nIters)
    for i, (input, target, meta) in enumerate(dataLoader):
        input_var = torch.autograd.Variable(input).float().cuda()
        target_var = torch.autograd.Variable(target).float().cuda()
        # model = torch.nn.DataParallel(model,device_ids=[0,1,2])
        output = model(input_var)
        # output = torch.nn.parallel.data_parallel(model,input_var,device_ids=[0,1,2,3,4,5])

        if opt.DEBUG >= 2:
            gt = getPreds(target.cuda().numpy()) * 4
            pred = getPreds((output[opt.nStack - 1].data).cuda().numpy()) * 4
            debugger = Debugger()
            img = (input[0].numpy().transpose(1, 2, 0)*256).astype(np.uint8).copy()
            debugger.addImg(img)
            debugger.addPoint2D(pred[0], (255, 0, 0))
            debugger.addPoint2D(gt[0], (0, 0, 255))
            debugger.showAllImg(pause = True)

        loss = criterion(output[0], target_var)
        for k in range(1, opt.nStack):
            loss += criterion(output[k], target_var)
        # Warning.after pytorch0.5.0 -> Tensor.item()代替loss.data[0]
        Loss.update(loss.data[0], input.size(0))
        Acc.update(Accuracy((output[opt.nStack - 1].data).cpu().numpy(), (target_var.data).cpu().numpy()))
        if split == 'train':
            # train
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            input_ = input.cpu().numpy()
            input_[0] = Flip(input_[0]).copy()
            inputFlip_var = torch.autograd.Variable(torch.from_numpy(input_).view(1, input_.shape[1], ref.inputRes, ref.inputRes)).float().cuda()
            outputFlip = model(inputFlip_var)
            outputFlip = ShuffleLR(Flip((outputFlip[opt.nStack - 1].data).cpu().numpy()[0])).reshape(1, ref.nJoints, ref.outputRes, ref.outputRes)
            output_ = ((output[opt.nStack - 1].data).cpu().numpy() + outputFlip) / 2
            preds.append(finalPreds(output_, meta['center'], meta['scale'], meta['rotate'])[0])

        Bar.suffix = '{split} Epoch: [{0}][{1}/{2}]| Total: {total:} | ETA: {eta:} | Loss {loss.avg:.6f} | Acc {Acc.avg:.6f} ({Acc.val:.6f})'.format(epoch, i, nIters, total=bar.elapsed_td, eta=bar.eta_td, loss=Loss, Acc=Acc, split = split)
        bar.next()

    bar.finish()
    return {'Loss': Loss.avg, 'Acc': Acc.avg}, preds 
开发者ID:IcewineChen,项目名称:pytorch-PyraNet,代码行数:54,代码来源:train.py

示例8: train

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def train(self, loader, model, criterion, optimizer, epoch, args):
        adjust_learning_rate(args.lr, args.lr_decay_rate, optimizer, epoch)
        timer = Timer()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        wtop1 = AverageMeter()
        metrics = {}

        # switch to train mode
        model.train()
        optimizer.zero_grad()

        def part(x):
            return itertools.islice(x, int(len(x) * args.train_size))
        for i, x in enumerate(part(loader)):
            inputs, target, meta = parse(x)
            data_time.update(timer.thetime() - timer.end)
            output, loss, weights = forward(inputs, target, model, criterion, meta['id'])
            prec1 = triplet_accuracy(output, target)
            wprec1 = triplet_accuracy(output, target, weights)
            losses.update(loss.data[0], inputs[0].size(0))
            top1.update(prec1, inputs[0].size(0))
            wtop1.update(wprec1, inputs[0].size(0))

            loss.backward()
            if i % args.accum_grad == args.accum_grad - 1:
                print('updating parameters')
                optimizer.step()
                optimizer.zero_grad()

            timer.tic()
            if i % args.print_freq == 0:
                print('[{name}] Epoch: [{0}][{1}/{2}({3})]\t'
                      'Time {timer.val:.3f} ({timer.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'WAcc@1 {wtop1.val:.3f} ({wtop1.avg:.3f})\t'.format(
                          epoch, i, int(len(loader) * args.train_size), len(loader), name=args.name,
                          timer=timer, data_time=data_time, loss=losses, top1=top1, wtop1=wtop1))

        metrics.update({'top1': top1.avg, 'wtop1': wtop1.avg})
        return metrics 
开发者ID:gsig,项目名称:actor-observer,代码行数:46,代码来源:train.py

示例9: _valid_epoch

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def _valid_epoch(self):
        valid_loss = AverageMeter()
        valid_probs = []

        for step, batch in enumerate(self.valid_loader):
            self.model.eval()
            batch = tuple(t.to(self.device) for t in batch)
            batch_size = batch[1].size(0)

            with torch.no_grad():
                op = batch[0]
                inputs = {
                    "input_ids_a": batch[1],
                    "token_type_ids_a": batch[2],
                    "attention_mask_a": batch[3],
                    "input_ids_b": batch[4],
                    "token_type_ids_b": batch[5],
                    "attention_mask_b": batch[6],
                    "input_ids_c": batch[7],
                    "token_type_ids_c": batch[8],
                    "attention_mask_c": batch[9],
                }
                if self.fts_flag:
                    inputs.update(
                        {"x_a": batch[10], "x_b": batch[11], "x_c": batch[12]}
                    )
                anchor, positive, negative = self.model(**inputs)

                # loss = self.criterion(anchor, positive, negative)
                loss = self.criterion(op.float(), anchor, positive, negative)
                valid_loss.update(loss.item(), batch_size)

            anchor = anchor.to("cpu").numpy()
            positive = positive.to("cpu").numpy()
            negative = negative.to("cpu").numpy()

            pos_dist = np.sqrt(
                np.sum(np.square(anchor - positive), axis=-1, keepdims=True)
            )
            neg_dist = np.sqrt(
                np.sum(np.square(anchor - negative), axis=-1, keepdims=True)
            )
            probs = pos_dist - neg_dist
            # probs = (op.to("cpu").numpy() * (pos_dist - neg_dist)).diagonal()
            valid_probs.append(probs)
        valid_probs = np.concatenate(valid_probs)

        valid_log = {"val_loss": valid_loss.avg, "val_probs": valid_probs}

        return valid_log 
开发者ID:GuidoPaul,项目名称:CAIL2019,代码行数:52,代码来源:trainer.py

示例10: eval

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def eval(self):
        test_loss = AverageMeter()
        test_probs = []

        for step, batch in enumerate(self.test_loader):
            self.model.eval()
            batch = tuple(t.to(self.device) for t in batch)
            batch_size = batch[1].size(0)

            with torch.no_grad():
                op = batch[0]
                inputs = {
                    "input_ids_a": batch[1],
                    "token_type_ids_a": batch[2],
                    "attention_mask_a": batch[3],
                    "input_ids_b": batch[4],
                    "token_type_ids_b": batch[5],
                    "attention_mask_b": batch[6],
                    "input_ids_c": batch[7],
                    "token_type_ids_c": batch[8],
                    "attention_mask_c": batch[9],
                }
                if self.fts_flag:
                    inputs.update(
                        {"x_a": batch[10], "x_b": batch[11], "x_c": batch[12]}
                    )
                anchor, positive, negative = self.model(**inputs)

                loss = self.criterion(op.float(), anchor, positive, negative)
                test_loss.update(loss.item(), batch_size)

            anchor = anchor.to("cpu").numpy()
            positive = positive.to("cpu").numpy()
            negative = negative.to("cpu").numpy()

            pos_dist = np.sqrt(
                np.sum(np.square(anchor - positive), axis=-1, keepdims=True)
            )
            neg_dist = np.sqrt(
                np.sum(np.square(anchor - negative), axis=-1, keepdims=True)
            )
            probs = pos_dist - neg_dist
            test_probs.append(probs)
        test_probs = np.concatenate(test_probs)

        correct = test_probs[np.where(test_probs <= 0)].shape[0]
        self.logger.info(
            f"min: {np.min(test_probs):.4f} "
            f"max: {np.max(test_probs):.4f} "
            f"avg: {np.average(test_probs):.4f} "
            f"loss: {test_loss.avg:.4f} "
            f"acc: {correct}, {float(correct / len(test_probs)):.4f}"
        ) 
开发者ID:GuidoPaul,项目名称:CAIL2019,代码行数:55,代码来源:tester.py

示例11: run_train_epoch

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def run_train_epoch(model, optimizer, criterion, train_dataloader, epoch, args):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    #data_time = utils.AverageMeter('Data', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    grad_norm = utils.AverageMeter('grad_norm', ':.4e')
    progress = utils.ProgressMeter(len(train_dataloader), batch_time, losses, grad_norm,
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    # trainloader is an iterator. This line extract one minibatch at one time
    for i, data in enumerate(train_dataloader, 0):
        feat = data["x"]
        label = data["y"]

        x = feat.to(th.float32)
        y = label.unsqueeze(2).long()

        if th.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        prediction = model(x)
        loss = criterion(prediction.view(-1, prediction.shape[2]), y.view(-1))

        optimizer.zero_grad()
        loss.backward()

        # Gradient Clipping
        norm = nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()

        grad_norm.update(norm)

        # update loss
        losses.update(loss.item(), x.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)

        if i % args.print_freq == 0:
    #        if not args.hvd or hvd.rank() == 0:
            progress.print(i) 
开发者ID:jzlianglu,项目名称:pykaldi2,代码行数:44,代码来源:train_ce.py

示例12: run_epoch

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def run_epoch(self, phase, epoch, data_loader):
    model_with_loss = self.model_with_loss
    if phase == 'train':
      model_with_loss.train()
    else:
      if len(self.opt.gpus) > 1:
        model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()
    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)    
      output, loss, loss_stats = model_with_loss(batch)
      loss = loss.mean()
      if phase == 'train':
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      if opt.debug > 0:
        self.debug(batch, output, iter_id)
      
      if opt.test:
        self.save_result(output, batch, results)
      del output, loss, loss_stats
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    return ret, results 
开发者ID:kimyoon-young,项目名称:centerNet-deep-sort,代码行数:63,代码来源:base_trainer.py

示例13: validate

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def validate(data_loader, model, epoch, logger=None):
    data_time_meter = AverageMeter()
    batch_time_meter = AverageMeter()

    model.eval()

    tic = time.time()
    loader_size = len(data_loader)

    training_states = TrainingStates()

    for i, (data_dicts) in enumerate(data_loader):
        data_time_meter.update(time.time() - tic)

        batch_size = data_dicts['point_cloud'].shape[0]

        with torch.no_grad():
            data_dicts_var = {key: value.cuda() for key, value in data_dicts.items()}

            losses, metrics = model(data_dicts_var)
            # mean for multi-gpu setting
            losses_reduce = {key: value.detach().mean().item() for key, value in losses.items()}
            metrics_reduce = {key: value.detach().mean().item() for key, value in metrics.items()}

        training_states.update_states(dict(**losses_reduce, **metrics_reduce), batch_size)

        batch_time_meter.update(time.time() - tic)
        tic = time.time()

    states = training_states.get_states(avg=True)

    states_str = training_states.format_states(states)
    output_str = 'Validation Epoch: {:03d} Time:{:.3f}/{:.3f} ' \
        .format(epoch + 1, data_time_meter.val, batch_time_meter.val)

    logging.info(output_str + states_str)

    if logger is not None:
        for tag, value in states.items():
            logger.scalar_summary(tag, value, int(epoch))

    return states['IoU_' + str(cfg.IOU_THRESH)] 
开发者ID:zhixinwang,项目名称:frustum-convnet,代码行数:44,代码来源:train_net_det.py

示例14: step

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import AverageMeter [as 别名]
def step(args, split, epoch, loader, model, optimizer = None, M = None, f = None, tag = None):
  losses, mpjpe, mpjpe_r = AverageMeter(), AverageMeter(), AverageMeter()
  viewLosses, shapeLosses, supLosses = AverageMeter(), AverageMeter(), AverageMeter()
  
  if split == 'train':
    model.train()
  else:
    model.eval()
  bar = Bar('{}'.format(ref.category), max=len(loader))
  
  nViews = loader.dataset.nViews
  for i, (input, target, meta) in enumerate(loader):
    input_var = torch.autograd.Variable(input)
    target_var = torch.autograd.Variable(target)
    output = model(input_var)
    loss = ShapeConsistencyCriterion(nViews, supWeight = 1, unSupWeight = args.shapeWeight, M = M)(output, target_var, torch.autograd.Variable(meta))

    if split == 'test':
      for j in range(input.numpy().shape[0]):
        img = (input.numpy()[j] * 255).transpose(1, 2, 0).astype(np.uint8)
        cv2.imwrite('{}/img_{}/{}.png'.format(args.save_path, tag, i * input.numpy().shape[0] + j), img)
        gt = target.cpu().numpy()[j]
        pred = (output.data).cpu().numpy()[j]
        vis = meta.cpu().numpy()[j][5:]
        for t in range(ref.J):
          f.write('{} {} {} '.format(pred[t * 3], pred[t * 3 + 1], pred[t * 3 + 2]))
        f.write('\n')
        for t in range(ref.J):
          f.write('{} {} {} '.format(gt[t, 0], gt[t, 1], gt[t, 2]))
        f.write('\n')
        if args.saveVis:
          for t in range(ref.J):
            f.write('{} 0 0 '.format(vis[t]))
          f.write('\n')

    mpjpe_this = accuracy(output.data, target, meta)
    mpjpe_r_this = accuracy_dis(output.data, target, meta)
    shapeLoss = shapeConsistency(output.data, meta, nViews, M, split = split)

    losses.update(loss.data[0], input.size(0))
    shapeLosses.update(shapeLoss, input.size(0))
    mpjpe.update(mpjpe_this, input.size(0))
    mpjpe_r.update(mpjpe_r_this, input.size(0))
    
    if split == 'train':
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    
    Bar.suffix = '{split:10}: [{0:2}][{1:3}/{2:3}] | Total: {total:} | ETA: {eta:} | Loss {loss.avg:.6f} | shapeLoss {shapeLoss.avg:.6f} | AE {mpjpe.avg:.6f} | ShapeDis {mpjpe_r.avg:.6f}'.format(epoch, i, len(loader), total=bar.elapsed_td, eta=bar.eta_td, loss=losses, mpjpe=mpjpe, split = split, shapeLoss = shapeLosses, mpjpe_r = mpjpe_r)
    bar.next()
      
  bar.finish()
  return mpjpe.avg, losses.avg, shapeLosses.avg 
开发者ID:xingyizhou,项目名称:3DKeypoints-DA,代码行数:56,代码来源:train.py


注:本文中的utils.utils.AverageMeter方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。