当前位置: 首页>>代码示例>>Python>>正文


Python Dataset.next_batch方法代码示例

本文整理汇总了Python中dataset.Dataset.next_batch方法的典型用法代码示例。如果您正苦于以下问题:Python Dataset.next_batch方法的具体用法?Python Dataset.next_batch怎么用?Python Dataset.next_batch使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在dataset.Dataset的用法示例。


在下文中一共展示了Dataset.next_batch方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from dataset import Dataset [as 别名]
# 或者: from dataset.Dataset import next_batch [as 别名]
def main():
    # Dataset path
    train_list = "/path/to/data/flickr_style/train.txt"
    test_list = "/path/to/data/flickr_style/test.txt"

    # Learning params
    learning_rate = 0.001
    training_iters = 12800  # 10 epochs
    batch_size = 50
    display_step = 20
    test_step = 640  # 0.5 epoch

    # Network params
    n_classes = 20
    keep_rate = 0.5

    # Graph input
    x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
    y = tf.placeholder(tf.float32, [None, n_classes])
    keep_var = tf.placeholder(tf.float32)

    # Model
    pred = Model.alexnet(x, keep_var)

    # Loss and optimizer
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)

    # Evaluation
    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Init
    init = tf.initialize_all_variables()

    # Load dataset
    dataset = Dataset(train_list, test_list)

    # Launch the graph
    with tf.Session() as sess:
        print "Init variable"
        sess.run(init)

        # Load pretrained model
        load_with_skip("caffenet.npy", sess, ["fc8"])  # Skip weights from fc8

        print "Start training"
        step = 1
        while step < training_iters:
            batch_xs, batch_ys = dataset.next_batch(batch_size, "train")
            sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, keep_var: keep_rate})

            # Display testing status
            if step % test_step == 0:
                test_acc = 0.0
                test_count = 0
                for _ in range(int(dataset.test_size / batch_size)):
                    batch_tx, batch_ty = dataset.next_batch(batch_size, "test")
                    acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, keep_var: 1.0})
                    test_acc += acc
                    test_count += 1
                test_acc /= test_count
                print >> sys.stderr, "{} Iter {}: Testing Accuracy = {:.4f}".format(datetime.now(), step, test_acc)

            # Display training status
            if step % display_step == 0:
                acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys, keep_var: 1.0})
                batch_loss = sess.run(loss, feed_dict={x: batch_xs, y: batch_ys, keep_var: 1.0})
                print >> sys.stderr, "{} Iter {}: Training Loss = {:.4f}, Accuracy = {:.4f}".format(
                    datetime.now(), step, batch_loss, acc
                )

            step += 1
        print "Finish!"
开发者ID:joelthchao,项目名称:tensorflow-finetune-flickr-style,代码行数:76,代码来源:finetune.py

示例2: range

# 需要导入模块: from dataset import Dataset [as 别名]
# 或者: from dataset.Dataset import next_batch [as 别名]
ce_summ = tf.scalar_summary('cross entropy', cross_entropy)
train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc_summ = tf.scalar_summary('accuracy', accuracy)

merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('/tmp/vitruvian_logs', sess.graph_def)

saver = tf.train.Saver()

if TRAIN:
    sess.run(tf.initialize_all_variables())

    for i in range(ITERATIONS):
        batch_x, batch_y = train_dataset.next_batch(BATCH_SIZE)
        feed = {x: batch_x, y_: batch_y, keep_prob: 1.0}
        train_accuracy = accuracy.eval(feed_dict={
            x: batch_x, y_: batch_y, keep_prob: 1.0})
        loss_val = cross_entropy.eval(feed_dict={
            x: batch_x, y_: batch_y, keep_prob: 1.0})
        print("step %d, training accuracy %g, loss %g"%(i, train_accuracy, loss_val))

        summary = merged.eval(feed_dict={
            x: batch_x, y_: batch_y, keep_prob: 1.0})
        writer.add_summary(summary, i)

        if i%1000 == 0:
            save_path = saver.save(sess, 'model_%d.ckpt' % i)
            print("model saved in file: %s" % save_path)
开发者ID:jamesaanderson,项目名称:vitruvian,代码行数:32,代码来源:model.py


注:本文中的dataset.Dataset.next_batch方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。