当前位置: 首页>>代码示例>>Python>>正文


Python train.train函数代码示例

本文整理汇总了Python中train.train函数的典型用法代码示例。如果您正苦于以下问题:Python train函数的具体用法?Python train怎么用?Python train使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了train函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: start_offline

def start_offline(dataset):
    #sys.path.append(envpath)
    os.chdir(envpath)
    import prepare
    prepare.prepare(dataset)
    import train
    train.train(dataset)
开发者ID:wchgit,项目名称:wlan_positioning,代码行数:7,代码来源:go.py

示例2: main

def main():
    print "In Main Experiment\n"
    # get the classnames from the directory structure
    directory_names = list(set(glob.glob(os.path.join("train", "*"))).difference(set(glob.glob(os.path.join("train", "*.*")))))
    # get the number of rows through image count
    numberofImages = parseImage.gestNumberofImages(directory_names)
    num_rows = numberofImages # one row for each image in the training dataset

    # We'll rescale the images to be 25x25
    maxPixel = 25
    imageSize = maxPixel * maxPixel
    num_features = imageSize + 2 + 128 # for our ratio

    X = np.zeros((num_rows, num_features), dtype=float)
    y = np.zeros((num_rows)) # numeric class label
    files = []
    namesClasses = list() #class name list

    # Get the image training data
    parseImage.readImage(True, namesClasses, directory_names,X, y, files)

    print "Training"

    # get test result
    train.train(X, y, namesClasses)

    print "Testing"
    test.test(num_rows, num_features, X, y, namesClasses = list())
开发者ID:LvYe-Go,项目名称:Foreign-Exchange,代码行数:28,代码来源:main.py

示例3: parse_command_line

def parse_command_line():
    parser = argparse.ArgumentParser(
        description="""Train, validate, and test a face detection classifier that will determine if
        two faces are the same or different.""")

    parser.add_argument("--test_data", help="Use preTrain model on test data, calcu the accuracy and ROC.", action="store_true")
    parser.add_argument("--test_val", help="Use preTrain model on validation data, calcu the accuracy and ROC.", action="store_true")
    parser.add_argument("--weights", help="""The trained model weights to use; if not provided
        defaults to the network that was just trained""", type=str, default=None)
    parser.add_argument("-t", "--threshold", help="The margin of two dense", type=int, default=80)

    args = vars(parser.parse_args())

    if os.environ.get("CAFFEHOME") == None:
        print "You must set CAFFEHOME to point to where Caffe is installed. Example:"
        print "export CAFFEHOME=/usr/local/caffe"
        exit(1)

    # Ensure the random number generator always starts from the same place for consistent tests.
    random.seed(0)

    lfw = data.Lfw()
    lfw.load_data()
    lfw.pair_data()

    if args["weights"] == None:
        args["weights"] = constants.TRAINED_WEIGHTS

    if args["test_data"] == True:
        test_pairings(lfw, weight_file=args["weights"], is_test=True, threshold=args["threshold"])
    elif args["test_val"] == True:
        test_pairings(lfw, weight_file=args["weights"], threshold=args["threshold"])
    else:
        train(True, data=lfw)
开发者ID:SundayDX,项目名称:LFW-adventure,代码行数:34,代码来源:main.py

示例4: main

def main():
  flags = parse_flags()
  hparams = parse_hparams(flags.hparams)

  if flags.mode == 'train':
    utils.resample(sample_rate=flags.sample_rate, dir=flags.train_clip_dir, csv_path=flags.train_csv_path)
    train.train(model_name=flags.model, hparams=hparams,
                class_map_path=flags.class_map_path,
                train_csv_path=flags.train_csv_path,
                train_clip_dir=flags.train_clip_dir+'/resampled',
                train_dir=flags.train_dir, sample_rate=flags.sample_rate)

  elif flags.mode == 'eval':
      #TODO uncomment
    #utils.resample(sample_rate=flags.sample_rate, dir=flags.eval_clip_dir, csv_path=flags.eval_csv_path)
    evaluation.evaluate(model_name=flags.model, hparams=hparams,
                        class_map_path=flags.class_map_path,
                        eval_csv_path=flags.eval_csv_path,
                        eval_clip_dir=flags.eval_clip_dir+'/resampled',
                        checkpoint_path=flags.checkpoint_path)

  else:
    assert flags.mode == 'inference'
    utils.resample(sample_rate=flags.sample_rate, dir=flags.test_clip_dir, csv_path='test')
    inference.predict(model_name=flags.model, hparams=hparams,
                      class_map_path=flags.class_map_path,
                      test_clip_dir=flags.test_clip_dir,
                      checkpoint_path=flags.checkpoint_path,
                      predictions_csv_path=flags.predictions_csv_path)
开发者ID:ssgalitsky,项目名称:Research-Audio-classification-using-Audioset-Freesound-Databases,代码行数:29,代码来源:main.py

示例5: main

def main():
    from train import train

    lr = NeuralNet(n_features=2, n_hidden=10)
    lr.optimizer.lr = 0.2

    train(model=lr, data='lin')
    train(model=lr, data='xor')
开发者ID:ticcky,项目名称:nn_intro,代码行数:8,代码来源:neural_net.py

示例6: main

def main(FLAGS):
    """
    """

    if FLAGS.mode == "train":
        train(FLAGS)
    elif FLAGS.mode == "infer":
        infer(FLAGS)
    else:
        raise Exception("Choose --mode=<train|infer>")
开发者ID:GKarmakar,项目名称:oreilly-pytorch,代码行数:10,代码来源:main.py

