当前位置: 首页>>代码示例>>Python>>正文


Python nmt.train方法代码示例

本文整理汇总了Python中nmt.train方法的典型用法代码示例。如果您正苦于以下问题:Python nmt.train方法的具体用法?Python nmt.train怎么用?Python nmt.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在nmt的用法示例。


在下文中一共展示了nmt.train方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print 'Anything printed here will end up in the output directory for job #%d' % job_id
    print params
    trainerr, validerr, testerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        n_words=params['n-words'][0],
                                        n_words_src=params['n-words'][0],
                                        decay_c=params['decay-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0], 
                                        maxlen=20,
                                        batch_size=16,
                                        valid_batch_size=16,
                                        validFreq=1000,
                                        dispFreq=1,
                                        saveFreq=1000,
                                        sampleFreq=1000,
                                        dataset='wmt14enfr', 
                                        dictionary='/data/lisatmp3/chokyun/wmt14/parallel-corpus/en-fr/vocab.fr.pkl',
                                        use_dropout=True if params['use-dropout'][0] else False)
    return validerr 
开发者ID:arctic-nmt,项目名称:nmt,代码行数:25,代码来源:evaluate_enfr.py

示例2: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print 'Anything printed here will end up in the output directory for job #%d' % job_id
    print params
    trainerr, validerr, testerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        n_words=params['n-words'][0],
                                        n_words_src=params['n-words-src'][0],
                                        decay_c=params['decay-c'][0],
                                        alpha_c=params['alpha-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0], 
                                        maxlen=20,
                                        batch_size=16,
                                        valid_batch_size=16,
                                        validFreq=1000,
                                        dispFreq=1,
                                        saveFreq=500,
                                        sampleFreq=10,
                                        dataset='iwslt14zhen', 
                                        dictionary='/data/lisatmp3/firatorh/nmt/zh-en_lm/trainedModels/unionFinetuneRnd/union_dict.pkl',
                                        use_dropout=True if params['use-dropout'][0] else False)
    return validerr 
开发者ID:arctic-nmt,项目名称:nmt,代码行数:26,代码来源:evaluate_zhen.py

示例3: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print 'Anything printed here will end up in the output directory for job #%d' % job_id
    print params
    trainerr, validerr, testerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        n_words=params['n-words'][0],
                                        n_words_src=params['n-words-src'][0],
                                        decay_c=params['decay-c'][0],
                                        alpha_c=params['alpha-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0], 
                                        encoder='gru',
                                        decoder='gru_cond', #'gru_cond_simple',
                                        maxlen=30,
                                        batch_size=128,
                                        valid_batch_size=128,
                                        validFreq=1000,
                                        dispFreq=1,
                                        saveFreq=500,
                                        sampleFreq=500,
                                        dataset='trans_enhi', 
                                        dictionary='/data/lisatmp3/chokyun/transliteration/TranslitDataset/vocab.hi.pkl',
                                        dictionary_src='/data/lisatmp3/chokyun/transliteration/TranslitDataset/vocab.en.pkl',
                                        use_dropout=True if params['use-dropout'][0] else False)
    return validerr 
开发者ID:arctic-nmt,项目名称:nmt,代码行数:29,代码来源:evaluate_transli.py

示例4: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print 'Anything printed here will end up in the output directory for job #%d' % job_id
    print params
    trainerr, validerr, testerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        encoder='gru',
                                        decoder='gru_cond_simple',
                                        hiero=None, #'gru_hiero', # or None
                                        n_words_src=params['n-words-src'][0],
                                        n_words=params['n-words'][0],
                                        decay_c=params['decay-c'][0],
                                        alpha_c=params['alpha-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0], 
                                        maxlen=100,
                                        batch_size=64,
                                        valid_batch_size=64,
                                        validFreq=1000,
                                        dispFreq=1,
                                        saveFreq=500,
                                        sampleFreq=10,
                                        dataset='stan',
                                        dictionary='./stan/vocab_and_data_sub_europarl/vocab_sub_europarl.fr.pkl',
                                        dictionary_src='./stan/vocab_and_data_sub_europarl/vocab_sub_europarl.en.pkl',
                                        use_dropout=False)
    return validerr 
