当前位置: 首页>>代码示例>>Python>>正文


Python Logger.scalar_summary方法代码示例

本文整理汇总了Python中utils.logger.Logger.scalar_summary方法的典型用法代码示例。如果您正苦于以下问题:Python Logger.scalar_summary方法的具体用法?Python Logger.scalar_summary怎么用?Python Logger.scalar_summary使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.logger.Logger的用法示例。


在下文中一共展示了Logger.scalar_summary方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from utils.logger import Logger [as 别名]
# 或者: from utils.logger.Logger import scalar_summary [as 别名]
def main():
  now = datetime.datetime.now()
  logger = Logger(args.save_path + '/logs_{}'.format(now.isoformat()))

  model = getModel(args)
  cudnn.benchmark = True
  optimizer = torch.optim.SGD(model.parameters(), args.LR,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

  valSource_dataset = SourceDataset('test', ref.nValViews)
  valTarget_dataset = TargetDataset('test', ref.nValViews)
  
  valSource_loader = torch.utils.data.DataLoader(valSource_dataset, batch_size = 1, 
                        shuffle=False, num_workers=1, pin_memory=True, collate_fn=collate_fn_cat)
  valTarget_loader = torch.utils.data.DataLoader(valTarget_dataset, batch_size = 1, 
                        shuffle=False, num_workers=1, pin_memory=True, collate_fn=collate_fn_cat)
  
  if args.test:
    f = {}
    for split in splits:
      f['{}'.format(split)] = open('{}/{}.txt'.format(args.save_path, split), 'w')
    test(args, valSource_loader, model, None, f['valSource'], 'valSource')
    test(args, valTarget_loader, model, None, f['valTarget'], 'valTarget')
    return
  
  train_dataset = Fusion(SourceDataset, TargetDataset, nViews = args.nViews, targetRatio = args.targetRatio, totalTargetIm = args.totalTargetIm)
  trainTarget_dataset = train_dataset.targetDataset
  
  train_loader = torch.utils.data.DataLoader(
      train_dataset, batch_size=args.batchSize, shuffle=not args.test,
      num_workers=args.workers if not args.test else 1, pin_memory=True, collate_fn=collate_fn_cat)
  trainTarget_loader = torch.utils.data.DataLoader(
      trainTarget_dataset, batch_size=args.batchSize, shuffle=False,
      num_workers=args.workers if not args.test else 1, pin_memory=True, collate_fn=collate_fn_cat)

  M = None
  if args.shapeWeight > ref.eps:
    print 'getY...'
    Y = getY(train_dataset.sourceDataset)
    M = initLatent(trainTarget_loader, model, Y, nViews = args.nViews, S = args.sampleSource, AVG = args.AVG)
  
  print 'Start training...'
  for epoch in range(1, args.epochs + 1):
    adjust_learning_rate(optimizer, epoch, args.dropLR)
    train_mpjpe, train_loss, train_unSuploss = train(args, train_loader, model, optimizer, M, epoch)
    valSource_mpjpe, valSource_loss, valSource_unSuploss = validate(args, 'Source', valSource_loader, model, None, epoch)
    valTarget_mpjpe, valTarget_loss, valTarget_unSuploss = validate(args, 'Target', valTarget_loader, model, None, epoch)

    train_loader.dataset.targetDataset.shuffle()
    if args.shapeWeight > ref.eps and epoch % args.intervalUpdateM == 0:
      M = stepLatent(trainTarget_loader, model, M, Y, nViews = args.nViews, lamb = args.lamb, mu = args.mu, S = args.sampleSource)

    logger.write('{} {} {}\n'.format(train_mpjpe, valSource_mpjpe, valTarget_mpjpe))
    
    logger.scalar_summary('train_mpjpe', train_mpjpe, epoch)
    logger.scalar_summary('valSource_mpjpe', valSource_mpjpe, epoch)
    logger.scalar_summary('valTarget_mpjpe', valTarget_mpjpe, epoch)
    
    logger.scalar_summary('train_loss', train_loss, epoch)
    logger.scalar_summary('valSource_loss', valSource_loss, epoch)
    logger.scalar_summary('valTatget_loss', valTarget_loss, epoch)
    
    logger.scalar_summary('train_unSuploss', train_unSuploss, epoch)
    logger.scalar_summary('valSource_unSuploss', valSource_unSuploss, epoch)
    logger.scalar_summary('valTarget_unSuploss', valTarget_unSuploss, epoch)
    
    if epoch % 10 == 0:
      torch.save({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
      }, args.save_path + '/checkpoint_{}.pth.tar'.format(epoch))
  logger.close()
开发者ID:codealphago,项目名称:3DKeypoints-DA,代码行数:77,代码来源:main.py


注:本文中的utils.logger.Logger.scalar_summary方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。