示例7: train_model

def train_model(db_file, entity_db_file, vocab_file, word2vec, **kwargs):
    db = AbstractDB(db_file, 'r')
    entity_db = EntityDB.load(entity_db_file)
    vocab = Vocab.load(vocab_file)

    if word2vec:
        w2vec = ModelReader(word2vec)
    else:
        w2vec = None

    train.train(db, entity_db, vocab, w2vec, **kwargs)
开发者ID:studio-ousia,项目名称:ntee,代码行数:11,代码来源:cli.py

示例8: test_train_success

  def test_train_success(self):
    train_root_dir = self._config['train_root_dir']
    if not tf.gfile.Exists(train_root_dir):
      tf.gfile.MakeDirs(train_root_dir)

    for stage_id in train.get_stage_ids(**self._config):
      tf.reset_default_graph()
      real_images = provide_random_data()
      model = train.build_model(stage_id, real_images, **self._config)
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)
开发者ID:ALISCIFP,项目名称:models,代码行数:11,代码来源:train_test.py

示例9: train_PNet

def train_PNet(base_dir, prefix, end_epoch, display, lr):
    """
    train PNet
    :param dataset_dir: tfrecord path
    :param prefix:
    :param end_epoch:
    :param display:
    :param lr:
    :return:
    """
    net_factory = P_Net
    train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr)
开发者ID:jiapei100,项目名称:MTCNN-Tensorflow,代码行数:12,代码来源:train_PNet.py

示例10: main

def main():
    """main function"""
    # flag = True
    util.check_tensorflow_version()
    util.check_and_mkdir()
    #util.TRAIN_YAML = yaml
    config = load_yaml()
    check_config(config)
    hparams = create_hparams(config)
    print(hparams.values())
    log = Log(hparams)
    hparams.logger = log.logger
    train.train(hparams)
开发者ID:zeroToAll,项目名称:tensorflow_practice,代码行数:13,代码来源:main.py

示例11: predict

def predict(corpusPath, modelsPath, dummy, corpusId=None, connection=None, directed="both"):
    for model in getModels(corpusPath, modelsPath, corpusId, directed):
        if os.path.exists(model["model"]):
            print "Skipping existing target", model["model"]
            continue
        print "Processing target", model["model"], "directed =", model["directed"]
        if dummy:
            continue
        train.train(model["model"], task=CORPUS_ID, corpusDir=model["corpusDir"], connection=connection,
                    exampleStyles={"examples":model["exampleStyle"]}, parse="McCC",
                    classifierParams={"examples":"c=1,10,100,500,1000,1500,2500,3500,4000,4500,5000,7500,10000,20000,25000,27500,28000,29000,30000,35000,40000,50000,60000,65000"})
        for dataset in ("devel", "test"):
            if os.path.exists(model[dataset]):
                evaluate(model[dataset], model[dataset + "-gold"], model[dataset + "-eval"])
开发者ID:jbjorne,项目名称:TEES,代码行数:14,代码来源:SemEval2010Task8Tools.py

示例12: main

def main():
    """
    Args: data_dir save_dir logs_dir
    """
    args = sys.argv
    data_dir = args[1]
    save_dir = args[2]
    logs_dir = args[3]

    sess = tf.Session()

    with sess.as_default():
        train_data, test_data = arrows.get_input_producers(data_dir)
        train.train(arrows.build_net, train_data, test_data, logs_dir=logs_dir, save_dir=save_dir)
开发者ID:vlpolyansky,项目名称:video-cnn,代码行数:14,代码来源:arrows_train.py

示例13: train_dataset

def train_dataset(dataset, train_params):
    temp_dataset_dir = dataset_dir
    data_dir = os.path.join(temp_dataset_dir, dataset)
    print("Data Directory: %s" % data_dir)
    # Model name (layers_size_model_time)
    model_name = "%d_%d_%s" % (train_params.num_layers,
                               train_params.rnn_size,
                               train_params.model)
    model_dir = os.path.join(data_dir, models_dir, model_name)
    print("Model Dir: %s" % model_dir)
    train_args = train_params.get_training_arguments(data_dir, model_dir)
    tf.reset_default_graph()
    train.train(train_args)

    return model_name
开发者ID:Zbot21,项目名称:char-rnn-tensorflow,代码行数:15,代码来源:automated_testing.py

示例14: main

def main():
    """
    Args: data_dir save_dir logs_dir
    """
    args = sys.argv
    data_dir = args[1]
    save_dir = args[2]
    logs_dir = args[3]

    sess = tf.Session()

    with sess.as_default():
        train_data, test_data = movie.get_input_producers(data_dir)
        train.train(movie.build_net, train_data, test_data, logs_dir=logs_dir, save_dir=save_dir, need_load=True,
                init_rate=0.0005, test_only=False)
开发者ID:vlpolyansky,项目名称:video-cnn,代码行数:15,代码来源:movie_train.py

示例15: k_result

 def k_result(k):
     train_k = random.sample(train_set,k)
     scp_k = os.path.join(tempdir,'scp_k')
     with open(scp_k,'w') as f:
         f.writelines(train_k)
     final_dir = train(outdir, config, scp_k, proto, htk_dict, words_mlf, monophones, tempdir)
     return test(outdir, final_dir, wdnet, htk_dict, monophones, scp_test, words_mlf, tempdir)
开发者ID:Tdebel,项目名称:HTK-scripts,代码行数:7,代码来源:graph.py


注:本文中的train.train函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。