当前位置: 首页>>代码示例>>Python>>正文


Python train.main方法代码示例

本文整理汇总了Python中train.main方法的典型用法代码示例。如果您正苦于以下问题:Python train.main方法的具体用法?Python train.main怎么用?Python train.main使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在train的用法示例。


在下文中一共展示了train.main方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train_translation_model

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def train_translation_model(data_dir, arch, extra_flags=None):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task', 'translation',
            data_dir,
            '--save-dir', data_dir,
            '--arch', arch,
            '--optimizer', 'nag',
            '--lr', '0.05',
            '--max-tokens', '500',
            '--max-epoch', '1',
            '--no-progress-bar',
            '--distributed-world-size', '1',
            '--source-lang', 'in',
            '--target-lang', 'out',
        ] + (extra_flags or []),
    )
    train.main(train_args) 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:22,代码来源:test_binaries.py

示例2: generate_main

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def generate_main(data_dir):
    generate_parser = options.get_generation_parser()
    generate_args = options.parse_args_and_arch(
        generate_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--beam', '3',
            '--batch-size', '64',
            '--max-len-b', '5',
            '--gen-subset', 'valid',
            '--no-progress-bar',
        ],
    )

    # evaluate model in batch mode
    generate.main(generate_args)

    # evaluate model interactively
    generate_args.buffer_size = 0
    generate_args.max_sentences = None
    orig_stdin = sys.stdin
    sys.stdin = StringIO('h e l l o\n')
    interactive.main(generate_args)
    sys.stdin = orig_stdin 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:27,代码来源:test_binaries.py

示例3: train_language_model

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def train_language_model(data_dir, arch):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task', 'language_modeling',
            data_dir,
            '--arch', arch,
            '--optimizer', 'nag',
            '--lr', '1.0',
            '--criterion', 'adaptive_loss',
            '--adaptive-softmax-cutoff', '5,10,15',
            '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
            '--decoder-embed-dim', '280',
            '--max-tokens', '500',
            '--tokens-per-sample', '500',
            '--save-dir', data_dir,
            '--max-epoch', '1',
            '--no-progress-bar',
            '--distributed-world-size', '1',
        ],
    )
    train.main(train_args) 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:25,代码来源:test_binaries.py

示例4: train_language_model

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def train_language_model(data_dir, arch):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task', 'language_modeling',
            data_dir,
            '--arch', arch,
            '--optimizer', 'nag',
            '--lr', '0.1',
            '--criterion', 'adaptive_loss',
            '--adaptive-softmax-cutoff', '5,10,15',
            '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
            '--decoder-embed-dim', '280',
            '--max-tokens', '500',
            '--tokens-per-sample', '500',
            '--save-dir', data_dir,
            '--max-epoch', '1',
            '--no-progress-bar',
            '--distributed-world-size', '1',
            '--ddp-backend', 'no_c10d',
        ],
    )
    train.main(train_args) 
开发者ID:kakaobrain,项目名称:helo_word,代码行数:26,代码来源:test_binaries.py

示例5: preprocess_translation_data

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def preprocess_translation_data(data_dir):
    preprocess_parser = preprocess.get_parser()
    preprocess_args = preprocess_parser.parse_args([
        '--source-lang', 'in',
        '--target-lang', 'out',
        '--trainpref', os.path.join(data_dir, 'train'),
        '--validpref', os.path.join(data_dir, 'valid'),
        '--testpref', os.path.join(data_dir, 'test'),
        '--thresholdtgt', '0',
        '--thresholdsrc', '0',
        '--destdir', data_dir,
    ])
    preprocess.main(preprocess_args) 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:15,代码来源:test_binaries.py

示例6: preprocess_lm_data

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def preprocess_lm_data(data_dir):
    preprocess_parser = preprocess.get_parser()
    preprocess_args = preprocess_parser.parse_args([
        '--only-source',
        '--trainpref', os.path.join(data_dir, 'train.out'),
        '--validpref', os.path.join(data_dir, 'valid.out'),
        '--testpref', os.path.join(data_dir, 'test.out'),
        '--destdir', data_dir,
    ])
    preprocess.main(preprocess_args) 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:12,代码来源:test_binaries.py

