当前位置: 首页>>代码示例>>Python>>正文


Python train.train方法代码示例

本文整理汇总了Python中train.train方法的典型用法代码示例。如果您正苦于以下问题:Python train.train方法的具体用法?Python train.train怎么用?Python train.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在train的用法示例。


在下文中一共展示了train.train方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(_):
    config = flags.FLAGS
    print(str(config.flag_values_dict()))
    
    os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu

    print('loading data...')
    train_batches, dev_batches, test_batches, embedding_matrix, vocab, word_to_id  = load_data(config)
    
    flags.DEFINE_integer('PAD_IDX', word_to_id[PAD], 'PAD_IDX')
    flags.DEFINE_integer('UNK_IDX', word_to_id[UNK], 'UNK_IDX')
    flags.DEFINE_integer('BOS_IDX', word_to_id[BOS], 'BOS_IDX')
    flags.DEFINE_integer('EOS_IDX', word_to_id[EOS], 'EOS_IDX')
    
    n_embed, d_embed = embedding_matrix.shape
    flags.DEFINE_integer('n_embed', n_embed, 'n_embed')
    flags.DEFINE_integer('d_embed', d_embed, 'd_embed')

    maximum_iterations = max([max([d._max_sent_len(None) for d in batch]) for ct, batch in dev_batches])
    flags.DEFINE_integer('maximum_iterations', maximum_iterations, 'maximum_iterations')    
    
    if config.mode == 'train':
        train(config, train_batches, dev_batches, test_batches, embedding_matrix, vocab)
    elif config.mode == 'eval':
        evaluate(config, test_batches, vocab) 
开发者ID:misonuma,项目名称:strsum,代码行数:27,代码来源:cli.py

示例2: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(args):

    model_storage_type = args.model_storage_type
    if (model_storage_type == "local" or model_storage_type == "oss"):
      print ( "The storage type is " + model_storage_type)
    else:
      raise Exception("Only supports storage types like local and OSS")

    if args.job_type == "Predict":
        logging.info("starting the predict job")
        predict(args)

    elif args.job_type == "Train":
        logging.info("starting the train job")
        model = train(args)

        if model is not None:
            logging.info("finish the model training, and start to dump model ")
            model_path = args.model_path
            dump_model(model, model_storage_type, model_path, args)

    elif args.job_type == "All":
        logging.info("starting the train and predict job")

    logging.info("Finish distributed XGBoost job") 
开发者ID:kubeflow,项目名称:xgboost-operator,代码行数:27,代码来源:main.py

示例3: evaluate

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def evaluate():

    # Clear stats
    stats.clearStats(True)

    # Parse Dataset
    cfg.CLASSES, TRAIN, VAL = train.parseDataset()

    # Build Model
    NET = birdnet.build_model()

    # Train and return best net
    best_net = train.train(NET, TRAIN, VAL)

    # Load trained net
    SNAPSHOT = io.loadModel(best_net)

    # Test snapshot
    MLRAP, TIME_PER_EPOCH = test.test(SNAPSHOT)

    result = np.array([[MLRAP]], dtype='float32')
        
    return result 
开发者ID:kahst,项目名称:BirdCLEF-Baseline,代码行数:25,代码来源:evaluate.py

