當前位置: 首頁>>代碼示例>>Python>>正文


Python serialization.load_checkpoint方法代碼示例

本文整理匯總了Python中reid.utils.serialization.load_checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python serialization.load_checkpoint方法的具體用法?Python serialization.load_checkpoint怎麽用?Python serialization.load_checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在reid.utils.serialization的用法示例。


在下文中一共展示了serialization.load_checkpoint方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: from reid.utils import serialization [as 別名]
# 或者: from reid.utils.serialization import load_checkpoint [as 別名]
def main(args):
    cudnn.benchmark = True
    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    train_loader = get_loader(args.train_path, args.height, args.width, relabel=True,
                                   batch_size=args.batch_size, mode='train', num_workers=args.workers, name_pattern = args.name_pattern)

    gallery_loader = get_loader(args.gallery_path, args.height, args.width, relabel=False,
                                   batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)

    query_loader = get_loader(args.query_path, args.height, args.width, relabel=False,
                                   batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)

    # Create model
    model = DenseNet(num_feature=args.num_feature, num_classes=args.true_class, num_iteration = args.num_iteration)

    # Load from checkpoint
    start_epoch = args.start_epoch
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])

    model = nn.DataParallel(model).cuda()

    # Evaluator
    if args.evaluate:
        evaluator = Evaluator(model)
        print("Test:")
        evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
        return

    # Start training
    model= train(args, model, train_loader, start_epoch)
    save_checkpoint({'state_dict': model.module.state_dict()}, fpath=osp.join(args.logs_dir, 'model.pth.tar'))

    evaluator = Evaluator(model)
    print("Test:")
    evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature) 
開發者ID:Huang-3,項目名稱:Celeb-reID,代碼行數:42,代碼來源:train.py

示例2: resume

# 需要導入模塊: from reid.utils import serialization [as 別名]
# 或者: from reid.utils.serialization import load_checkpoint [as 別名]
def resume(self, ckpt_file, step):
        print("continued from step", step)
        model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode)
        self.model = nn.DataParallel(model).cuda()
        self.model.load_state_dict(load_checkpoint(ckpt_file)) 
開發者ID:Yu-Wu,項目名稱:Exploit-Unknown-Gradually,代碼行數:7,代碼來源:eug.py


注:本文中的reid.utils.serialization.load_checkpoint方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。