开发者ID:arctic-nmt,项目名称:nmt,代码行数:30,代码来源:evaluate_stan.py

示例5: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print 'Anything printed here will end up in the output directory for job #%d' % job_id
    print params
    trainerr, validerr, testerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        encoder='gru',
                                        decoder='gru_cond',
                                        hiero='gru_hiero', # or None
                                        n_words_src=params['n-words-src'][0],
                                        n_words=params['n-words'][0],
                                        decay_c=params['decay-c'][0],
                                        alpha_c=params['alpha-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0], 
                                        maxlen=50,
                                        batch_size=64,
                                        valid_batch_size=64,
                                        validFreq=1000,
                                        dispFreq=1,
                                        saveFreq=500,
                                        sampleFreq=10,
                                        dataset='openmt15zhen', 
                                        dictionary='./openmt15/vocab.en.pkl',
                                        dictionary_src='./openmt15/vocab.zh.pkl',
                                        use_dropout=True if params['use-dropout'][0] else False)
    return validerr 
开发者ID:arctic-nmt,项目名称:nmt,代码行数:30,代码来源:evaluate_openmt15.py

示例6: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print params
    username = os.environ['USER']
    validerr = train(
        saveto=params['model'][0],
        reload_=params['reload'][0],
        dim_word=params['dim_word'][0],
        dim=params['dim'][0],
        n_words=params['n-words'][0],
        n_words_src=params['n-words'][0],
        decay_c=params['decay-c'][0],
        lrate=params['learning-rate'][0],
        optimizer=params['optimizer'][0],
        maxlen=50,
        batch_size=32,
        valid_batch_size=32,
        datasets=[
            '/ichec/home/users/%s/data/all.en.concat.shuf.gz' % username,
            '/ichec/home/users/%s/data/all.fr.concat.shuf.gz' % username],
        valid_datasets=[
            '/ichec/home/users/%s/data/newstest2011.en.tok' % username,
            '/ichec/home/users/%s/data/newstest2011.fr.tok' % username],
        dictionaries=[
            '/ichec/home/users/%s/data/all.en.concat.gz.pkl' % username,
            '/ichec/home/users/%s/data/all.fr.concat.gz.pkl' % username],
        validFreq=5000,
        dispFreq=10,
        saveFreq=5000,
        sampleFreq=1000,
        use_dropout=params['use-dropout'][0],
        overwrite=False)
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-tutorial,代码行数:34,代码来源:train_nmt_all.py

示例7: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print params
    validerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        n_words=params['n-words'][0],
                                        n_words_src=params['n-words'][0],
                                        decay_c=params['decay-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0],
                                        maxlen=50,
                                        batch_size=32,
                                        valid_batch_size=32,
					datasets=['/ichec/home/users/%s/data/europarl-v7.fr-en.en.tok'%os.environ['USER'],
					'/ichec/home/users/%s/data/europarl-v7.fr-en.fr.tok'%os.environ['USER']],
					valid_datasets=['/ichec/home/users/%s/data/newstest2011.en.tok'%os.environ['USER'],
					'/ichec/home/users/%s/data/newstest2011.fr.tok'%os.environ['USER']],
					dictionaries=['/ichec/home/users/%s/data/europarl-v7.fr-en.en.tok.pkl'%os.environ['USER'],
					'/ichec/home/users/%s/data/europarl-v7.fr-en.fr.tok.pkl'%os.environ['USER']],
                                        validFreq=5000,
                                        dispFreq=10,
                                        saveFreq=5000,
                                        sampleFreq=1000,
                                        use_dropout=params['use-dropout'][0],
                                        overwrite=False)
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-tutorial,代码行数:29,代码来源:train_nmt.py

