当前位置: 首页>>代码示例>>Python>>正文


Python serialization.load_checkpoint方法代码示例

本文整理汇总了Python中reid.utils.serialization.load_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python serialization.load_checkpoint方法的具体用法?Python serialization.load_checkpoint怎么用?Python serialization.load_checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在reid.utils.serialization的用法示例。


在下文中一共展示了serialization.load_checkpoint方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from reid.utils import serialization [as 别名]
# 或者: from reid.utils.serialization import load_checkpoint [as 别名]
def main(args):
    cudnn.benchmark = True
    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    train_loader = get_loader(args.train_path, args.height, args.width, relabel=True,
                                   batch_size=args.batch_size, mode='train', num_workers=args.workers, name_pattern = args.name_pattern)

    gallery_loader = get_loader(args.gallery_path, args.height, args.width, relabel=False,
                                   batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)

    query_loader = get_loader(args.query_path, args.height, args.width, relabel=False,
                                   batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)

    # Create model
    model = DenseNet(num_feature=args.num_feature, num_classes=args.true_class, num_iteration = args.num_iteration)

    # Load from checkpoint
    start_epoch = args.start_epoch
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])

    model = nn.DataParallel(model).cuda()

    # Evaluator
    if args.evaluate:
        evaluator = Evaluator(model)
        print("Test:")
        evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
        return

    # Start training
    model= train(args, model, train_loader, start_epoch)
    save_checkpoint({'state_dict': model.module.state_dict()}, fpath=osp.join(args.logs_dir, 'model.pth.tar'))

    evaluator = Evaluator(model)
    print("Test:")
    evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature) 
开发者ID:Huang-3,项目名称:Celeb-reID,代码行数:42,代码来源:train.py

示例2: resume

# 需要导入模块: from reid.utils import serialization [as 别名]
# 或者: from reid.utils.serialization import load_checkpoint [as 别名]
def resume(self, ckpt_file, step):
        print("continued from step", step)
        model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode)
        self.model = nn.DataParallel(model).cuda()
        self.model.load_state_dict(load_checkpoint(ckpt_file)) 
开发者ID:Yu-Wu,项目名称:Exploit-Unknown-Gradually,代码行数:7,代码来源:eug.py


注:本文中的reid.utils.serialization.load_checkpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。