本文整理汇总了Python中model.Model.load_state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python Model.load_state_dict方法的具体用法?Python Model.load_state_dict怎么用?Python Model.load_state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类model.Model
的用法示例。
在下文中一共展示了Model.load_state_dict方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from model import Model [as 别名]
# 或者: from model.Model import load_state_dict [as 别名]
def main():
global args, best_prec1, best_loss
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# create model
#model = torch.nn.DataParallel(Model()).cuda()
model = Model().cuda()
#inputs = torch.autograd.Variable(torch.randn(2, 3, 512, 512))
#model = Model()
#outputs = model(inputs)
#print(outputs.size())
#exit(0)
#model = model.cuda()
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint (epoch {})"
.format(checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
# Data loading code
k = 4
n = 2*k + 1
args.arch = 'alex'
args.data = '/home/thuyen/Research/pupil/input/'
valdir = args.data
df = pd.read_csv('valid_info.csv')
valid_loader = torch.utils.data.DataLoader(
ImageList(df, valdir, for_train=False),
batch_size=16, shuffle=False,
num_workers=args.workers, pin_memory=True)
outputs = []
for j, (input, target) in enumerate(valid_loader):
input_var = torch.autograd.Variable(input.cuda(), volatile=True)
output_var = model(input_var)
outputs.append(output_var.data.cpu().numpy())
#outputs.append(output_var.data.cpu().numpy() > 0.5)
outputs = np.concatenate(outputs)
np.save('preds_raw.npy', outputs)
示例2: main
# 需要导入模块: from model import Model [as 别名]
# 或者: from model.Model import load_state_dict [as 别名]
def main():
net = Model(num_class, args.test_segments, args.representation,
base_model=args.arch)
checkpoint = torch.load(args.weights)
print("model epoch {} best [email protected]: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
net.load_state_dict(base_dict)
if args.test_crops == 1:
cropping = torchvision.transforms.Compose([
GroupScale(net.scale_size),
GroupCenterCrop(net.crop_size),
])
elif args.test_crops == 10:
cropping = torchvision.transforms.Compose([
GroupOverSample(net.crop_size, net.scale_size, is_mv=(args.representation == 'mv'))
])
else:
raise ValueError("Only 1 and 10 crops are supported, but got {}.".format(args.test_crops))
data_loader = torch.utils.data.DataLoader(
CoviarDataSet(
args.data_root,
args.data_name,
video_list=args.test_list,
num_segments=args.test_segments,
representation=args.representation,
transform=cropping,
is_train=False,
accumulate=(not args.no_accumulation),
),
batch_size=1, shuffle=False,
num_workers=args.workers * 2, pin_memory=True)
if args.gpus is not None:
devices = [args.gpus[i] for i in range(args.workers)]
else:
devices = list(range(args.workers))
net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
net.eval()
data_gen = enumerate(data_loader)
total_num = len(data_loader.dataset)
output = []
def forward_video(data):
input_var = torch.autograd.Variable(data, volatile=True)
scores = net(input_var)
scores = scores.view((-1, args.test_segments * args.test_crops) + scores.size()[1:])
scores = torch.mean(scores, dim=1)
return scores.data.cpu().numpy().copy()
proc_start_time = time.time()
for i, (data, label) in data_gen:
video_scores = forward_video(data)
output.append((video_scores, label[0]))
cnt_time = time.time() - proc_start_time
if (i + 1) % 100 == 0:
print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1,
total_num,
float(cnt_time) / (i+1)))
video_pred = [np.argmax(x[0]) for x in output]
video_labels = [x[1] for x in output]
print('Accuracy {:.02f}% ({})'.format(
float(np.sum(np.array(video_pred) == np.array(video_labels))) / len(video_pred) * 100.0,
len(video_pred)))
if args.save_scores is not None:
name_list = [x.strip().split()[0] for x in open(args.test_list)]
order_dict = {e:i for i, e in enumerate(sorted(name_list))}
reorder_output = [None] * len(output)
reorder_label = [None] * len(output)
reorder_name = [None] * len(output)
for i in range(len(output)):
idx = order_dict[name_list[i]]
reorder_output[idx] = output[i]
reorder_label[idx] = video_labels[i]
reorder_name[idx] = name_list[i]
np.savez(args.save_scores, scores=reorder_output, labels=reorder_label, names=reorder_name)
示例3: main
# 需要导入模块: from model import Model [as 别名]
# 或者: from model.Model import load_state_dict [as 别名]
def main():
global args, best_loss
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.manual_seed(args.seed)
if not os.path.exists(args.ckpts):
os.makedirs(args.ckpts)
# create model
model = Model().cuda()
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
if args.evaluate:
df = pd.read_csv(args.valid_list)
valid_loader = torch.utils.data.DataLoader(
ImageList(df, args.data, for_train=False),
batch_size=16, shuffle=False,
num_workers=args.workers, pin_memory=True)
outputs = []
for j, (input, target) in enumerate(valid_loader):
input_var = torch.autograd.Variable(input.cuda(), volatile=True)
output_var = model(input_var)
outputs.append(output_var.data.cpu().numpy())
#outputs.append(output_var.data.cpu().numpy() > 0.5)
outputs = np.concatenate(outputs)
np.save(args.out_file, outputs)
return
df = pd.read_csv(args.train_list)
train_loader = torch.utils.data.DataLoader(
ImageList(df, args.data, for_train=True),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
eps = 1e-2
def criterion(x, y):
num = 2*(x*y).sum() + eps
den = x.sum() + y.sum() + eps
return -num/den
optimizer = torch.optim.Adam(model.parameters(), args.lr,
weight_decay=args.weight_decay)
logging.info('-------------- New training session, LR = %f ----------------' % (args.lr, ))
for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch) # adam, same lr
# train for one epoch
train_loss = train(train_loader, model, criterion, optimizer, epoch)
## evaluate on validation set
#valid_loss = validate(valid_loader, model, criterion)
is_best = False
filename = os.path.join(args.ckpts, 'model_{}.pth.tar'.format(epoch+1))
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict()
}, is_best, filename=filename)
msg = 'Epoch: {0:02d} Train loss {1:.4f}'.format(epoch+1, train_loss)
logging.info(msg)