示例4: get_kwargs

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def get_kwargs(kwargs):
    parser = argparse.ArgumentParser(description='-f TRAIN_FILE -t TEST_FILE -o OUTPUT_FILE -e EMBEDS_FILE [-l LOGGER_FILE] [--swear-words SWEAR_FILE] [--wrong-words WRONG_WORDS_FILE] [--format-embeds FALSE]')
    parser.add_argument('-f', '--train', dest='train', action='store', help='/path/to/trian_file', type=str)
    parser.add_argument('-t', '--test', dest='test', action='store', help='/path/to/test_file', type=str)
    parser.add_argument('-o', '--output', dest='output', action='store', help='/path/to/output_file', type=str)
    parser.add_argument('-we', '--word_embeds', dest='word_embeds', action='store', help='/path/to/embeds_file', type=str)
    parser.add_argument('-ce', '--char_embeds', dest='char_embeds', action='store', help='/path/to/embeds_file', type=str)
    parser.add_argument('-c','--config', dest='config', action='store', help='/path/to/config.json', type=str)
    parser.add_argument('-l', '--logger', dest='logger', action='store', help='/path/to/log_file', type=str, default=None)
    parser.add_argument('--mode', dest='mode', action='store', help='preprocess / train / validate / all', type=str, default='all')
    parser.add_argument('--max-words', dest='max_words', action='store', type=int, default=300000)
    parser.add_argument('--use-only-exists-words', dest='use_only_exists_words', action='store_true')
    parser.add_argument('--swear-words', dest='swear_words', action='store', help='/path/to/swear_words_file', type=str, default=None)
    parser.add_argument('--wrong-words', dest='wrong_words', action='store', help='/path/to/wrong_words_file', type=str, default=None)
    parser.add_argument('--format-embeds', dest='format_embeds', action='store', help='file | json | pickle | binary', type=str, default='raw')
    parser.add_argument('--output-dir', dest='output_dir', action='store', help='/path/to/dir', type=str, default='.')
    parser.add_argument('--norm-prob', dest='norm_prob', action='store_true')
    parser.add_argument('--norm-prob-koef', dest='norm_prob_koef', action='store', type=float, default=1)
    parser.add_argument('--gpus', dest='gpus', action='store', help='count GPUs', type=int, default=0)
    for key, value in iteritems(parser.parse_args().__dict__):
        kwargs[key] = value 
开发者ID:Donskov7,项目名称:toxic_comments,代码行数:23,代码来源:main.py

示例5: function_to_minimize

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def function_to_minimize(hyperparams, gamma='auto', decision_function='ovr'):
    decision_function = hyperparams['decision_function']
    gamma = hyperparams['gamma']
    global current_eval 
    global max_evals
    print( "#################################")
    print( "       Evaluation {} of {}".format(current_eval, max_evals))
    print( "#################################")
    start_time = time.time()
    try:
        accuracy = train(epochs=HYPERPARAMS.epochs_during_hyperopt, decision_function=decision_function, gamma=gamma)
        training_time = int(round(time.time() - start_time))
        current_eval += 1
        train_history.append({'accuracy':accuracy, 'decision_function':decision_function, 'gamma':gamma, 'time':training_time})
    except Exception as e:
        print( "#################################")
        print( "Exception during training: {}".format(str(e)))
        print( "Saving train history in train_history.npy")
        np.save("train_history.npy", train_history)
        exit()
    return {'loss': -accuracy, 'time': training_time, 'status': STATUS_OK}

# lunch the hyperparameters search 
开发者ID:amineHorseman,项目名称:facial-expression-recognition-svm,代码行数:25,代码来源:optimize_parameters.py

示例6: _provide_real_images

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def _provide_real_images(batch_size, **kwargs):
  """Provides real images."""
  dataset_name = kwargs.get('dataset_name')
  dataset_file_pattern = kwargs.get('dataset_file_pattern')
  colors = kwargs['colors']
  final_height, final_width = train.make_resolution_schedule(
      **kwargs).final_resolutions
  if dataset_name is not None:
    return data_provider.provide_data(
        dataset_name=dataset_name,
        split_name='train',
        batch_size=batch_size,
        patch_height=final_height,
        patch_width=final_width,
        colors=colors)
  elif dataset_file_pattern is not None:
    return data_provider.provide_data_from_image_files(
        file_pattern=dataset_file_pattern,
        batch_size=batch_size,
        patch_height=final_height,
        patch_width=final_width,
        colors=colors) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:24,代码来源:train_main.py