示例8: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print params
    validerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        n_words=params['n-words'][0],
                                        n_words_src=params['n-words'][0],
                                        decay_c=params['decay-c'][0],
                                        clip_c=params['clip-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0],
                                        maxlen=50,
                                        batch_size=32,
                                        valid_batch_size=32,
					datasets=['/ichec/home/users/%s/data/all.en.concat.shuf.gz'%os.environ['USER'],
					'/ichec/home/users/%s/data/all.fr.concat.shuf.gz'%os.environ['USER']],
					valid_datasets=['/ichec/home/users/%s/data/newstest2011.en.tok'%os.environ['USER'],
					'/ichec/home/users/%s/data/newstest2011.fr.tok'%os.environ['USER']],
					dictionaries=['/ichec/home/users/%s/data/all.en.concat.gz.pkl'%os.environ['USER'],
					'/ichec/home/users/%s/data/all.fr.concat.gz.pkl'%os.environ['USER']],
                                        validFreq=5000,
                                        dispFreq=10,
                                        saveFreq=5000,
                                        sampleFreq=1000,
                                        use_dropout=params['use-dropout'][0],
                                        overwrite=False)
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-tutorial,代码行数:30,代码来源:train_nmt_all.py

示例9: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print params
    basedir = '/data/lisatmp3/firatorh/nmt/europarlv7'
    validerr = train(saveto=params['model'][0],
                                        reload_=params['reload'][0],
                                        dim_word=params['dim_word'][0],
                                        dim=params['dim'][0],
                                        n_words=params['n-words'][0],
                                        n_words_src=params['n-words'][0],
                                        decay_c=params['decay-c'][0],
                                        clip_c=params['clip-c'][0],
                                        lrate=params['learning-rate'][0],
                                        optimizer=params['optimizer'][0],
                                        maxlen=15,
                                        batch_size=32,
                                        valid_batch_size=32,
					datasets=['%s/europarl-v7.fr-en.fr.tok'%basedir,
					'%s/europarl-v7.fr-en.en.tok'%basedir],
					valid_datasets=['%s/newstest2011.fr.tok'%basedir,
					'%s/newstest2011.en.tok'%basedir],
					dictionaries=['%s/europarl-v7.fr-en.fr.tok.pkl'%basedir,
					'%s/europarl-v7.fr-en.en.tok.pkl'%basedir],
                                        validFreq=500000,
                                        dispFreq=1,
                                        saveFreq=100,
                                        sampleFreq=50,
                                        use_dropout=params['use-dropout'][0],
                                        overwrite=False)
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-tutorial,代码行数:31,代码来源:train_nmt.py

示例10: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    print params
    validerr = train(saveto=params['model'][0],
                     reload_=params['reload'][0],
                     dim_word=params['dim_word'][0],
                     dim=params['dim'][0],
                     n_words=params['n-words'][0],
                     n_words_src=params['n-words'][0],
                     decay_c=params['decay-c'][0],
                     clip_c=params['clip-c'][0],
                     lrate=params['learning-rate'][0],
                     optimizer=params['optimizer'][0],
                     patience=1000,
                     maxlen=50,
                     batch_size=32,
                     valid_batch_size=32,
                     validFreq=100,
                     dispFreq=100,
                     saveFreq=1000,
                     sampleFreq=1000,
                     datasets=['/home/chenhd/data/zh2en/tree/corpus.ch',
                               '/home/chenhd/data/zh2en/tree/corpus.en'],
                     valid_datasets=['/home/chenhd/data/zh2en/devntest/MT02/MT02.src',
                                     '/home/chenhd/data/zh2en/devntest/MT02/reference0'],
                     dictionaries=['/home/chenhd/data/zh2en/tree/corpus.ch.pkl',
                                   '/home/chenhd/data/zh2en/tree/corpus.en.pkl'],
                     treeset=['/home/chenhd/data/zh2en/tree/corpus.ch.tree',
                              '/home/chenhd/data/zh2en/devntest/MT02/MT02.ce.tree'],
                     use_dropout=params['use-dropout'][0],
                     # shuffle_each_epoch=True,
                     overwrite=False)
    return validerr 