示例7: test_full_flow

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def test_full_flow(self, mock_data_provider):
    FLAGS.eval_dir = self.get_temp_dir()
    FLAGS.batch_size = 16
    FLAGS.max_number_of_steps = 2
    FLAGS.noise_dims = 3

    # Construct mock inputs.
    mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
    mock_lbls = np.concatenate(
        (np.ones([FLAGS.batch_size, 1], dtype=np.int32),
         np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
    mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)

    train.main(None) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:16,代码来源:train_test.py

示例8: _test_build_graph_helper

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, conditional, use_sync_replicas):
    FLAGS.max_number_of_steps = 0
    FLAGS.conditional = conditional
    FLAGS.use_sync_replicas = use_sync_replicas
    FLAGS.batch_size = 16

    # Mock input pipeline.
    mock_imgs = np.zeros([FLAGS.batch_size, 32, 32, 3], dtype=np.float32)
    mock_lbls = np.concatenate(
        (np.ones([FLAGS.batch_size, 1], dtype=np.int32),
         np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
    with mock.patch.object(train, 'data_provider') as mock_data_provider:
      mock_data_provider.provide_data.return_value = (
          mock_imgs, mock_lbls, None, None)
      train.main(None) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:17,代码来源:train_test.py

示例9: test_run_one_train_step

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def test_run_one_train_step(self, mock_data_provider):
    FLAGS.max_number_of_steps = 1
    FLAGS.gan_type = 'unconditional'
    FLAGS.batch_size = 5
    FLAGS.grid_size = 1
    tf.set_random_seed(1234)

    # Mock input pipeline.
    mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
    mock_lbls = np.concatenate(
        (np.ones([FLAGS.batch_size, 1], dtype=np.int32),
         np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
    mock_data_provider.provide_data.return_value = (mock_imgs, mock_lbls, None)

    train.main(None) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:17,代码来源:train_test.py

示例10: _test_build_graph_helper

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, gan_type):
    FLAGS.max_number_of_steps = 0
    FLAGS.gan_type = gan_type

    # Mock input pipeline.
    mock_imgs = np.zeros([FLAGS.batch_size, 28, 28, 1], dtype=np.float32)
    mock_lbls = np.concatenate(
        (np.ones([FLAGS.batch_size, 1], dtype=np.int32),
         np.zeros([FLAGS.batch_size, 9], dtype=np.int32)), axis=1)
    with mock.patch.object(train, 'data_provider') as mock_data_provider:
      mock_data_provider.provide_data.return_value = (
          mock_imgs, mock_lbls, None)
      train.main(None) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:15,代码来源:train_test.py

示例11: _test_build_graph_helper

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, weight_factor):
    FLAGS.max_number_of_steps = 0
    FLAGS.weight_factor = weight_factor

    batch_size = 3
    patch_size = 16

    FLAGS.batch_size = batch_size
    FLAGS.patch_size = patch_size
    mock_imgs = np.zeros([batch_size, patch_size, patch_size, 3],
                         dtype=np.float32)

    with mock.patch.object(train, 'data_provider') as mock_data_provider:
      mock_data_provider.provide_data.return_value = mock_imgs
      train.main(None) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:17,代码来源:train_test.py

示例12: preprocess_translation_data

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def preprocess_translation_data(data_dir, extra_flags=None):
    preprocess_parser = preprocess.get_parser()
    preprocess_args = preprocess_parser.parse_args(
        [
            '--source-lang', 'in',
            '--target-lang', 'out',
            '--trainpref', os.path.join(data_dir, 'train'),
            '--validpref', os.path.join(data_dir, 'valid'),
            '--testpref', os.path.join(data_dir, 'test'),
            '--thresholdtgt', '0',
            '--thresholdsrc', '0',
            '--destdir', data_dir,
        ] + (extra_flags or []),
    )
    preprocess.main(preprocess_args) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:17,代码来源:test_binaries.py

示例13: generate_main

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def generate_main(data_dir, extra_flags=None):
    generate_parser = options.get_generation_parser()
    generate_args = options.parse_args_and_arch(
        generate_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--beam', '3',
            '--batch-size', '64',
            '--max-len-b', '5',
            '--gen-subset', 'valid',
            '--no-progress-bar',
            '--print-alignment',
        ] + (extra_flags or []),
    )

    # evaluate model in batch mode
    generate.main(generate_args)

    # evaluate model interactively
    generate_args.buffer_size = 0
    generate_args.max_sentences = None
    orig_stdin = sys.stdin
    sys.stdin = StringIO('h e l l o\n')
    interactive.main(generate_args)
    sys.stdin = orig_stdin 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:28,代码来源:test_binaries.py

示例14: test_main

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss,
                mock_define_model, mock_data_provider, mock_gfile):
    FLAGS.image_set_x_file_pattern = '/tmp/x/*.jpg'
    FLAGS.image_set_y_file_pattern = '/tmp/y/*.jpg'
    FLAGS.batch_size = 3
    FLAGS.patch_size = 8
    FLAGS.generator_lr = 0.02
    FLAGS.discriminator_lr = 0.3
    FLAGS.train_log_dir = '/tmp/foo'
    FLAGS.master = 'master'
    FLAGS.task = 0
    FLAGS.cycle_consistency_loss_weight = 2.0
    FLAGS.max_number_of_steps = 1

    mock_data_provider.provide_custom_datasets.return_value = (tf.zeros(
        [1, 2], dtype=tf.float32), tf.zeros([1, 2], dtype=tf.float32))

    train.main(None)
    mock_data_provider.provide_custom_datasets.assert_called_once_with(
        ['/tmp/x/*.jpg', '/tmp/y/*.jpg'], batch_size=3, patch_size=8)
    mock_define_model.assert_called_once_with(mock.ANY, mock.ANY)
    mock_cyclegan_loss.assert_called_once_with(
        mock_define_model.return_value,
        cycle_consistency_loss_weight=2.0,
        tensor_pool_fn=mock.ANY)
    mock_define_train_ops.assert_called_once_with(
        mock_define_model.return_value, mock_cyclegan_loss.return_value)
    mock_gan_train.assert_called_once_with(
        mock_define_train_ops.return_value,
        '/tmp/foo',
        get_hooks_fn=mock.ANY,
        hooks=mock.ANY,
        master='master',
        is_chief=True) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:36,代码来源:train_test.py

示例15: _test_build_graph_helper

# 需要导入模块: import train [as 别名]
# 或者: from train import main [as 别名]
def _test_build_graph_helper(self, weight_factor):
    FLAGS.max_number_of_steps = 0
    FLAGS.weight_factor = weight_factor
    FLAGS.batch_size = 9
    FLAGS.patch_size = 32

    mock_imgs = np.zeros(
        [FLAGS.batch_size, FLAGS.patch_size, FLAGS.patch_size, 3],
        dtype=np.float32)
    with mock.patch.object(train, 'data_provider') as mock_data_provider:
      mock_data_provider.provide_data.return_value = mock_imgs
      train.main(None) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:14,代码来源:train_test.py


注:本文中的train.main方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。