本文整理汇总了Python中train.train方法的典型用法代码示例。如果您正苦于以下问题:Python train.train方法的具体用法?Python train.train怎么用?Python train.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类train
的用法示例。
在下文中一共展示了train.train方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(_):
config = flags.FLAGS
print(str(config.flag_values_dict()))
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu
print('loading data...')
train_batches, dev_batches, test_batches, embedding_matrix, vocab, word_to_id = load_data(config)
flags.DEFINE_integer('PAD_IDX', word_to_id[PAD], 'PAD_IDX')
flags.DEFINE_integer('UNK_IDX', word_to_id[UNK], 'UNK_IDX')
flags.DEFINE_integer('BOS_IDX', word_to_id[BOS], 'BOS_IDX')
flags.DEFINE_integer('EOS_IDX', word_to_id[EOS], 'EOS_IDX')
n_embed, d_embed = embedding_matrix.shape
flags.DEFINE_integer('n_embed', n_embed, 'n_embed')
flags.DEFINE_integer('d_embed', d_embed, 'd_embed')
maximum_iterations = max([max([d._max_sent_len(None) for d in batch]) for ct, batch in dev_batches])
flags.DEFINE_integer('maximum_iterations', maximum_iterations, 'maximum_iterations')
if config.mode == 'train':
train(config, train_batches, dev_batches, test_batches, embedding_matrix, vocab)
elif config.mode == 'eval':
evaluate(config, test_batches, vocab)
示例2: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(args):
model_storage_type = args.model_storage_type
if (model_storage_type == "local" or model_storage_type == "oss"):
print ( "The storage type is " + model_storage_type)
else:
raise Exception("Only supports storage types like local and OSS")
if args.job_type == "Predict":
logging.info("starting the predict job")
predict(args)
elif args.job_type == "Train":
logging.info("starting the train job")
model = train(args)
if model is not None:
logging.info("finish the model training, and start to dump model ")
model_path = args.model_path
dump_model(model, model_storage_type, model_path, args)
elif args.job_type == "All":
logging.info("starting the train and predict job")
logging.info("Finish distributed XGBoost job")
示例3: evaluate
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def evaluate():
# Clear stats
stats.clearStats(True)
# Parse Dataset
cfg.CLASSES, TRAIN, VAL = train.parseDataset()
# Build Model
NET = birdnet.build_model()
# Train and return best net
best_net = train.train(NET, TRAIN, VAL)
# Load trained net
SNAPSHOT = io.loadModel(best_net)
# Test snapshot
MLRAP, TIME_PER_EPOCH = test.test(SNAPSHOT)
result = np.array([[MLRAP]], dtype='float32')
return result
示例4: get_kwargs
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def get_kwargs(kwargs):
parser = argparse.ArgumentParser(description='-f TRAIN_FILE -t TEST_FILE -o OUTPUT_FILE -e EMBEDS_FILE [-l LOGGER_FILE] [--swear-words SWEAR_FILE] [--wrong-words WRONG_WORDS_FILE] [--format-embeds FALSE]')
parser.add_argument('-f', '--train', dest='train', action='store', help='/path/to/trian_file', type=str)
parser.add_argument('-t', '--test', dest='test', action='store', help='/path/to/test_file', type=str)
parser.add_argument('-o', '--output', dest='output', action='store', help='/path/to/output_file', type=str)
parser.add_argument('-we', '--word_embeds', dest='word_embeds', action='store', help='/path/to/embeds_file', type=str)
parser.add_argument('-ce', '--char_embeds', dest='char_embeds', action='store', help='/path/to/embeds_file', type=str)
parser.add_argument('-c','--config', dest='config', action='store', help='/path/to/config.json', type=str)
parser.add_argument('-l', '--logger', dest='logger', action='store', help='/path/to/log_file', type=str, default=None)
parser.add_argument('--mode', dest='mode', action='store', help='preprocess / train / validate / all', type=str, default='all')
parser.add_argument('--max-words', dest='max_words', action='store', type=int, default=300000)
parser.add_argument('--use-only-exists-words', dest='use_only_exists_words', action='store_true')
parser.add_argument('--swear-words', dest='swear_words', action='store', help='/path/to/swear_words_file', type=str, default=None)
parser.add_argument('--wrong-words', dest='wrong_words', action='store', help='/path/to/wrong_words_file', type=str, default=None)
parser.add_argument('--format-embeds', dest='format_embeds', action='store', help='file | json | pickle | binary', type=str, default='raw')
parser.add_argument('--output-dir', dest='output_dir', action='store', help='/path/to/dir', type=str, default='.')
parser.add_argument('--norm-prob', dest='norm_prob', action='store_true')
parser.add_argument('--norm-prob-koef', dest='norm_prob_koef', action='store', type=float, default=1)
parser.add_argument('--gpus', dest='gpus', action='store', help='count GPUs', type=int, default=0)
for key, value in iteritems(parser.parse_args().__dict__):
kwargs[key] = value
示例5: function_to_minimize
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def function_to_minimize(hyperparams, gamma='auto', decision_function='ovr'):
decision_function = hyperparams['decision_function']
gamma = hyperparams['gamma']
global current_eval
global max_evals
print( "#################################")
print( " Evaluation {} of {}".format(current_eval, max_evals))
print( "#################################")
start_time = time.time()
try:
accuracy = train(epochs=HYPERPARAMS.epochs_during_hyperopt, decision_function=decision_function, gamma=gamma)
training_time = int(round(time.time() - start_time))
current_eval += 1
train_history.append({'accuracy':accuracy, 'decision_function':decision_function, 'gamma':gamma, 'time':training_time})
except Exception as e:
print( "#################################")
print( "Exception during training: {}".format(str(e)))
print( "Saving train history in train_history.npy")
np.save("train_history.npy", train_history)
exit()
return {'loss': -accuracy, 'time': training_time, 'status': STATUS_OK}
# lunch the hyperparameters search
示例6: _provide_real_images
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def _provide_real_images(batch_size, **kwargs):
"""Provides real images."""
dataset_name = kwargs.get('dataset_name')
dataset_file_pattern = kwargs.get('dataset_file_pattern')
colors = kwargs['colors']
final_height, final_width = train.make_resolution_schedule(
**kwargs).final_resolutions
if dataset_name is not None:
return data_provider.provide_data(
dataset_name=dataset_name,
split_name='train',
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
elif dataset_file_pattern is not None:
return data_provider.provide_data_from_image_files(
file_pattern=dataset_file_pattern,
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
示例7: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(_):
if not tf.gfile.Exists(FLAGS.train_root_dir):
tf.gfile.MakeDirs(FLAGS.train_root_dir)
config = _make_config_from_flags()
logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))
for stage_id in train.get_stage_ids(**config):
batch_size = train.get_batch_size(stage_id, **config)
tf.reset_default_graph()
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
real_images = None
with tf.device('/cpu:0'), tf.name_scope('inputs'):
real_images = _provide_real_images(batch_size, **config)
model = train.build_model(stage_id, batch_size, real_images, **config)
train.add_model_summaries(model, **config)
train.train(model, **config)
示例8: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main():
flags = parse_flags()
hparams = parse_hparams(flags.hparams)
if flags.mode == 'train':
train.train(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
train_csv_path=flags.train_csv_path,
train_clip_dir=flags.train_clip_dir,
train_dir=flags.train_dir)
elif flags.mode == 'eval':
evaluation.evaluate(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
eval_csv_path=flags.eval_csv_path,
eval_clip_dir=flags.eval_clip_dir,
checkpoint_path=flags.checkpoint_path)
else:
assert flags.mode == 'inference'
inference.predict(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
test_clip_dir=flags.test_clip_dir,
checkpoint_path=flags.checkpoint_path,
predictions_csv_path=flags.predictions_csv_path)
示例9: parse_args
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='vqa', help='vqa or flickr')
parser.add_argument('--epochs', type=int, default=13)
parser.add_argument('--num_hid', type=int, default=1280)
parser.add_argument('--model', type=str, default='ban')
parser.add_argument('--op', type=str, default='c')
parser.add_argument('--gamma', type=int, default=8, help='glimpse')
parser.add_argument('--use_both', action='store_true', help='use both train/val datasets to train?')
parser.add_argument('--use_vg', action='store_true', help='use visual genome dataset to train?')
parser.add_argument('--tfidf', action='store_false', help='tfidf word embedding?')
parser.add_argument('--input', type=str, default=None)
parser.add_argument('--output', type=str, default='saved_models/ban')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--seed', type=int, default=1204, help='random seed')
args = parser.parse_args()
return args
示例10: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main():
opt = opts().parse()
now = datetime.datetime.now()
logger = Logger(opt.saveDir + '/logs_{}'.format(now.isoformat()))
if opt.loadModel == 'none':
model = inflate(opt).cuda()
elif opt.loadModel == 'scratch':
model = Pose3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, opt.nRegModules, opt.nRegFrames, ref.nJoints).cuda()
else :
model = torch.load(opt.loadModel).cuda()
train_loader = torch.utils.data.DataLoader(
h36m('train',opt),
batch_size = opt.dataloaderSize,
shuffle = False,
num_workers = int(ref.nThreads)
)
optimizer = torch.optim.RMSprop(
[{'params': model.parameters(), 'lr': opt.LRhg}],
alpha = ref.alpha,
eps = ref.epsilon,
weight_decay = ref.weightDecay,
momentum = ref.momentum
)
for epoch in range(1, opt.nEpochs + 1):
loss_train, acc_train = train(epoch, opt, train_loader, model, optimizer)
logger.scalar_summary('loss_train', loss_train, epoch)
logger.scalar_summary('acc_train', acc_train, epoch)
logger.write('{:8f} {:8f} \n'.format(loss_train, acc_train))
logger.close()
示例11: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
示例12: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
示例13: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main():
args = parser.parse_args()
modify_arguments(args)
# setting random seeds
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
with open(args.config_file, 'r') as stream:
config = yaml.load(stream)
args.config = Munch(modify_config(args, config))
logger.info(args)
if args.mode == 'train':
train.train(args, device)
elif args.mode == 'test':
pass
elif args.mode == 'analysis':
analysis.analyze(args, device)
elif args.mode == 'generate':
pass
elif args.mode == 'classify':
analysis.classify(args, device)
elif args.mode == 'classify_coqa':
analysis.classify_coqa(args, device)
elif args.mode == 'classify_final':
analysis.classify_final(args, device)
示例14: function_to_minimize
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def function_to_minimize(hyperparams, optimizer=HYPERPARAMS.optimizer, optimizer_param=HYPERPARAMS.optimizer_param,
learning_rate=HYPERPARAMS.learning_rate, keep_prob=HYPERPARAMS.keep_prob,
learning_rate_decay=HYPERPARAMS.learning_rate_decay):
if 'learning_rate' in hyperparams:
learning_rate = hyperparams['learning_rate']
if 'learning_rate_decay' in hyperparams:
learning_rate_decay = hyperparams['learning_rate_decay']
if 'keep_prob' in hyperparams:
keep_prob = hyperparams['keep_prob']
if 'optimizer' in hyperparams:
optimizer = hyperparams['optimizer']
if 'optimizer_param' in hyperparams:
optimizer_param = hyperparams['optimizer_param']
global current_eval
global max_evals
print( "#################################")
print( " Evaluation {} of {}".format(current_eval, max_evals))
print( "#################################")
start_time = time.time()
try:
accuracy = train(learning_rate=learning_rate, learning_rate_decay=learning_rate_decay,
optimizer=optimizer, optimizer_param=optimizer_param, keep_prob=keep_prob)
training_time = int(round(time.time() - start_time))
current_eval += 1
train_history.append({'accuracy':accuracy, 'learning_rate':learning_rate, 'learning_rate_decay':learning_rate_decay,
'optimizer':optimizer, 'optimizer_param':optimizer_param, 'keep_prob':keep_prob, 'time':training_time})
except Exception as e:
# exception occured during training, saving history and stopping the operation
print( "#################################")
print( "Exception during training: {}".format(str(e)))
print( "Saving train history in train_history.npy")
np.save("train_history.npy", train_history)
exit()
return {'loss': -accuracy, 'time': training_time, 'status': STATUS_OK}
# lunch the hyperparameters search
开发者ID:amineHorseman,项目名称:facial-expression-recognition-using-cnn,代码行数:38,代码来源:optimize_hyperparams.py
示例15: main
# 需要导入模块: import train [as 别名]
# 或者: from train import train [as 别名]
def main(argv):
hparams = model.parse_hparams(flags.hparams)
if flags.mode == 'train':
def split_csv(scopes):
return scopes.split(',') if scopes else None
train.train(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
train_csv_path=flags.train_csv_path,
train_clip_dir=flags.train_clip_dir,
train_dir=flags.train_dir,
epoch_batches=flags.epoch_num_batches,
warmstart_checkpoint=flags.warmstart_checkpoint,
warmstart_include_scopes=split_csv(flags.warmstart_include_scopes),
warmstart_exclude_scopes=split_csv(flags.warmstart_exclude_scopes))
elif flags.mode == 'eval':
evaluation.evaluate(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
eval_csv_path=flags.eval_csv_path,
eval_clip_dir=flags.eval_clip_dir,
eval_dir=flags.eval_dir,
train_dir=flags.train_dir)
else:
assert flags.mode == 'inference'
inference.predict(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
inference_clip_dir=flags.inference_clip_dir,
inference_checkpoint=flags.inference_checkpoint,
predictions_csv_path=flags.predictions_csv_path)