开发者ID:howardchenhd,项目名称:Syntax-awared-NMT,代码行数:34,代码来源:train_nmt_zh2en.py

示例11: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    re_load = False
    save_file_name = 'bpe2char_biscale_decoder_adam'
    source_dataset = params['train_data_path'] + params['source_dataset']
    target_dataset = params['train_data_path'] + params['target_dataset']
    valid_source_dataset = params['dev_data_path'] + params['valid_source_dataset']
    valid_target_dataset = params['dev_data_path'] + params['valid_target_dataset']
    source_dictionary = params['train_data_path'] + params['source_dictionary']
    target_dictionary = params['train_data_path'] + params['target_dictionary']

    print params, params['save_path'], save_file_name
    validerr = train(
        max_epochs=int(params['max_epochs']),
        patience=int(params['patience']),
        dim_word=int(params['dim_word']),
        dim_word_src=int(params['dim_word_src']),
        save_path=params['save_path'],
        save_file_name=save_file_name,
        re_load=re_load,
        enc_dim=int(params['enc_dim']),
        dec_dim=int(params['dec_dim']),
        n_words=int(params['n_words']),
        n_words_src=int(params['n_words_src']),
        decay_c=float(params['decay_c']),
        lrate=float(params['learning_rate']),
        optimizer=params['optimizer'],
        maxlen=int(params['maxlen']),
        maxlen_trg=int(params['maxlen_trg']),
        maxlen_sample=int(params['maxlen_sample']),
        batch_size=int(params['batch_size']),
        valid_batch_size=int(params['valid_batch_size']),
        sort_size=int(params['sort_size']),
        validFreq=int(params['validFreq']),
        dispFreq=int(params['dispFreq']),
        saveFreq=int(params['saveFreq']),
        sampleFreq=int(params['sampleFreq']),
        clip_c=int(params['clip_c']),
        datasets=[source_dataset, target_dataset],
        valid_datasets=[valid_source_dataset, valid_target_dataset],
        dictionaries=[source_dictionary, target_dictionary],
        use_dropout=int(params['use_dropout']),
        source_word_level=int(params['source_word_level']),
        target_word_level=int(params['target_word_level']),
        layers=layers,
        save_every_saveFreq=1,
        use_bpe=1,
        init_params=init_params,
        build_model=build_model,
        build_sampler=build_sampler,
        gen_sample=gen_sample
    )
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-cdec,代码行数:54,代码来源:train_wmt15_fien_adam.py

示例12: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    re_load = False
    save_file_name = 'bpe2char_biscale_decoder_attc_adam'
    source_dataset = params['train_data_path'] + params['source_dataset']
    target_dataset = params['train_data_path'] + params['target_dataset']
    valid_source_dataset = params['dev_data_path'] + params['valid_source_dataset']
    valid_target_dataset = params['dev_data_path'] + params['valid_target_dataset']
    source_dictionary = params['train_data_path'] + params['source_dictionary']
    target_dictionary = params['train_data_path'] + params['target_dictionary']

    print params, params['save_path'], save_file_name
    validerr = train(
        max_epochs=int(params['max_epochs']),
        patience=int(params['patience']),
        dim_word=int(params['dim_word']),
        dim_word_src=int(params['dim_word_src']),
        save_path=params['save_path'],
        save_file_name=save_file_name,
        re_load=re_load,
        enc_dim=int(params['enc_dim']),
        dec_dim=int(params['dec_dim']),
        n_words=int(params['n_words']),
        n_words_src=int(params['n_words_src']),
        decay_c=float(params['decay_c']),
        lrate=float(params['learning_rate']),
        optimizer=params['optimizer'],
        maxlen=int(params['maxlen']),
        maxlen_trg=int(params['maxlen_trg']),
        maxlen_sample=int(params['maxlen_sample']),
        batch_size=int(params['batch_size']),
        valid_batch_size=int(params['valid_batch_size']),
        sort_size=int(params['sort_size']),
        validFreq=int(params['validFreq']),
        dispFreq=int(params['dispFreq']),
        saveFreq=int(params['saveFreq']),
        sampleFreq=int(params['sampleFreq']),
        clip_c=int(params['clip_c']),
        datasets=[source_dataset, target_dataset],
        valid_datasets=[valid_source_dataset, valid_target_dataset],
        dictionaries=[source_dictionary, target_dictionary],
        use_dropout=int(params['use_dropout']),
        source_word_level=int(params['source_word_level']),
        target_word_level=int(params['target_word_level']),
        layers=layers,
        save_every_saveFreq=1,
        use_bpe=1,
        init_params=init_params,
        build_model=build_model,
        build_sampler=build_sampler,
        gen_sample=gen_sample,
    )
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-cdec,代码行数:54,代码来源:train_wmt15_deen_attc_adam.py