示例7: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(_):
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)

  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    batch_size = train.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      model = train.build_model(stage_id, batch_size, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:19,代码来源:train_main.py

示例8: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main():
  flags = parse_flags()
  hparams = parse_hparams(flags.hparams)

  if flags.mode == 'train':
    train.train(model_name=flags.model, hparams=hparams,
                class_map_path=flags.class_map_path,
                train_csv_path=flags.train_csv_path,
                train_clip_dir=flags.train_clip_dir,
                train_dir=flags.train_dir)

  elif flags.mode == 'eval':
    evaluation.evaluate(model_name=flags.model, hparams=hparams,
                        class_map_path=flags.class_map_path,
                        eval_csv_path=flags.eval_csv_path,
                        eval_clip_dir=flags.eval_clip_dir,
                        checkpoint_path=flags.checkpoint_path)

  else:
    assert flags.mode == 'inference'
    inference.predict(model_name=flags.model, hparams=hparams,
                      class_map_path=flags.class_map_path,
                      test_clip_dir=flags.test_clip_dir,
                      checkpoint_path=flags.checkpoint_path,
                      predictions_csv_path=flags.predictions_csv_path) 
开发者ID:DCASE-REPO,项目名称:dcase2018_baseline,代码行数:27,代码来源:main.py

示例9: parse_args

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr')
    parser.add_argument('--epochs', type=int, default=13)
    parser.add_argument('--num_hid', type=int, default=1280)
    parser.add_argument('--model', type=str, default='ban')
    parser.add_argument('--op', type=str, default='c')
    parser.add_argument('--gamma', type=int, default=8, help='glimpse')
    parser.add_argument('--use_both', action='store_true', help='use both train/val datasets to train?')
    parser.add_argument('--use_vg', action='store_true', help='use visual genome dataset to train?')
    parser.add_argument('--tfidf', action='store_false', help='tfidf word embedding?')
    parser.add_argument('--input', type=str, default=None)
    parser.add_argument('--output', type=str, default='saved_models/ban')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--seed', type=int, default=1204, help='random seed')
    args = parser.parse_args()
    return args 
开发者ID:jnhwkim,项目名称:ban-vqa,代码行数:19,代码来源:main.py

示例10: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main():
	opt = opts().parse()
	now = datetime.datetime.now()
	logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))

	if opt.loadModel == 'none':

		model = inflate(opt).cuda()
	elif opt.loadModel == 'scratch':
		model = Pose3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, opt.nRegModules, opt.nRegFrames, ref.nJoints).cuda()
	else :
		model = torch.load(opt.loadModel).cuda()

	train_loader = torch.utils.data.DataLoader(
		h36m('train',opt),
		batch_size = opt.dataloaderSize,
		shuffle = False,
		num_workers = int(ref.nThreads)
	)

	optimizer = torch.optim.RMSprop(
		[{'params': model.parameters(), 'lr': opt.LRhg}], 
		alpha = ref.alpha, 
		eps = ref.epsilon, 
		weight_decay = ref.weightDecay, 
		momentum = ref.momentum
	)

	
	for epoch in range(1, opt.nEpochs + 1):
		loss_train, acc_train = train(epoch, opt, train_loader, model, optimizer)
		logger.scalar_summary('loss_train', loss_train, epoch)
		logger.scalar_summary('acc_train', acc_train, epoch)
		logger.write('{:8f} {:8f} \n'.format(loss_train, acc_train))

	logger.close() 
开发者ID:Naman-ntc,项目名称:3D-HourGlass-Network,代码行数:38,代码来源:overfit.py

示例11: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(unused_argv):
    default_hparams = create_hparams(FLAGS)
    train_fn = train.train
    inference_fn = inference.inference
    run_main(FLAGS, default_hparams, train_fn, inference_fn) 
开发者ID:neccam,项目名称:nslt,代码行数:7,代码来源:nmt.py

示例12: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn) 
开发者ID:snuspl,项目名称:parallax,代码行数:7,代码来源:nmt.py

示例13: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main():
    args = parser.parse_args()
    modify_arguments(args)

    # setting random seeds
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    with open(args.config_file, 'r') as stream:
        config = yaml.load(stream)
        args.config = Munch(modify_config(args, config))
    logger.info(args)

    if args.mode == 'train':
        train.train(args, device)
    elif args.mode == 'test':
        pass
    elif args.mode == 'analysis':
        analysis.analyze(args, device)
    elif args.mode == 'generate':
        pass
    elif args.mode == 'classify':
        analysis.classify(args, device)
    elif args.mode == 'classify_coqa':
        analysis.classify_coqa(args, device)
    elif args.mode == 'classify_final':
        analysis.classify_final(args, device) 
