当前位置: 首页>>代码示例>>Python>>正文


Python Stats.record_training_cost方法代码示例

本文整理汇总了Python中stats.Stats.record_training_cost方法的典型用法代码示例。如果您正苦于以下问题:Python Stats.record_training_cost方法的具体用法?Python Stats.record_training_cost怎么用?Python Stats.record_training_cost使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在stats.Stats的用法示例。


在下文中一共展示了Stats.record_training_cost方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: str

# 需要导入模块: from stats import Stats [as 别名]
# 或者: from stats.Stats import record_training_cost [as 别名]
    ckpt_file = str(int(time.time()))
    stats.last_ckpt = ckpt_file
    full_ckpt_path = os.path.join(opts.ckpt_dir, RUN_ID, ckpt_file)
    saver.save(sess, full_ckpt_path)
    log("ckpt saved to %s" % full_ckpt_path)

log("training")
epoch = 0
egs = zip(train_x, train_y)
last_ckpt = ""
while epoch != int(opts.num_epochs):
    random.shuffle(egs)
    for n, (eg, true_labels) in enumerate(egs):
        # train a batch
        # TODO: move this into ONE matrix and slice these things out.
        # TODO: move away from feeddict and to queueing egs.
        batch_cost, _opt = sess.run([cost, optimizer],
                                    feed_dict=eg_and_label_to_feeddict(eg, true_labels))
        stats.record_training_cost(batch_cost)
        # occasionally check dev set
        if stats.n_batches_trained % opts.dev_run_freq == 0:
            stats_from_dev_set(stats)
            stats.flush_to_stdout(epoch)
        # occasionally write a checkpoint
        if opts.ckpt_dir and stats.n_batches_trained % opts.ckpt_freq == 0:
            save_ckpt()
    epoch += 1

if opts.ckpt_dir:
    save_ckpt()
开发者ID:BinbinBian,项目名称:snli_nn_tf,代码行数:32,代码来源:nn_baseline.py

示例2: test_fn

# 需要导入模块: from stats import Stats [as 别名]
# 或者: from stats.Stats import record_training_cost [as 别名]
    for (s1, s2), y in zip(dev_x, dev_y):
        pred_y, cost = test_fn(s1, s2, [y])
        actuals.append(y)
        predicteds.append(pred_y)
        stats.record_dev_cost(cost)
    dev_c = confusion_matrix(actuals, predicteds)
    dev_accuracy = util.accuracy(dev_c)
    stats.set_dev_accuracy(dev_accuracy)
    print "dev confusion\n %s (%s)" % (dev_c, dev_accuracy)


log("training")
epoch = 0
training_early_stop_time = opts.max_run_time_sec + time.time()
stats = Stats(os.path.basename(__file__), opts)
egs = zip(train_x, train_y)
while epoch != opts.num_epochs:
    random.shuffle(egs)
    for (s1, s2), y in egs:
        cost, = train_fn(s1, s2, [y])
        stats.record_training_cost(cost)
        early_stop = False
        if opts.max_run_time_sec != -1 and time.time() > training_early_stop_time:
            early_stop = True
        if stats.n_egs_trained % opts.dev_run_freq == 0 or early_stop:
            stats_from_dev_set(stats)
            stats.flush_to_stdout(epoch)
        if early_stop:
            exit(0)
    epoch += 1
开发者ID:BinbinBian,项目名称:snli_nn,代码行数:32,代码来源:nn_seq2seq.py


注:本文中的stats.Stats.record_training_cost方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。