示例13: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    re_load = False
    save_file_name = 'bpe2char_biscale_decoder_both_adam'
    source_dataset = params['train_data_path'] + params['source_dataset']
    target_dataset = params['train_data_path'] + params['target_dataset']
    valid_source_dataset = params['dev_data_path'] + params['valid_source_dataset']
    valid_target_dataset = params['dev_data_path'] + params['valid_target_dataset']
    source_dictionary = params['train_data_path'] + params['source_dictionary']
    target_dictionary = params['train_data_path'] + params['target_dictionary']

    print params, params['save_path'], save_file_name
    validerr = train(
        max_epochs=int(params['max_epochs']),
        patience=int(params['patience']),
        dim_word=int(params['dim_word']),
        dim_word_src=int(params['dim_word_src']),
        save_path=params['save_path'],
        save_file_name=save_file_name,
        re_load=re_load,
        enc_dim=int(params['enc_dim']),
        dec_dim=int(params['dec_dim']),
        n_words=int(params['n_words']),
        n_words_src=int(params['n_words_src']),
        decay_c=float(params['decay_c']),
        lrate=float(params['learning_rate']),
        optimizer=params['optimizer'],
        maxlen=int(params['maxlen']),
        maxlen_trg=int(params['maxlen_trg']),
        maxlen_sample=int(params['maxlen_sample']),
        batch_size=int(params['batch_size']),
        valid_batch_size=int(params['valid_batch_size']),
        sort_size=int(params['sort_size']),
        validFreq=int(params['validFreq']),
        dispFreq=int(params['dispFreq']),
        saveFreq=int(params['saveFreq']),
        sampleFreq=int(params['sampleFreq']),
        clip_c=int(params['clip_c']),
        datasets=[source_dataset, target_dataset],
        valid_datasets=[valid_source_dataset, valid_target_dataset],
        dictionaries=[source_dictionary, target_dictionary],
        use_dropout=int(params['use_dropout']),
        source_word_level=int(params['source_word_level']),
        target_word_level=int(params['target_word_level']),
        layers=layers,
        save_every_saveFreq=1,
        use_bpe=1,
        init_params=init_params,
        build_model=build_model,
        build_sampler=build_sampler,
        gen_sample=gen_sample,
    )
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-cdec,代码行数:54,代码来源:train_wmt15_deen_both_adam.py

