當前位置: 首頁>>代碼示例>>Python>>正文


Python train_options.TrainOptions方法代碼示例

本文整理匯總了Python中options.train_options.TrainOptions方法的典型用法代碼示例。如果您正苦於以下問題:Python train_options.TrainOptions方法的具體用法?Python train_options.TrainOptions怎麽用?Python train_options.TrainOptions使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在options.train_options的用法示例。


在下文中一共展示了train_options.TrainOptions方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from options import train_options [as 別名]
# 或者: from options.train_options import TrainOptions [as 別名]
def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train() 
開發者ID:albertpumarola,項目名稱:GANimation,代碼行數:19,代碼來源:train.py

示例2: main

# 需要導入模塊: from options import train_options [as 別名]
# 或者: from options.train_options import TrainOptions [as 別名]
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset_size = len(data_loader) * opt.batch_size
    visualizer = Visualizer(opt)
    model = create_model(opt)    
    start_epoch = model.start_epoch
    total_steps = start_epoch*dataset_size
    for epoch in range(start_epoch+1, opt.niter+opt.niter_decay+1):
        epoch_start_time = time.time()
        model.update_lr()
        save_result = True
        for i, data in enumerate(data_loader):
            iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.prepare_data(data)
            model.update_model()
            if save_result or total_steps % opt.display_freq == 0:
                save_result = save_result or total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
                save_result = False
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
        print('epoch {} cost dime {}'.format(epoch,time.time()-epoch_start_time))
        model.save_ckpt(epoch)
        model.save_generator('latest')
        if epoch % opt.save_epoch_freq == 0:
            print('saving the generator at the end of epoch {}, iters {}'.format(epoch, total_steps))
            model.save_generator(epoch) 
開發者ID:Xiaoming-Yu,項目名稱:DMIT,代碼行數:36,代碼來源:train.py

示例3: main

# 需要導入模塊: from options import train_options [as 別名]
# 或者: from options.train_options import TrainOptions [as 別名]
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
                                   opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2, num_modules=1,
                         num_classes=num_classes, chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        # exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2, hg, opt, is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                                val_distri_path_2, hg, opt, is_train=False) 
開發者ID:zhiqiangdon,項目名稱:pose-adv-aug,代碼行數:48,代碼來源:collect-rotation-ditri.py

示例4: main

# 需要導入模塊: from options import train_options [as 別名]
# 或者: from options.train_options import TrainOptions [as 別名]
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
                                   opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    # train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_scales_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_scales_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2, num_modules=1,
                         num_classes=num_classes, chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2, hg, opt, is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                                val_distri_path_2, hg, opt, is_train=False) 
開發者ID:zhiqiangdon,項目名稱:pose-adv-aug,代碼行數:48,代碼來源:collect-scale-ditri.py

示例5: main

# 需要導入模塊: from options import train_options [as 別名]
# 或者: from options.train_options import TrainOptions [as 別名]
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset_size = len(data_loader) * opt.batchSize
    visualizer = Visualizer(opt)


    model = SingleGAN()
    model.initialize(opt)


    total_steps = 0
    lr = opt.lr
    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        save_result = True
        for i, data in enumerate(data_loader):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.update_model(data)
            
            if save_result or total_steps % opt.display_freq == 0:
                save_result = save_result or total_steps % opt.update_html_freq == 0
                print('mode:{} dataset:{}'.format(opt.mode,opt.name))
                visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
                save_result = False
            
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
                    
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %(epoch, total_steps))
                model.save('latest')
                
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %(epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            
        if epoch > opt.niter:
            lr -= opt.lr / opt.niter_decay
            model.update_lr(lr) 
開發者ID:Xiaoming-Yu,項目名稱:SingleGAN,代碼行數:49,代碼來源:train.py


注:本文中的options.train_options.TrainOptions方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。