本文整理汇总了Python中reid.utils.serialization.load_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python serialization.load_checkpoint方法的具体用法?Python serialization.load_checkpoint怎么用?Python serialization.load_checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类reid.utils.serialization
的用法示例。
在下文中一共展示了serialization.load_checkpoint方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from reid.utils import serialization [as 别名]
# 或者: from reid.utils.serialization import load_checkpoint [as 别名]
def main(args):
cudnn.benchmark = True
# Redirect print to both console and log file
if not args.evaluate:
sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
train_loader = get_loader(args.train_path, args.height, args.width, relabel=True,
batch_size=args.batch_size, mode='train', num_workers=args.workers, name_pattern = args.name_pattern)
gallery_loader = get_loader(args.gallery_path, args.height, args.width, relabel=False,
batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)
query_loader = get_loader(args.query_path, args.height, args.width, relabel=False,
batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)
# Create model
model = DenseNet(num_feature=args.num_feature, num_classes=args.true_class, num_iteration = args.num_iteration)
# Load from checkpoint
start_epoch = args.start_epoch
if args.resume:
checkpoint = load_checkpoint(args.resume)
model.load_state_dict(checkpoint['state_dict'])
model = nn.DataParallel(model).cuda()
# Evaluator
if args.evaluate:
evaluator = Evaluator(model)
print("Test:")
evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
return
# Start training
model= train(args, model, train_loader, start_epoch)
save_checkpoint({'state_dict': model.module.state_dict()}, fpath=osp.join(args.logs_dir, 'model.pth.tar'))
evaluator = Evaluator(model)
print("Test:")
evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
示例2: resume
# 需要导入模块: from reid.utils import serialization [as 别名]
# 或者: from reid.utils.serialization import load_checkpoint [as 别名]
def resume(self, ckpt_file, step):
print("continued from step", step)
model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode)
self.model = nn.DataParallel(model).cuda()
self.model.load_state_dict(load_checkpoint(ckpt_file))