当前位置: 首页>>代码示例>>Python>>正文


Python Evaluator.tune方法代码示例

本文整理汇总了Python中evaluator.Evaluator.tune方法的典型用法代码示例。如果您正苦于以下问题:Python Evaluator.tune方法的具体用法?Python Evaluator.tune怎么用?Python Evaluator.tune使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在evaluator.Evaluator的用法示例。


在下文中一共展示了Evaluator.tune方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from evaluator import Evaluator [as 别名]
# 或者: from evaluator.Evaluator import tune [as 别名]

#.........这里部分代码省略.........
    flags.DEFINE_string("hadoop_weights", "", "hadoop weights (formatted specially)")
    flags.DEFINE_boolean("add_features", False, "add features to training data")
    flags.DEFINE_boolean("prune_train", False, "prune before decoding")
    flags.DEFINE_boolean("no_lm", False, "don't use the unigram language model")
    flags.DEFINE_boolean("pickleinput", False, "assumed input is pickled")
    flags.DEFINE_string("oracle_forests", None, "oracle forests", short_name="o")
    flags.DEFINE_string("feature_map_file", None, "file with the integer to feature mapping (for lbfgs)")
    flags.DEFINE_boolean("cache_input", False, "cache input sentences (only works for pruned input)")
    flags.DEFINE_string("rm_features", None, "list of features to remove")
    flags.DEFINE_boolean("just_basic", False, "remove all features but basic")

    argv = FLAGS(sys.argv)

    if FLAGS.weights:
        weights = Model.cmdline_model()
    else:
        vector = Vector()
        assert glob.glob(FLAGS.hadoop_weights)
        for file in glob.glob(FLAGS.hadoop_weights):
            for l in open(file):
                if not l.strip():
                    continue
                f, v = l.strip().split()
                vector[f] = float(v)
        weights = Model(vector)

    rm_features = set()
    if FLAGS.rm_features:
        for l in open(FLAGS.rm_features):
            rm_features.add(l.strip())

    lm = Ngram.cmdline_ngram()
    if FLAGS.no_lm:
        lm = None

    if argv[1] == "train":
        local_decode = ChiangPerceptronDecoder(weights, lm)
    elif argv[1] == "sgd" or argv[1] == "crf":
        local_decode = MarginalDecoder(weights, lm)
    else:
        local_decode = MarginalDecoder(weights, lm)

    if FLAGS.add_features:
        tdm = local_features.TargetDataManager()
        local_decode.feature_adder = FeatureAdder(tdm)
    local_decode.prune_train = FLAGS.prune_train
    local_decode.use_pickle = FLAGS.pickleinput
    local_decode.cache_input = FLAGS.cache_input
    print >> logs, "Cache input is %s" % FLAGS.cache_input
    if FLAGS.debuglevel > 0:
        print >> logs, "beam size = %d" % FLAGS.beam

    if argv[1] == "train":
        if not FLAGS.dist:
            perc = trainer.Perceptron.cmdline_perc(local_decode)
        else:
            train_files = [FLAGS.prefix + file.strip() for file in sys.stdin]
            perc = distributed_trainer.DistributedPerceptron.cmdline_perc(local_decode)
            perc.set_training(train_files)
        perc.train()
    elif argv[1] == "sgd":
        crf = sgd.BaseCRF.cmdline_crf(local_decode)
        crf.set_oracle_files([FLAGS.oracle_forests])
        crf.train()

    elif argv[1] == "crf":
        if not FLAGS.dist:
            crf = CRF.LBFGSCRF.cmdline_crf(local_decode)
            crf.set_oracle_files([FLAGS.oracle_forests])
            crf.set_feature_mappers(add_features.read_features(FLAGS.feature_map_file))
            crf.rm_features(rm_features)
            if FLAGS.just_basic:
                print "Enforcing Basic"
                crf.enforce_just_basic()
            crf.train()
        else:
            # train_files = [FLAGS.prefix+file.strip() for file in sys.stdin]
            # oracle_files = [file+".oracle" for file in train_files]
            print >> sys.stderr, "DistributedCRF"
            crf = distCRF.DistributedCRF.cmdline_distibuted_crf(local_decode)
            # os.system("~/.python/bin/dumbo rm train_input -hadoop /home/nlg-03/mt-apps/hadoop/0.20.1+169.89/")
            # os.system("~/.python/bin/dumbo put "+crf.trainfiles[0]+" train_input -hadoop /home/nlg-03/mt-apps/hadoop/0.20.1+169.89/")
            crf.set_feature_mappers(add_features.read_features(FLAGS.feature_map_file))
            crf.rm_features(rm_features)
            if FLAGS.just_basic:
                print "Enforcing Basic"
                crf.enforce_just_basic()

            # crf.set_oracle_files(oracle_files)
            crf.train()

    else:
        if not FLAGS.dist:
            print "Evaluating"
            eval = Evaluator(local_decode, [FLAGS.dev])
            eval.tune()
        else:
            dev_files = [FLAGS.prefix + file.strip() for file in sys.stdin]
            eval = Evaluator(local_decode, dev_files)
        print eval.eval(verbose=True).compute_score()
开发者ID:srush,项目名称:tf-fork,代码行数:104,代码来源:train_manager.py


注:本文中的evaluator.Evaluator.tune方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。