当前位置: 首页>>代码示例>>Python>>正文


Python trainer.train方法代码示例

本文整理汇总了Python中trainer.train方法的典型用法代码示例。如果您正苦于以下问题:Python trainer.train方法的具体用法?Python trainer.train怎么用?Python trainer.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在trainer的用法示例。


在下文中一共展示了trainer.train方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def main():
    # logging configuration
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s]: %(levelname)s: %(message)s"
    )
    
    # command line paser
    opt = parse.parse_arg()

    # GPU
    opt.cuda = opt.gpuid >= 0
    if opt.gpuid >= 0:
        torch.cuda.set_device(opt.gpuid)
    else:
        logging.info("WARNING: RUN WITHOUT GPU")
    
    # prepare dataset    
    db = dataset.prepare_db(opt)
    
    # initalize neural decision forest
    NDF = model.prepare_model(opt)
    
    # prepare optimizer
    optim, sche = optimizer.prepare_optim(NDF, opt)
    
    # train the neural decision forest
    best_metric = trainer.train(NDF, optim, sche, db, opt)
    logging.info('The best evaluation metric is %f'%best_metric) 
开发者ID:Nicholasli1995,项目名称:VisualizingNDF,代码行数:31,代码来源:main.py

示例2: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def main(train, classify, help):
    if (help):
        print(help_message)
        sys.exit(0)
    else:
        if (train):
            iteration = click.prompt('Iteration count for training model', type=int)
            trainer.train(num_iteration=iteration)
        else:
            image_file_path = click.prompt('Image file path that is going to be classified', type=str)
            classifier.classify(file_path=image_file_path) 
开发者ID:abdullahselek,项目名称:plant-disease-classification,代码行数:13,代码来源:__main__.py

示例3: main

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def main():
    args = arguments()

    num_templates = 25  # aka the number of clusters

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    img_transforms = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    train_loader, _ = get_dataloader(args.traindata, args, num_templates,
                                     img_transforms=img_transforms)

    model = DetectionModel(num_objects=1, num_templates=num_templates)
    loss_fn = DetectionCriterion(num_templates)

    # directory where we'll store model weights
    weights_dir = "weights"
    if not osp.exists(weights_dir):
        os.mkdir(weights_dir)

    # check for CUDA
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    optimizer = optim.SGD(model.learnable_parameters(args.lr), lr=args.lr,
                          momentum=args.momentum, weight_decay=args.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        # Set the start epoch if it has not been
        if not args.start_epoch:
            args.start_epoch = checkpoint['epoch']

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=20,
                                          last_epoch=args.start_epoch-1)

    # train and evalute for `epochs`
    for epoch in range(args.start_epoch, args.epochs):
        trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device)
        scheduler.step()

        if (epoch+1) % args.save_every == 0:
            trainer.save_checkpoint({
                'epoch': epoch + 1,
                'batch_size': train_loader.batch_size,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, filename="checkpoint_{0}.pth".format(epoch+1), save_path=weights_dir) 
开发者ID:varunagrawal,项目名称:tiny-faces-pytorch,代码行数:58,代码来源:main.py

示例4: train

# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def train(self, generator_train, X_train, generator_val, X_val):
        #filenames_train, filenames_val = patch_sampling.get_filenames()
        #generator = partial(patch_sampling.extract_random_patches, patch_size=P.INPUT_SIZE, crop_size=OUTPUT_SIZE)


        train_true = filter(lambda x: "True" in x, X_train)
        train_false = filter(lambda x: "False" in x, X_train)

        print "N train true/false", len(train_true), len(train_false)
        print X_train[:2]

        val_true = filter(lambda x: "True" in x, X_val)
        val_false = filter(lambda x: "False" in x, X_val)

        n_train_true = len(train_true)
        n_val_true = len(val_true)

        logging.info("Starting training...")
        for epoch in range(P.N_EPOCHS):
            self.pre_epoch()

            if epoch in LR_SCHEDULE:
                logging.info("Setting learning rate to {}".format(LR_SCHEDULE[epoch]))
                self.l_r.set_value(LR_SCHEDULE[epoch])


            np.random.shuffle(train_false)
            np.random.shuffle(val_false)

            train_epoch_data = train_true + train_false[:n_train_true]
            val_epoch_data = val_true + val_false[:n_val_true]

            np.random.shuffle(train_epoch_data)
            #np.random.shuffle(val_epoch_data)

            #Full pass over the training data
            train_gen = ParallelBatchIterator(generator_train, train_epoch_data, ordered=False,
                                                batch_size=P.BATCH_SIZE_TRAIN//3,
                                                multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION,
                                                n_producers=P.N_WORKERS_LOAD_AUGMENTATION)

            self.do_batches(self.train_fn, train_gen, self.train_metrics)

            # And a full pass over the validation data:
            val_gen = ParallelBatchIterator(generator_val, val_epoch_data, ordered=False,
                                                batch_size=P.BATCH_SIZE_VALIDATION//3,
                                                multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION,
                                                n_producers=P.N_WORKERS_LOAD_AUGMENTATION)

            self.do_batches(self.val_fn, val_gen, self.val_metrics)
            self.post_epoch() 
开发者ID:gzuidhof,项目名称:luna16,代码行数:53,代码来源:resnet_trainer.py


注:本文中的trainer.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。