示例14: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    re_load = False
    save_file_name = 'bpe2char_two_layer_gru_decoder_adam'
    source_dataset = params['train_data_path'] + params['source_dataset']
    target_dataset = params['train_data_path'] + params['target_dataset']
    valid_source_dataset = params['dev_data_path'] + params['valid_source_dataset']
    valid_target_dataset = params['dev_data_path'] + params['valid_target_dataset']
    source_dictionary = params['train_data_path'] + params['source_dictionary']
    target_dictionary = params['train_data_path'] + params['target_dictionary']

    print params, params['save_path'], save_file_name
    validerr = train(
        max_epochs=int(params['max_epochs']),
        patience=int(params['patience']),
        dim_word=int(params['dim_word']),
        dim_word_src=int(params['dim_word_src']),
        save_path=params['save_path'],
        save_file_name=save_file_name,
        re_load=re_load,
        enc_dim=int(params['enc_dim']),
        dec_dim=int(params['dec_dim']),
        n_words=int(params['n_words']),
        n_words_src=int(params['n_words_src']),
        decay_c=float(params['decay_c']),
        lrate=float(params['learning_rate']),
        optimizer=params['optimizer'],
        maxlen=int(params['maxlen']),
        maxlen_trg=int(params['maxlen_trg']),
        maxlen_sample=int(params['maxlen_sample']),
        batch_size=int(params['batch_size']),
        valid_batch_size=int(params['valid_batch_size']),
        sort_size=int(params['sort_size']),
        validFreq=int(params['validFreq']),
        dispFreq=int(params['dispFreq']),
        saveFreq=int(params['saveFreq']),
        sampleFreq=int(params['sampleFreq']),
        clip_c=int(params['clip_c']),
        datasets=[source_dataset, target_dataset],
        valid_datasets=[valid_source_dataset, valid_target_dataset],
        dictionaries=[source_dictionary, target_dictionary],
        use_dropout=int(params['use_dropout']),
        source_word_level=int(params['source_word_level']),
        target_word_level=int(params['target_word_level']),
        layers=layers,
        save_every_saveFreq=1,
        use_bpe=1,
        init_params=init_params,
        build_model=build_model,
        build_sampler=build_sampler,
        gen_sample=gen_sample
    )
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-cdec,代码行数:54,代码来源:train_wmt15_csen_bpe2char_adam.py

示例15: main

# 需要导入模块: import nmt [as 别名]
# 或者: from nmt import train [as 别名]
def main(job_id, params):
    re_load = True
    save_file_name = 'bpe2char_two_layer_gru_decoder_adam'
    source_dataset = params['train_data_path'] + params['source_dataset']
    target_dataset = params['train_data_path'] + params['target_dataset']
    valid_source_dataset = params['dev_data_path'] + params['valid_source_dataset']
    valid_target_dataset = params['dev_data_path'] + params['valid_target_dataset']
    source_dictionary = params['train_data_path'] + params['source_dictionary']
    target_dictionary = params['train_data_path'] + params['target_dictionary']

    print params, params['save_path'], save_file_name
    validerr = train(
        max_epochs=int(params['max_epochs']),
        patience=int(params['patience']),
        dim_word=int(params['dim_word']),
        dim_word_src=int(params['dim_word_src']),
        save_path=params['save_path'],
        save_file_name=save_file_name,
        re_load=re_load,
        enc_dim=int(params['enc_dim']),
        dec_dim=int(params['dec_dim']),
        n_words=int(params['n_words']),
        n_words_src=int(params['n_words_src']),
        decay_c=float(params['decay_c']),
        lrate=float(params['learning_rate']),
        optimizer=params['optimizer'],
        maxlen=int(params['maxlen']),
        maxlen_trg=int(params['maxlen_trg']),
        maxlen_sample=int(params['maxlen_sample']),
        batch_size=int(params['batch_size']),
        valid_batch_size=int(params['valid_batch_size']),
        sort_size=int(params['sort_size']),
        validFreq=int(params['validFreq']),
        dispFreq=int(params['dispFreq']),
        saveFreq=int(params['saveFreq']),
        sampleFreq=int(params['sampleFreq']),
        clip_c=int(params['clip_c']),
        datasets=[source_dataset, target_dataset],
        valid_datasets=[valid_source_dataset, valid_target_dataset],
        dictionaries=[source_dictionary, target_dictionary],
        use_dropout=int(params['use_dropout']),
        source_word_level=int(params['source_word_level']),
        target_word_level=int(params['target_word_level']),
        layers=layers,
        save_every_saveFreq=1,
        use_bpe=1,
        init_params=init_params,
        build_model=build_model,
        build_sampler=build_sampler,
        gen_sample=gen_sample
    )
    return validerr 
开发者ID:nyu-dl,项目名称:dl4mt-cdec,代码行数:54,代码来源:train_wmt15_ruen_bpe2char_adam.py


注:本文中的nmt.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。