本文整理汇总了Python中trainer.train方法的典型用法代码示例。如果您正苦于以下问题:Python trainer.train方法的具体用法?Python trainer.train怎么用?Python trainer.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类trainer
的用法示例。
在下文中一共展示了trainer.train方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def main():
# logging configuration
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s]: %(levelname)s: %(message)s"
)
# command line paser
opt = parse.parse_arg()
# GPU
opt.cuda = opt.gpuid >= 0
if opt.gpuid >= 0:
torch.cuda.set_device(opt.gpuid)
else:
logging.info("WARNING: RUN WITHOUT GPU")
# prepare dataset
db = dataset.prepare_db(opt)
# initalize neural decision forest
NDF = model.prepare_model(opt)
# prepare optimizer
optim, sche = optimizer.prepare_optim(NDF, opt)
# train the neural decision forest
best_metric = trainer.train(NDF, optim, sche, db, opt)
logging.info('The best evaluation metric is %f'%best_metric)
示例2: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def main(train, classify, help):
if (help):
print(help_message)
sys.exit(0)
else:
if (train):
iteration = click.prompt('Iteration count for training model', type=int)
trainer.train(num_iteration=iteration)
else:
image_file_path = click.prompt('Image file path that is going to be classified', type=str)
classifier.classify(file_path=image_file_path)
示例3: main
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def main():
args = arguments()
num_templates = 25 # aka the number of clusters
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
img_transforms = transforms.Compose([
transforms.ToTensor(),
normalize
])
train_loader, _ = get_dataloader(args.traindata, args, num_templates,
img_transforms=img_transforms)
model = DetectionModel(num_objects=1, num_templates=num_templates)
loss_fn = DetectionCriterion(num_templates)
# directory where we'll store model weights
weights_dir = "weights"
if not osp.exists(weights_dir):
os.mkdir(weights_dir)
# check for CUDA
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
optimizer = optim.SGD(model.learnable_parameters(args.lr), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
# optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
# Set the start epoch if it has not been
if not args.start_epoch:
args.start_epoch = checkpoint['epoch']
scheduler = optim.lr_scheduler.StepLR(optimizer,
step_size=20,
last_epoch=args.start_epoch-1)
# train and evalute for `epochs`
for epoch in range(args.start_epoch, args.epochs):
trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device)
scheduler.step()
if (epoch+1) % args.save_every == 0:
trainer.save_checkpoint({
'epoch': epoch + 1,
'batch_size': train_loader.batch_size,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}, filename="checkpoint_{0}.pth".format(epoch+1), save_path=weights_dir)
示例4: train
# 需要导入模块: import trainer [as 别名]
# 或者: from trainer import train [as 别名]
def train(self, generator_train, X_train, generator_val, X_val):
#filenames_train, filenames_val = patch_sampling.get_filenames()
#generator = partial(patch_sampling.extract_random_patches, patch_size=P.INPUT_SIZE, crop_size=OUTPUT_SIZE)
train_true = filter(lambda x: "True" in x, X_train)
train_false = filter(lambda x: "False" in x, X_train)
print "N train true/false", len(train_true), len(train_false)
print X_train[:2]
val_true = filter(lambda x: "True" in x, X_val)
val_false = filter(lambda x: "False" in x, X_val)
n_train_true = len(train_true)
n_val_true = len(val_true)
logging.info("Starting training...")
for epoch in range(P.N_EPOCHS):
self.pre_epoch()
if epoch in LR_SCHEDULE:
logging.info("Setting learning rate to {}".format(LR_SCHEDULE[epoch]))
self.l_r.set_value(LR_SCHEDULE[epoch])
np.random.shuffle(train_false)
np.random.shuffle(val_false)
train_epoch_data = train_true + train_false[:n_train_true]
val_epoch_data = val_true + val_false[:n_val_true]
np.random.shuffle(train_epoch_data)
#np.random.shuffle(val_epoch_data)
#Full pass over the training data
train_gen = ParallelBatchIterator(generator_train, train_epoch_data, ordered=False,
batch_size=P.BATCH_SIZE_TRAIN//3,
multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION,
n_producers=P.N_WORKERS_LOAD_AUGMENTATION)
self.do_batches(self.train_fn, train_gen, self.train_metrics)
# And a full pass over the validation data:
val_gen = ParallelBatchIterator(generator_val, val_epoch_data, ordered=False,
batch_size=P.BATCH_SIZE_VALIDATION//3,
multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION,
n_producers=P.N_WORKERS_LOAD_AUGMENTATION)
self.do_batches(self.val_fn, val_gen, self.val_metrics)
self.post_epoch()