开发者ID:martiansideofthemoon,项目名称:squash-generation,代码行数:30,代码来源:main.py

示例14: function_to_minimize

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def function_to_minimize(hyperparams, optimizer=HYPERPARAMS.optimizer, optimizer_param=HYPERPARAMS.optimizer_param, 
        learning_rate=HYPERPARAMS.learning_rate, keep_prob=HYPERPARAMS.keep_prob, 
        learning_rate_decay=HYPERPARAMS.learning_rate_decay):
    if 'learning_rate' in hyperparams: 
        learning_rate = hyperparams['learning_rate']
    if 'learning_rate_decay' in hyperparams: 
        learning_rate_decay = hyperparams['learning_rate_decay']
    if 'keep_prob' in hyperparams: 
        keep_prob = hyperparams['keep_prob']
    if 'optimizer' in hyperparams:
        optimizer = hyperparams['optimizer']
    if 'optimizer_param' in hyperparams:
        optimizer_param = hyperparams['optimizer_param']
    global current_eval 
    global max_evals
    print( "#################################")
    print( "       Evaluation {} of {}".format(current_eval, max_evals))
    print( "#################################")
    start_time = time.time()
    try:
        accuracy = train(learning_rate=learning_rate, learning_rate_decay=learning_rate_decay, 
                     optimizer=optimizer, optimizer_param=optimizer_param, keep_prob=keep_prob)
        training_time = int(round(time.time() - start_time))
        current_eval += 1
        train_history.append({'accuracy':accuracy, 'learning_rate':learning_rate, 'learning_rate_decay':learning_rate_decay, 
                                  'optimizer':optimizer, 'optimizer_param':optimizer_param, 'keep_prob':keep_prob, 'time':training_time})
    except Exception as e:
        # exception occured during training, saving history and stopping the operation
        print( "#################################")
        print( "Exception during training: {}".format(str(e)))
        print( "Saving train history in train_history.npy")
        np.save("train_history.npy", train_history)
        exit()
    return {'loss': -accuracy, 'time': training_time, 'status': STATUS_OK}

# lunch the hyperparameters search 
开发者ID:amineHorseman,项目名称:facial-expression-recognition-using-cnn,代码行数:38,代码来源:optimize_hyperparams.py

示例15: main

# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(argv):
  hparams = model.parse_hparams(flags.hparams)

  if flags.mode == 'train':
    def split_csv(scopes):
      return scopes.split(',') if scopes else None
    train.train(model_name=flags.model, hparams=hparams,
                class_map_path=flags.class_map_path,
                train_csv_path=flags.train_csv_path,
                train_clip_dir=flags.train_clip_dir,
                train_dir=flags.train_dir,
                epoch_batches=flags.epoch_num_batches,
                warmstart_checkpoint=flags.warmstart_checkpoint,
                warmstart_include_scopes=split_csv(flags.warmstart_include_scopes),
                warmstart_exclude_scopes=split_csv(flags.warmstart_exclude_scopes))

  elif flags.mode == 'eval':
    evaluation.evaluate(model_name=flags.model, hparams=hparams,
                        class_map_path=flags.class_map_path,
                        eval_csv_path=flags.eval_csv_path,
                        eval_clip_dir=flags.eval_clip_dir,
                        eval_dir=flags.eval_dir,
                        train_dir=flags.train_dir)

  else:
    assert flags.mode == 'inference'
    inference.predict(model_name=flags.model, hparams=hparams,
                      class_map_path=flags.class_map_path,
                      inference_clip_dir=flags.inference_clip_dir,
                      inference_checkpoint=flags.inference_checkpoint,
                      predictions_csv_path=flags.predictions_csv_path) 
开发者ID:DCASE-REPO,项目名称:dcase2019_task2_baseline,代码行数:33,代码来源:runner.py


注:本文中的train.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。