当前位置: 首页>>代码示例>>Python>>正文


Python Logger.close方法代码示例

本文整理汇总了Python中utils.logger.Logger.close方法的典型用法代码示例。如果您正苦于以下问题:Python Logger.close方法的具体用法?Python Logger.close怎么用?Python Logger.close使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.logger.Logger的用法示例。


在下文中一共展示了Logger.close方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from utils.logger import Logger [as 别名]
# 或者: from utils.logger.Logger import close [as 别名]
def main():
  now = datetime.datetime.now()
  logger = Logger(args.save_path + '/logs_{}'.format(now.isoformat()))

  model = getModel(args)
  cudnn.benchmark = True
  optimizer = torch.optim.SGD(model.parameters(), args.LR,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

  valSource_dataset = SourceDataset('test', ref.nValViews)
  valTarget_dataset = TargetDataset('test', ref.nValViews)
  
  valSource_loader = torch.utils.data.DataLoader(valSource_dataset, batch_size = 1, 
                        shuffle=False, num_workers=1, pin_memory=True, collate_fn=collate_fn_cat)
  valTarget_loader = torch.utils.data.DataLoader(valTarget_dataset, batch_size = 1, 
                        shuffle=False, num_workers=1, pin_memory=True, collate_fn=collate_fn_cat)
  
  if args.test:
    f = {}
    for split in splits:
      f['{}'.format(split)] = open('{}/{}.txt'.format(args.save_path, split), 'w')
    test(args, valSource_loader, model, None, f['valSource'], 'valSource')
    test(args, valTarget_loader, model, None, f['valTarget'], 'valTarget')
    return
  
  train_dataset = Fusion(SourceDataset, TargetDataset, nViews = args.nViews, targetRatio = args.targetRatio, totalTargetIm = args.totalTargetIm)
  trainTarget_dataset = train_dataset.targetDataset
  
  train_loader = torch.utils.data.DataLoader(
      train_dataset, batch_size=args.batchSize, shuffle=not args.test,
      num_workers=args.workers if not args.test else 1, pin_memory=True, collate_fn=collate_fn_cat)
  trainTarget_loader = torch.utils.data.DataLoader(
      trainTarget_dataset, batch_size=args.batchSize, shuffle=False,
      num_workers=args.workers if not args.test else 1, pin_memory=True, collate_fn=collate_fn_cat)

  M = None
  if args.shapeWeight > ref.eps:
    print 'getY...'
    Y = getY(train_dataset.sourceDataset)
    M = initLatent(trainTarget_loader, model, Y, nViews = args.nViews, S = args.sampleSource, AVG = args.AVG)
  
  print 'Start training...'
  for epoch in range(1, args.epochs + 1):
    adjust_learning_rate(optimizer, epoch, args.dropLR)
    train_mpjpe, train_loss, train_unSuploss = train(args, train_loader, model, optimizer, M, epoch)
    valSource_mpjpe, valSource_loss, valSource_unSuploss = validate(args, 'Source', valSource_loader, model, None, epoch)
    valTarget_mpjpe, valTarget_loss, valTarget_unSuploss = validate(args, 'Target', valTarget_loader, model, None, epoch)

    train_loader.dataset.targetDataset.shuffle()
    if args.shapeWeight > ref.eps and epoch % args.intervalUpdateM == 0:
      M = stepLatent(trainTarget_loader, model, M, Y, nViews = args.nViews, lamb = args.lamb, mu = args.mu, S = args.sampleSource)

    logger.write('{} {} {}\n'.format(train_mpjpe, valSource_mpjpe, valTarget_mpjpe))
    
    logger.scalar_summary('train_mpjpe', train_mpjpe, epoch)
    logger.scalar_summary('valSource_mpjpe', valSource_mpjpe, epoch)
    logger.scalar_summary('valTarget_mpjpe', valTarget_mpjpe, epoch)
    
    logger.scalar_summary('train_loss', train_loss, epoch)
    logger.scalar_summary('valSource_loss', valSource_loss, epoch)
    logger.scalar_summary('valTatget_loss', valTarget_loss, epoch)
    
    logger.scalar_summary('train_unSuploss', train_unSuploss, epoch)
    logger.scalar_summary('valSource_unSuploss', valSource_unSuploss, epoch)
    logger.scalar_summary('valTarget_unSuploss', valTarget_unSuploss, epoch)
    
    if epoch % 10 == 0:
      torch.save({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
      }, args.save_path + '/checkpoint_{}.pth.tar'.format(epoch))
  logger.close()
开发者ID:codealphago,项目名称:3DKeypoints-DA,代码行数:77,代码来源:main.py

示例2: main

# 需要导入模块: from utils.logger import Logger [as 别名]
# 或者: from utils.logger.Logger import close [as 别名]

#.........这里部分代码省略.........
    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)

    print("==> Creating model '{}-{}', stacks={}, blocks={}, feats={}".format(
        args.netType, args.pointType, args.nStacks, args.nModules, args.nFeats))

    print("=> Models will be saved at: {}".format(args.checkpoint))

    model = models.__dict__[args.netType](
        num_stacks=args.nStacks,
        num_blocks=args.nModules,
        num_feats=args.nFeats,
        use_se=args.use_se,
        use_attention=args.use_attention,
        num_classes=68)

    model = torch.nn.DataParallel(model).cuda()

    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    title = args.checkpoint.split('/')[-1] + ' on ' + args.data.split('/')[-1]

    Loader = get_loader(args.data)

    val_loader = torch.utils.data.DataLoader(
        Loader(args, 'A'),
        batch_size=args.val_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Valid Loss', 'Train Acc', 'Val Acc', 'AUC'])

    cudnn.benchmark = True
    print('=> Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / (1024. * 1024)))

    if args.evaluation:
        print('=> Evaluation only')
        D = args.data.split('/')[-1]
        save_dir = os.path.join(args.checkpoint, D)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        loss, acc, predictions, auc = validate(val_loader, model, criterion, args.netType,
                                                        args.debug, args.flip)
        save_pred(predictions, checkpoint=save_dir)
        return

    train_loader = torch.utils.data.DataLoader(
        Loader(args, 'train'),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('=> Epoch: %d | LR %.8f' % (epoch + 1, lr))

        train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.netType,
                                      args.debug, args.flip)
        # do not save predictions in model file
        valid_loss, valid_acc, predictions, valid_auc = validate(val_loader, model, criterion, args.netType,
                                                      args.debug, args.flip)

        logger.append([int(epoch + 1), lr, train_loss, valid_loss, train_acc, valid_acc, valid_auc])

        is_best = valid_auc >= best_auc
        best_auc = max(valid_auc, best_auc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'netType': args.netType,
                'state_dict': model.state_dict(),
                'best_acc': best_auc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            predictions,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot(['AUC'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
开发者ID:jiaxiangshang,项目名称:pyhowfar,代码行数:104,代码来源:main.py


注:本文中的utils.logger.Logger.close方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。