本文整理汇总了Python中options.train_options.TrainOptions方法的典型用法代码示例。如果您正苦于以下问题:Python train_options.TrainOptions方法的具体用法?Python train_options.TrainOptions怎么用?Python train_options.TrainOptions使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类options.train_options
的用法示例。
在下文中一共展示了train_options.TrainOptions方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from options import train_options [as 别名]
# 或者: from options.train_options import TrainOptions [as 别名]
def __init__(self):
self._opt = TrainOptions().parse()
data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)
self._dataset_train = data_loader_train.load_data()
self._dataset_test = data_loader_test.load_data()
self._dataset_train_size = len(data_loader_train)
self._dataset_test_size = len(data_loader_test)
print('#train images = %d' % self._dataset_train_size)
print('#test images = %d' % self._dataset_test_size)
self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
self._tb_visualizer = TBVisualizer(self._opt)
self._train()
示例2: main
# 需要导入模块: from options import train_options [as 别名]
# 或者: from options.train_options import TrainOptions [as 别名]
def main():
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset_size = len(data_loader) * opt.batch_size
visualizer = Visualizer(opt)
model = create_model(opt)
start_epoch = model.start_epoch
total_steps = start_epoch*dataset_size
for epoch in range(start_epoch+1, opt.niter+opt.niter_decay+1):
epoch_start_time = time.time()
model.update_lr()
save_result = True
for i, data in enumerate(data_loader):
iter_start_time = time.time()
total_steps += opt.batch_size
epoch_iter = total_steps - dataset_size * (epoch - 1)
model.prepare_data(data)
model.update_model()
if save_result or total_steps % opt.display_freq == 0:
save_result = save_result or total_steps % opt.update_html_freq == 0
visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
save_result = False
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
print('epoch {} cost dime {}'.format(epoch,time.time()-epoch_start_time))
model.save_ckpt(epoch)
model.save_generator('latest')
if epoch % opt.save_epoch_freq == 0:
print('saving the generator at the end of epoch {}, iters {}'.format(epoch, total_steps))
model.save_generator(epoch)
示例3: main
# 需要导入模块: from options import train_options [as 别名]
# 或者: from options.train_options import TrainOptions [as 别名]
def main():
opt = TrainOptions().parse()
if opt.sr_dir == '':
print('sr directory is null.')
exit()
sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
if not os.path.isdir(sr_pretrain_dir):
os.makedirs(sr_pretrain_dir)
train_history = ASNTrainHistory()
# print(train_history.lr)
# exit()
checkpoint_hg = Checkpoint()
# visualizer = Visualizer(opt)
# log_name = opt.resume_prefix_pose + 'log.txt'
# visualizer.log_path = sr_pretrain_dir + '/' + log_name
train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
# train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
# train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
# val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
# val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
if opt.dataset == 'mpii':
num_classes = 16
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
hg = model.create_hg(num_stacks=2, num_modules=1,
num_classes=num_classes, chan=256)
hg = torch.nn.DataParallel(hg).cuda()
if opt.load_prefix_pose == '':
print('please input the checkpoint name of the pose model')
# exit()
# checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
opt.load_prefix_pose)[0:-1]
checkpoint_hg.load_checkpoint(hg)
print 'collecting training distributions ...\n'
train_distri_list = collect_train_valid_data(train_distri_path,
train_distri_path_2, hg, opt, is_train=True)
print 'collecting validation distributions ...\n'
val_distri_list = collect_train_valid_data(val_distri_path,
val_distri_path_2, hg, opt, is_train=False)
示例4: main
# 需要导入模块: from options import train_options [as 别名]
# 或者: from options.train_options import TrainOptions [as 别名]
def main():
opt = TrainOptions().parse()
if opt.sr_dir == '':
print('sr directory is null.')
exit()
sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
if not os.path.isdir(sr_pretrain_dir):
os.makedirs(sr_pretrain_dir)
# train_history = ASNTrainHistory()
# print(train_history.lr)
# exit()
checkpoint_hg = Checkpoint()
# visualizer = Visualizer(opt)
# log_name = opt.resume_prefix_pose + 'log.txt'
# visualizer.log_path = sr_pretrain_dir + '/' + log_name
train_distri_path = sr_pretrain_dir + '/' + 'train_scales.txt'
train_distri_path_2 = sr_pretrain_dir + '/' + 'train_scales_copy.txt'
# train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
# train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
val_distri_path = sr_pretrain_dir + '/' + 'val_scales.txt'
val_distri_path_2 = sr_pretrain_dir + '/' + 'val_scales_copy.txt'
# val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
# val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
if opt.dataset == 'mpii':
num_classes = 16
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
hg = model.create_hg(num_stacks=2, num_modules=1,
num_classes=num_classes, chan=256)
hg = torch.nn.DataParallel(hg).cuda()
if opt.load_prefix_pose == '':
print('please input the checkpoint name of the pose model')
exit()
# checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
opt.load_prefix_pose)[0:-1]
checkpoint_hg.load_checkpoint(hg)
print 'collecting training distributions ...\n'
train_distri_list = collect_train_valid_data(train_distri_path,
train_distri_path_2, hg, opt, is_train=True)
print 'collecting validation distributions ...\n'
val_distri_list = collect_train_valid_data(val_distri_path,
val_distri_path_2, hg, opt, is_train=False)
示例5: main
# 需要导入模块: from options import train_options [as 别名]
# 或者: from options.train_options import TrainOptions [as 别名]
def main():
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset_size = len(data_loader) * opt.batchSize
visualizer = Visualizer(opt)
model = SingleGAN()
model.initialize(opt)
total_steps = 0
lr = opt.lr
for epoch in range(1, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
save_result = True
for i, data in enumerate(data_loader):
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter = total_steps - dataset_size * (epoch - 1)
model.update_model(data)
if save_result or total_steps % opt.display_freq == 0:
save_result = save_result or total_steps % opt.update_html_freq == 0
print('mode:{} dataset:{}'.format(opt.mode,opt.name))
visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
save_result = False
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %(epoch, total_steps))
model.save('latest')
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %(epoch, total_steps))
model.save('latest')
model.save(epoch)
if epoch > opt.niter:
lr -= opt.lr / opt.niter_decay
model.update_lr(lr)