本文整理汇总了Python中train.main方法的典型用法代码示例。如果您正苦于以下问题:Python train.main方法的具体用法?Python train.main怎么用?Python train.main使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类train
的用法示例。
在下文中一共展示了train.main方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_translation_model
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def train_translation_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'translation',
data_dir,
'--save-dir', data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.05',
'--max-tokens', '500',
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--source-lang', 'in',
'--target-lang', 'out',
] + (extra_flags or []),
)
train.main(train_args)
示例2: generate_main
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def generate_main(data_dir):
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
],
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
示例3: train_language_model
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def train_language_model(data_dir, arch):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '1.0',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
],
)
train.main(train_args)
示例4: train_language_model
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def train_language_model(data_dir, arch):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.1',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--ddp-backend', 'no_c10d',
],
)
train.main(train_args)
示例5: preprocess_translation_data
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def preprocess_translation_data(data_dir):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([
'--source-lang', 'in',
'--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'),
'--validpref', os.path.join(data_dir, 'valid'),
'--testpref', os.path.join(data_dir, 'test'),
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
示例6: preprocess_lm_data
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def preprocess_lm_data(data_dir):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([
'--only-source',
'--trainpref', os.path.join(data_dir, 'train.out'),
'--validpref', os.path.join(data_dir, 'valid.out'),
'--testpref', os.path.join(data_dir, 'test.out'),
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
示例7: test_full_flow
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def test_full_flow(self, mock_data_provider):
FLAGS.eval_dir = self.get_temp_dir()
FLAGS.batch_size = 16
FLAGS.max_number_of_steps = 2
FLAGS.noise_dims = 3
# Construct mock inputs.
mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)
train.main(None)
示例8: _test_build_graph_helper
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, conditional, use_sync_replicas):
FLAGS.max_number_of_steps = 0
FLAGS.conditional = conditional
FLAGS.use_sync_replicas = use_sync_replicas
FLAGS.batch_size = 16
# Mock input pipeline.
mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None, None)
train.main(None)
示例9: test_run_one_train_step
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def test_run_one_train_step(self, mock_data_provider):
FLAGS.max_number_of_steps = 1
FLAGS.gan_type = 'unconditional'
FLAGS.batch_size = 5
FLAGS.grid_size = 1
tf.set_random_seed(1234)
# Mock input pipeline.
mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)
train.main(None)
示例10: _test_build_graph_helper
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, gan_type):
FLAGS.max_number_of_steps = 0
FLAGS.gan_type = gan_type
# Mock input pipeline.
mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
mock_lbls = np.concatenate(
(np.ones([FLAGS.batch_size, 1], dtype=np.int32),
np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = (
mock_imgs, mock_lbls, None)
train.main(None)
示例11: _test_build_graph_helper
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
batch_size = 3
patch_size = 16
FLAGS.batch_size = batch_size
FLAGS.patch_size = patch_size
mock_imgs = np.zeros([batch_size, patch_size, patch_size, 3],
dtype=np.float32)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)
示例12: preprocess_translation_data
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args(
[
'--source-lang', 'in',
'--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'),
'--validpref', os.path.join(data_dir, 'valid'),
'--testpref', os.path.join(data_dir, 'test'),
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir,
] + (extra_flags or []),
)
preprocess.main(preprocess_args)
示例13: generate_main
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def generate_main(data_dir, extra_flags=None):
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
示例14: test_main
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss,
mock_define_model, mock_data_provider, mock_gfile):
FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
FLAGS.batch_size = 3
FLAGS.patch_size = 8
FLAGS.generator_lr = 0.02
FLAGS.discriminator_lr = 0.3
FLAGS.train_log_dir = '/tmp/foo'
FLAGS.master = 'master'
FLAGS.task = 0
FLAGS.cycle_consistency_loss_weight = 2.0
FLAGS.max_number_of_steps = 1
mock_data_provider.provide_custom_datasets.return_value = (tf.zeros(
[1, 2], dtype=tf.float32), tf.zeros([1, 2], dtype=tf.float32))
train.main(None)
mock_data_provider.provide_custom_datasets.assert_called_once_with(
['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8)
mock_define_model.assert_called_once_with(mock.ANY, mock.ANY)
mock_cyclegan_loss.assert_called_once_with(
mock_define_model.return_value,
cycle_consistency_loss_weight=2.0,
tensor_pool_fn=mock.ANY)
mock_define_train_ops.assert_called_once_with(
mock_define_model.return_value, mock_cyclegan_loss.return_value)
mock_gan_train.assert_called_once_with(
mock_define_train_ops.return_value,
'/tmp/foo',
get_hooks_fn=mock.ANY,
hooks=mock.ANY,
master='master',
is_chief=True)
示例15: _test_build_graph_helper
# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, weight_factor):
FLAGS.max_number_of_steps = 0
FLAGS.weight_factor = weight_factor
FLAGS.batch_size = 9
FLAGS.patch_size = 32
mock_imgs = np.zeros(
[FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3],
dtype=np.float32)
with mock.patch.object(train, 'data_provider') as mock_data_provider:
mock_data_provider.provide_data.return_value = mock_imgs
train.main(None)