本文整理匯總了Python中reid.utils.serialization.load_checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python serialization.load_checkpoint方法的具體用法?Python serialization.load_checkpoint怎麽用?Python serialization.load_checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類reid.utils.serialization
的用法示例。
在下文中一共展示了serialization.load_checkpoint方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: main
# 需要導入模塊: from reid.utils import serialization [as 別名]
# 或者: from reid.utils.serialization import load_checkpoint [as 別名]
def main(args):
cudnn.benchmark = True
# Redirect print to both console and log file
if not args.evaluate:
sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
train_loader = get_loader(args.train_path, args.height, args.width, relabel=True,
batch_size=args.batch_size, mode='train', num_workers=args.workers, name_pattern = args.name_pattern)
gallery_loader = get_loader(args.gallery_path, args.height, args.width, relabel=False,
batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)
query_loader = get_loader(args.query_path, args.height, args.width, relabel=False,
batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)
# Create model
model = DenseNet(num_feature=args.num_feature, num_classes=args.true_class, num_iteration = args.num_iteration)
# Load from checkpoint
start_epoch = args.start_epoch
if args.resume:
checkpoint = load_checkpoint(args.resume)
model.load_state_dict(checkpoint['state_dict'])
model = nn.DataParallel(model).cuda()
# Evaluator
if args.evaluate:
evaluator = Evaluator(model)
print("Test:")
evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
return
# Start training
model= train(args, model, train_loader, start_epoch)
save_checkpoint({'state_dict': model.module.state_dict()}, fpath=osp.join(args.logs_dir, 'model.pth.tar'))
evaluator = Evaluator(model)
print("Test:")
evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
示例2: resume
# 需要導入模塊: from reid.utils import serialization [as 別名]
# 或者: from reid.utils.serialization import load_checkpoint [as 別名]
def resume(self, ckpt_file, step):
print("continued from step", step)
model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode)
self.model = nn.DataParallel(model).cuda()
self.model.load_state_dict(load_checkpoint(ckpt_file))