当前位置: 首页>>代码示例>>Python>>正文


Python Model.load_state_dict方法代码示例

本文整理汇总了Python中model.Model.load_state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python Model.load_state_dict方法的具体用法?Python Model.load_state_dict怎么用?Python Model.load_state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在model.Model的用法示例。


在下文中一共展示了Model.load_state_dict方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from model import Model [as 别名]
# 或者: from model.Model import load_state_dict [as 别名]
def main():
    global args, best_prec1, best_loss
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # create model
    #model = torch.nn.DataParallel(Model()).cuda()
    model = Model().cuda()

    #inputs = torch.autograd.Variable(torch.randn(2, 3, 512, 512))
    #model = Model()
    #outputs = model(inputs)
    #print(outputs.size())
    #exit(0)

    #model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint (epoch {})"
                  .format(checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    k = 4
    n = 2*k + 1
    args.arch = 'alex'
    args.data = '/home/thuyen/Research/pupil/input/'

    valdir = args.data

    df = pd.read_csv('valid_info.csv')

    valid_loader = torch.utils.data.DataLoader(
        ImageList(df, valdir, for_train=False),
        batch_size=16, shuffle=False,
        num_workers=args.workers, pin_memory=True)


    outputs = []
    for j, (input, target) in enumerate(valid_loader):

        input_var = torch.autograd.Variable(input.cuda(), volatile=True)
        output_var = model(input_var)
        outputs.append(output_var.data.cpu().numpy())
        #outputs.append(output_var.data.cpu().numpy() > 0.5)
    outputs = np.concatenate(outputs)
    np.save('preds_raw.npy', outputs)
开发者ID:thuyen,项目名称:test,代码行数:59,代码来源:predict.py

示例2: main

# 需要导入模块: from model import Model [as 别名]
# 或者: from model.Model import load_state_dict [as 别名]
def main():
    net = Model(num_class, args.test_segments, args.representation,
                base_model=args.arch)

    checkpoint = torch.load(args.weights)
    print("model epoch {} best [email protected]: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))

    base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
    net.load_state_dict(base_dict)

    if args.test_crops == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(net.crop_size),
        ])
    elif args.test_crops == 10:
        cropping = torchvision.transforms.Compose([
            GroupOverSample(net.crop_size, net.scale_size, is_mv=(args.representation == 'mv'))
        ])
    else:
        raise ValueError("Only 1 and 10 crops are supported, but got {}.".format(args.test_crops))

    data_loader = torch.utils.data.DataLoader(
        CoviarDataSet(
            args.data_root,
            args.data_name,
            video_list=args.test_list,
            num_segments=args.test_segments,
            representation=args.representation,
            transform=cropping,
            is_train=False,
            accumulate=(not args.no_accumulation),
            ),
        batch_size=1, shuffle=False,
        num_workers=args.workers * 2, pin_memory=True)

    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
    else:
        devices = list(range(args.workers))

    net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
    net.eval()

    data_gen = enumerate(data_loader)

    total_num = len(data_loader.dataset)
    output = []

    def forward_video(data):
        input_var = torch.autograd.Variable(data, volatile=True)
        scores = net(input_var)
        scores = scores.view((-1, args.test_segments * args.test_crops) + scores.size()[1:])
        scores = torch.mean(scores, dim=1)
        return scores.data.cpu().numpy().copy()


    proc_start_time = time.time()


    for i, (data, label) in data_gen:
        video_scores = forward_video(data)
        output.append((video_scores, label[0]))
        cnt_time = time.time() - proc_start_time
        if (i + 1) % 100 == 0:
            print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1,
                                                                            total_num,
                                                                            float(cnt_time) / (i+1)))

    video_pred = [np.argmax(x[0]) for x in output]
    video_labels = [x[1] for x in output]

    print('Accuracy {:.02f}% ({})'.format(
        float(np.sum(np.array(video_pred) == np.array(video_labels))) / len(video_pred) * 100.0,
        len(video_pred)))


    if args.save_scores is not None:

        name_list = [x.strip().split()[0] for x in open(args.test_list)]
        order_dict = {e:i for i, e in enumerate(sorted(name_list))}

        reorder_output = [None] * len(output)
        reorder_label = [None] * len(output)
        reorder_name = [None] * len(output)

        for i in range(len(output)):
            idx = order_dict[name_list[i]]
            reorder_output[idx] = output[i]
            reorder_label[idx] = video_labels[i]
            reorder_name[idx] = name_list[i]

        np.savez(args.save_scores, scores=reorder_output, labels=reorder_label, names=reorder_name)
开发者ID:baiyancheng20,项目名称:pytorch-coviar,代码行数:95,代码来源:test.py

示例3: main

# 需要导入模块: from model import Model [as 别名]
# 或者: from model.Model import load_state_dict [as 别名]
def main():
    global args, best_loss
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    torch.manual_seed(args.seed)
    if not os.path.exists(args.ckpts):
        os.makedirs(args.ckpts)

    # create model
    model = Model().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        df = pd.read_csv(args.valid_list)
        valid_loader = torch.utils.data.DataLoader(
            ImageList(df, args.data, for_train=False),
            batch_size=16, shuffle=False,
            num_workers=args.workers, pin_memory=True)

        outputs = []
        for j, (input, target) in enumerate(valid_loader):

            input_var = torch.autograd.Variable(input.cuda(), volatile=True)
            output_var = model(input_var)
            outputs.append(output_var.data.cpu().numpy())
            #outputs.append(output_var.data.cpu().numpy() > 0.5)
        outputs = np.concatenate(outputs)
        np.save(args.out_file, outputs)
        return



    df = pd.read_csv(args.train_list)

    train_loader = torch.utils.data.DataLoader(
        ImageList(df, args.data, for_train=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)


    eps = 1e-2

    def criterion(x, y):
        num = 2*(x*y).sum() + eps
        den = x.sum() + y.sum() + eps
        return -num/den


    optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                weight_decay=args.weight_decay)


    logging.info('-------------- New training session, LR = %f ----------------' % (args.lr, ))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch) # adam, same lr

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)

        ## evaluate on validation set
        #valid_loss = validate(valid_loader, model, criterion)

        is_best = False
        filename = os.path.join(args.ckpts, 'model_{}.pth.tar'.format(epoch+1))
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict()
        }, is_best, filename=filename)

        msg = 'Epoch: {0:02d} Train loss {1:.4f}'.format(epoch+1, train_loss)
        logging.info(msg)
开发者ID:thuyen,项目名称:test,代码行数:88,代码来源:main.py


注:本文中的model.Model.load_state_dict方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。