当前位置: 首页>>代码示例>>Python>>正文


Python TextLoader.reset_batch_pointer方法代码示例

本文整理汇总了Python中utils.TextLoader.reset_batch_pointer方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.reset_batch_pointer方法的具体用法?Python TextLoader.reset_batch_pointer怎么用?Python TextLoader.reset_batch_pointer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.TextLoader的用法示例。


在下文中一共展示了TextLoader.reset_batch_pointer方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'w') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        for e in xrange(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in xrange(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                print "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start)
                if (e * data_loader.num_batches + b) % args.save_every == 0:
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print "model saved to {}".format(checkpoint_path)
开发者ID:nakosung,项目名称:char-rnn-tensorflow,代码行数:34,代码来源:train.py

示例2: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
开发者ID:owen-d,项目名称:tensorflow_practice,代码行数:61,代码来源:train.py

示例3: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    print(args)
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            #print("model learning rate is {}".format(model.lr.eval()))
            data_loader.reset_batch_pointer('train')

            state = model.initial_state.eval()
            for b in xrange(data_loader.ntrain):
                start = time.time()
                x, y = data_loader.next_batch('train')

                # tmp = ''
                # for c in x:
                #   for i in c:
                #     tmp += np.array(data_loader.chars)[i]
                # print(tmp)

                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.ntrain + b,
                            args.num_epochs * data_loader.ntrain,
                            e, train_loss, end - start))
                if (e * data_loader.ntrain + b) % args.save_every == 0:
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.ntrain + b)
                    print("model saved to {}".format(checkpoint_path))


            # eval validation loss
            data_loader.reset_batch_pointer('validation')
            validation_state = model.initial_state.eval()
            val_losses = 0
            for n in xrange(data_loader.nvalidation):
                x, y = data_loader.next_batch('validation')
                feed = {model.input_data: x, model.targets: y, model.initial_state: validation_state}
                validation_loss, validation_state = sess.run([model.cost, model.final_state], feed)
                val_losses += validation_loss

            validation_loss = val_losses / data_loader.nvalidation
            print("validation loss is {}".format(validation_loss))
开发者ID:jiongye,项目名称:char-rnn-tensorflow,代码行数:58,代码来源:train.py

示例4: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'w') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        train_loss_iterations = {'iteration': [], 'epoch': [], 'train_loss': [], 'val_loss': []}

        for e in xrange(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in xrange(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                batch_idx = e * data_loader.num_batches + b
                print "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(batch_idx,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start)
                train_loss_iterations['iteration'].append(batch_idx)
                train_loss_iterations['epoch'].append(e)
                train_loss_iterations['train_loss'].append(train_loss)

                if batch_idx % args.save_every == 0:

                    # evaluate
                    state_val = model.initial_state.eval()
                    avg_val_loss = 0
                    for x_val, y_val in data_loader.val_batches:
                        feed_val = {model.input_data: x_val, model.targets: y_val, model.initial_state: state_val}
                        val_loss, state_val, _ = sess.run([model.cost, model.final_state, model.train_op], feed_val)
                        avg_val_loss += val_loss / len(data_loader.val_batches)
                    print 'val_loss: {:.3f}'.format(avg_val_loss)
                    train_loss_iterations['val_loss'].append(avg_val_loss)

                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
                    print "model saved to {}".format(checkpoint_path)
                else:
                    train_loss_iterations['val_loss'].append(None)

            pd.DataFrame(data=train_loss_iterations,
                         columns=train_loss_iterations.keys()).to_csv(os.path.join(args.save_dir, 'log.csv'))
开发者ID:gfortaine,项目名称:grid-lstm-tensorflow,代码行数:57,代码来源:train.py

示例5: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    # Load data
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    # Set vocabulary size
    args.vocab_size = data_loader.vocab_size

    # Create the save directory if it does not exist
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # Save the configuration and the vocab, used to reload models when sampling
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    # Create models with arguments
    model = Model(args)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                      .format(e * data_loader.num_batches + b,
                              args.num_epochs * data_loader.num_batches,
                              e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0:
                    checkpoint_path = os.path.join(args.save_dir, 'models.ckpt')
                    saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
                    print("models saved to {}".format(checkpoint_path))
        # Save the final state
        saver.save(sess, os.path.join(args.save_dir, 'models.ckpt'),
                   global_step=args.num_epochs * data_loader.num_batches)
开发者ID:Zbot21,项目名称:char-rnn-tensorflow,代码行数:45,代码来源:train.py

示例6: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, args.input_encoding)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(args.init_from)," %s must be a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"words_vocab.pkl")),"words_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'words_vocab.pkl'), 'rb') as f:
            saved_words, saved_vocab = cPickle.load(f)
        assert saved_words==data_loader.words, "Data and loaded model disagree on word set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.words, data_loader.vocab), f)

    model = Model(args)

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(args.log_dir)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        train_writer.add_graph(sess.graph)
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(model.epoch_pointer.eval(), args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            speed = 0
            if args.init_from is None:
                assign_op = model.epoch_pointer.assign(e)
                sess.run(assign_op)
            if args.init_from is not None:
                data_loader.pointer = model.batch_pointer.eval()
                args.init_from = None
            for b in range(data_loader.pointer, data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state,
                        model.batch_time: speed}
                summary, train_loss, state, _, _ = sess.run([merged, model.cost, model.final_state,
                                                             model.train_op, model.inc_batch_pointer_op], feed)
                train_writer.add_summary(summary, e * data_loader.num_batches + b)
                speed = time.time() - start
                if (e * data_loader.num_batches + b) % args.batch_size == 0:
                    print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                        .format(e * data_loader.num_batches + b,
                                args.num_epochs * data_loader.num_batches,
                                e, train_loss, speed))
                if (e * data_loader.num_batches + b) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
        train_writer.close()
开发者ID:Sr-vZ,项目名称:word-rnn-tensorflow,代码行数:78,代码来源:train.py

示例7: train2

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train2(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, args.reprocess)
    args.vocab_size = data_loader.vocab_size

    totalTask = args.num_epochs * data_loader.num_batches

    lastCheckpoint = tf.train.latest_checkpoint(args.save_dir) 
    if lastCheckpoint is None:
        startEpoch = 0
    else:
        print "Last checkpoint :", lastCheckpoint
        startEpoch = int(lastCheckpoint.split("-")[-1])

    print "startEpoch = ", startEpoch

    with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'w') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = ConstrainedModel(args)

    etaCount = 0
    etaString = "-" 
    etaStart = time.time()
    etaTime = 0

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        if startEpoch > 0: # load latest checkpoint
            print "Loading last checkpoint"
            saver.restore(sess, lastCheckpoint)

        for e in xrange(startEpoch, args.num_epochs):
            sess.run(tf.assign(model.lr, decayForEpoch(args, e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in xrange(data_loader.num_batches):
                start = time.time()
                x, y, con = data_loader.next_batch()

                feed = {model.input_data: x, model.targets: y, model.initial_state: state, model.con_data:con}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                #time.sleep(0.01)
                #train_loss = 5
                end = time.time()

                taskNum = (e * data_loader.num_batches + b)
                etaCount += 1
                if (etaCount) % 25 == 0:
                    duration = time.time() - etaStart
                    etaTime = (totalTask - (taskNum + 1)) / 25 * duration
                    m, s = divmod(etaTime, 60)
                    h, m = divmod(m, 60)
                    etaString = "%d:%02d:%02d" % (h, m, s)
                    etaStart = time.time()

                print "{}/{} (epoch {}), loss = {:.3f}, time/batch = {:.3f}, ETA: {} ({})" \
                    .format(taskNum, totalTask, e, train_loss, end - start, time.ctime(time.time()+etaTime), etaString)

            if (e + 1) % args.save_every == 0 or e == args.num_epochs - 1:
                checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step = e + 1)
                print "model saved to {}".format(checkpoint_path)
开发者ID:supasorn,项目名称:constrained-char-rnn,代码行数:67,代码来源:train.py

示例8: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    model_name = args.data_dir.split("/")[-1]
    # make a dir to store checkpoints
    args.save_dir = os.path.join('checkpoints', model_name)
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, "No checkpoint found"
        assert ckpt.model_checkpoint_path, "No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same = ["model", "rnn_size", "num_layers", "seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme

        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"

    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        # instrument for tensorboard
        summaries = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
                os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S")))
        writer.add_graph(sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h

                # instrument for tensorboard
                summ, train_loss, state, _ = sess.run([summaries, model.cost, model.final_state, model.train_op], feed)
                writer.add_summary(summ, e * data_loader.num_batches + b)

                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                      .format(e * data_loader.num_batches + b,
                              args.num_epochs * data_loader.num_batches,
                              e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                        or (e == args.num_epochs-1 and b == data_loader.num_batches-1):
                    # remove previous checkpoints
                    current_checkpoints = [f for f in os.listdir(args.save_dir) if os.path.isfile(os.path.join(args.save_dir, f))]
                    for f in current_checkpoints:
                        if model_name in f:
                            os.remove(os.path.join(args.save_dir, f))
                    # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, model_name)
                    saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
                    final_model = '{}-{}'.format(model_name, e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))

    # get the vocab
    model_vocab = getModelVocab(model_name)
    # dump the checkpoints to javascript
    dump_checkpoints(model_vocab, model_name, final_model)
开发者ID:scottleedavis,项目名称:ml5-data-and-training,代码行数:94,代码来源:train.py

示例9: cross_validation

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def cross_validation(args):
    data_loader = TextLoader(args.utils_dir, args.data_path, args.batch_size, args.seq_length, None, None)
    args.vocab_size = data_loader.vocab_size
    args.label_size = data_loader.label_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)
    with open(os.path.join(args.save_dir, 'labels.pkl'), 'wb') as f:
        pickle.dump(data_loader.labels, f)

    data = data_loader.tensor.copy()
    np.random.shuffle(data)
    data_list = np.array_split(data, 10, axis=0)

    model = Model(args)
    accuracy_list = []

    with tf.Session() as sess:
        for n in range(10):
            init = tf.initialize_all_variables()
            sess.run(init)
            saver = tf.train.Saver(tf.all_variables())

            test_data = data_list[n].copy()
            train_data = np.concatenate(map(lambda i: data_list[i], [j for j in range(10) if j!=n]), axis=0)
            data_loader.tensor = train_data

            for e in range(args.num_epochs):
                sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
                data_loader.reset_batch_pointer()

                for b in range(data_loader.num_batches):
                    start = time.time()
                    state = model.initial_state.eval()
                    x, y = data_loader.next_batch()
                    feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                    train_loss, state, _, accuracy = sess.run([model.cost, model.final_state, model.optimizer, model.accuracy], feed_dict=feed)
                    end = time.time()
                    print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
                        .format(e * data_loader.num_batches + b + 1,
                                args.num_epochs * data_loader.num_batches,
                                e + 1,
                                train_loss,
                                accuracy,
                                end - start)
                    if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
                        or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
                        checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=e*data_loader.num_batches+b+1)
                        print 'model saved to {}'.format(checkpoint_path)

            n_chunks = len(test_data) / args.batch_size
            if len(test_data) % args.batch_size:
                n_chunks += 1
            test_data_list = np.array_split(test_data, n_chunks, axis=0)

            correct_total = 0.0
            num_total = 0.0
            for m in range(n_chunks):
                start = time.time()
                x = test_data_list[m][:, :-1]
                y = test_data_list[m][:, -1]
                results = model.predict_class(sess, x)
                correct_num = np.sum(results==y)
                end = time.time()

                correct_total += correct_num
                num_total += len(x)

            accuracy_total = correct_total / num_total
            accuracy_list.append(accuracy_total)
            print 'total_num = {}, total_accuracy = {:.6f}'.format(int(num_total), accuracy_total)

    accuracy_average = np.average(accuracy_list)
    print 'The average accuracy of cross_validation is {}'.format(accuracy_average)
开发者ID:12190143,项目名称:RNN-Classification,代码行数:79,代码来源:cross_validation.py

示例10: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    print(args)
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length, args.training_data_ratio)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    #sess = tf.InteractiveSession()
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()
        summary_writer = tf.train.SummaryWriter('/tmp', sess.graph)

        step = 0
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            #print("model learning rate is {}".format(model.lr.eval()))
            data_loader.reset_batch_pointer('train')

            state = model.initial_state.eval()
            for b in xrange(data_loader.ntrain):
                start = time.time()
                x, y = data_loader.next_batch('train')

                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                step = e * data_loader.ntrain + b
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(step,
                            args.num_epochs * data_loader.ntrain,
                            e, train_loss, end - start))

                if step % args.write_summary_every == 0:
                    # training loss
                    summary_str = sess.run(summary_op, feed_dict=feed)
                    summary_writer.add_summary(summary_str, step)

                if step % args.save_every == 0 or (step + 1) == (args.num_epochs * data_loader.ntrain):
                    # eval validation loss
                    data_loader.reset_batch_pointer('validation')
                    validation_state = model.initial_state.eval()
                    val_losses = 0
                    for n in xrange(data_loader.nvalidation):
                        x, y = data_loader.next_batch('validation')
                        val_feed = {model.input_data: x, model.targets: y, model.initial_state: validation_state}
                        validation_loss, validation_state = sess.run([model.cost, model.final_state], val_feed)
                        val_losses += validation_loss

                    validation_loss = val_losses / data_loader.nvalidation
                    print("validation loss is {}".format(validation_loss))

                    # write top 5 validation loss to a json file
                    args_dict = vars(args)
                    args_dict['step'] = step
                    val_loss_file = args.save_dir + '/val_loss.json'
                    loss_json = ''
                    save_new_checkpoint = False
                    time_int = int(time.time())
                    args_dict['checkpoint_path'] = os.path.join(args.save_dir, 'model.ckpt-'+str(time_int))
                    if os.path.exists(val_loss_file):
                        with open(val_loss_file, "r") as text_file:
                            text = text_file.read()
                            if text == '':
                                loss_json = {validation_loss: args_dict}
                                save_new_checkpoint = True
                            else:
                                loss_json = json.loads(text)
                                losses = loss_json.keys()
                                if len(losses) > 3:
                                    losses.sort(key=lambda x: float(x), reverse=True)
                                    loss = losses[0]
                                    if validation_loss < float(loss):
                                        to_be_remove_ckpt_file_path =  loss_json[loss]['checkpoint_path']
                                        to_be_remove_ckpt_meta_file_path = to_be_remove_ckpt_file_path + '.meta'
                                        print("removed checkpoint {}".format(to_be_remove_ckpt_file_path))
                                        if os.path.exists(to_be_remove_ckpt_file_path):
                                            os.remove(to_be_remove_ckpt_file_path)
                                        if os.path.exists(to_be_remove_ckpt_meta_file_path):
                                            os.remove(to_be_remove_ckpt_meta_file_path)
                                        del(loss_json[loss])
                                        loss_json[validation_loss] = args_dict
                                        save_new_checkpoint = True
                                else:
                                    loss_json[validation_loss] = args_dict
                                    save_new_checkpoint = True
                    else:
                       loss_json = {validation_loss: args_dict}
                       save_new_checkpoint = True

                    if save_new_checkpoint:
                        checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
#.........这里部分代码省略.........
开发者ID:jtoy,项目名称:word-rnn-tf,代码行数:103,代码来源:train.py

示例11: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):

    display_step = 100
    num_train = 20000;
    train_input, train_output, train_length, max_length = get_training_data(args, 'train', num_train, 0)
    test_input, test_output, test_length, max_length = get_training_data(args, 'test', 25000, 50000)
    val_input, val_output, val_length, max_length = get_training_data(args, 'val', 25000, 75000)

    #for i in range(2):
    #  print('i: ' + str(i) + ' => ' + str(train_input[i,:]))

    train_input = train_input.astype(int)

    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = 50000 #data_loader.vocab_size
    
    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist 
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
        
        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"
        
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)
        
    model = Model(args)

    print("num_layers: ", args.num_layers)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()

            step = 0
            ptr = 0

	    print('train_input: ', train_input.shape)

            while step < num_train/args.batch_size:
                b = step
            #for b in range(data_loader.num_batches):
		step += 1
                start = time.time()

	        # inputs batch
	        x = np.squeeze(train_input[ptr:ptr+args.batch_size, :args.batch_size])

	        # output batch
	        y = np.squeeze(train_input[ptr:ptr+args.batch_size, 1:args.batch_size+1])
		ptr += args.batch_size+1
                #x, y = data_loader.next_batch()
		#print('x: ', x.shape)
		#print('y: ', y.shape)
		#print('x: ', x[1])
		#print('y: ', y)
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                tt, calc_res, reg_cost, train_loss, state, _ = sess.run([model.target_vector, model.logits, model.reg_cost, model.cost, model.final_state, model.train_op], feed)
		print('out len: ', len(tt))
		print('target: ', tt)
		print('calc_res: ', calc_res)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}, reg_cost = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start, reg_cost))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))

		if step % display_step == 0:
		    print('x: ', x[1])
开发者ID:bottiger,项目名称:Integer-Sequence-Learning,代码行数:101,代码来源:train.py

示例12: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):
    if args.continue_training in ['True', 'true']:
        args.continue_training = True
    else:
        args.continue_training = False

    data_loader = TextLoader(True, args.utils_dir, args.data_path, args.batch_size, args.seq_length, None, None)
    args.vocab_size = data_loader.vocab_size
    args.label_size = data_loader.label_size

    if args.continue_training:
        assert os.path.isfile(os.path.join(args.save_dir, 'config.pkl')), 'config.pkl file does not exist in path %s' % args.save_dir
        assert os.path.isfile(os.path.join(args.utils_dir, 'chars_vocab.pkl')), 'chars_vocab.pkl file does not exist in path %s' % args.utils_dir
        assert os.path.isfile(os.path.join(args.utils_dir, 'labels.pkl')), 'labels.pkl file does not exist in path %s' % args.utils_dir
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

        with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
            saved_model_args = pickle.load(f)
        need_be_same = ['model', 'rnn_size', 'num_layers', 'seq_length']
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme], 'command line argument and saved model disagree on %s' % checkme

        with open(os.path.join(args.utils_dir, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = pickle.load(f)
        with open(os.path.join(args.utils_dir, 'labels.pkl'), 'rb') as f:
            saved_labels = pickle.load(f)
        assert saved_chars==data_loader.chars, 'data and loaded model disagree on character set'
        assert saved_vocab==data_loader.vocab, 'data and loaded model disagree on dictionary mappings'
        assert saved_labels==data_loader.labels, 'data and loaded model disagree on label dictionary mappings'

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.utils_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)
    with open(os.path.join(args.utils_dir, 'labels.pkl'), 'wb') as f:
        pickle.dump(data_loader.labels, f)

    model = Model(args)

    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        sess.run(init)
        saver = tf.train.Saver(tf.all_variables())

        if args.continue_training:
            saver.restore(sess, ckpt.model_checkpoint_path)

        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()

            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                train_loss, state, _, accuracy = sess.run([model.cost, model.final_state, model.optimizer, model.accuracy], feed_dict=feed)
                end = time.time()
                print '{}/{} (epoch {}), train_loss = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}'\
                    .format(e * data_loader.num_batches + b + 1,
                            args.num_epochs * data_loader.num_batches,
                            e + 1,
                            train_loss,
                            accuracy,
                            end - start)
                if (e*data_loader.num_batches+b+1) % args.save_every == 0 \
                    or (e==args.num_epochs-1 and b==data_loader.num_batches-1):
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=e*data_loader.num_batches+b+1)
                    print 'model saved to {}'.format(checkpoint_path)
开发者ID:AlexYoung757,项目名称:RNN-Classification,代码行数:73,代码来源:train.py

示例13: main

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def main(_):
  pp.pprint(FLAGS.__flags)

  if not os.path.exists(FLAGS.checkpoint_dir):
    print(" [*] Creating checkpoint directory...")
    os.makedirs(FLAGS.checkpoint_dir)

  data_loader = TextLoader(os.path.join(FLAGS.data_dir, FLAGS.dataset_name),
                           FLAGS.batch_size, FLAGS.seq_length)
  vocab_size = data_loader.vocab_size
  valid_size = 50
  valid_window = 100

  with tf.variable_scope('model'):
    train_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip)

  with tf.variable_scope('model', reuse=True):
    simple_model = CharRNN(vocab_size, 1, FLAGS.rnn_size,
                           FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                           1, FLAGS.keep_prob,
                           FLAGS.grad_clip)

  with tf.variable_scope('model', reuse=True):
    valid_model = CharRNN(vocab_size, FLAGS.batch_size, FLAGS.rnn_size,
                          FLAGS.layer_depth, FLAGS.num_units, FLAGS.rnn_type,
                          FLAGS.seq_length, FLAGS.keep_prob,
                          FLAGS.grad_clip)

  with tf.Session() as sess:
    tf.global_variables_initializer().run()

    train_model.load(sess, FLAGS.checkpoint_dir, FLAGS.dataset_name)

    best_val_pp = float('inf')
    best_val_epoch = 0
    valid_loss = 0
    valid_perplexity = 0
    start = time.time()

    if FLAGS.export:
      print("Eval...")
      final_embeddings = train_model.embedding.eval(sess)
      emb_file = os.path.join(FLAGS.data_dir, FLAGS.dataset_name, 'emb.npy')
      print("Embedding shape: {}".format(final_embeddings.shape))
      np.save(emb_file, final_embeddings)

    else: # Train
      current_step = 0
      similarity, valid_examples, _ = compute_similarity(train_model, valid_size, valid_window, 6)

      # save hyper-parameters
      cPickle.dump(FLAGS.__flags, open(FLAGS.log_dir + "/hyperparams.pkl", 'wb'))

      # run it!
      for e in range(FLAGS.num_epochs):
        data_loader.reset_batch_pointer()

        # decay learning rate
        sess.run(tf.assign(train_model.lr, FLAGS.learning_rate))

        # iterate by batch
        for b in range(data_loader.num_batches):
          x, y = data_loader.next_batch()
          res, time_batch = run_epochs(sess, x, y, train_model)
          train_loss = res["loss"]
          train_perplexity = np.exp(train_loss)
          iterate = e * data_loader.num_batches + b

          # print log
          print("{}/{} (epoch {}) loss = {:.2f}({:.2f}) perplexity(train/valid) = {:.2f}({:.2f}) time/batch = {:.2f} chars/sec = {:.2f}k"\
              .format(e * data_loader.num_batches + b,
                      FLAGS.num_epochs * data_loader.num_batches,
                      e, train_loss, valid_loss, train_perplexity, valid_perplexity,
                      time_batch, (FLAGS.batch_size * FLAGS.seq_length) / time_batch / 1000))

          current_step = tf.train.global_step(sess, train_model.global_step)

        # validate
        valid_loss = 0

        for vb in range(data_loader.num_valid_batches):
          res, valid_time_batch = run_epochs(sess, data_loader.x_valid[vb], data_loader.y_valid[vb], valid_model, False)
          valid_loss += res["loss"]

        valid_loss = valid_loss / data_loader.num_valid_batches
        valid_perplexity = np.exp(valid_loss)

        print("### valid_perplexity = {:.2f}, time/batch = {:.2f}".format(valid_perplexity, valid_time_batch))

        log_str = ""

        # Generate sample
        smp1 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"我喜歡做")
        smp2 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"他吃飯時會用")
        smp3 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"人類總要重複同樣的")
        smp4 = simple_model.sample(sess, data_loader.chars, data_loader.vocab, UNK_ID, 5, u"天色暗了,好像快要")

#.........这里部分代码省略.........
开发者ID:indiejoseph,项目名称:chinese-char-rnn,代码行数:103,代码来源:train.py

示例14: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import reset_batch_pointer [as 别名]
def train(args):

    data_loader = TextLoader(args.data_path, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size
    args.file_size = data_loader.file_size
    print("Vocab size: ",args.vocab_size)
    print("File size: ",args.file_size)
    args.lower_bound = 0 #If we know the entropy then we set it to this
    data_info = {}
    if args.info_path is not None:
        assert os.path.isfile(args.info_path),"Info file not found in the path: %s"%args.info_path

        #Open the info file
        with open(args.info_path, 'rb') as f:
            data_info = json.load(f)
            #Assuming we know entropy
            args.lower_bound = data_info['Entropy']
            print(data_info)

    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        # check if all necessary files exist 
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
        
        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"
        
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)
        
    
    ##################################################
    # Get the model
    ##################################################
    model = Model(args)
    print("model Loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        writer = tf.summary.FileWriter(args.summary_dir,sess.graph)
        # restore model
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        
        ######################################################
        # Perform the training
        #####################################################
        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer() #Need to check what this does
            state = sess.run(model.initial_state) #What is this initial state
            cumul_loss = 0
             
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y}
                
                for i, (c, h) in enumerate(model.initial_state):
                    feed[c] = state[i].c
                    feed[h] = state[i].h
                summary, train_loss, state, _ = sess.run([model.merged_summaries, model.cost, model.final_state, model.train_op], feed) #what is the training loss
                train_loss /= np.log(2)
                cumul_loss += train_loss
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % args.save_every == 0\
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))

                if b%10 == 0:
                    writer.add_summary(summary,e*data_loader.num_batches + b)
             
            cumul_loss /= data_loader.num_batches
            print("Epoch {}: Cumulative Loss for the epoch: {:.3f}".format(e,cumul_loss))
            if (abs(cumul_loss - args.lower_bound) < 0.1):
                print("Stopping Training as we get a good loss.. :) ... ") 
#.........这里部分代码省略.........
开发者ID:jessehui,项目名称:NN_compression,代码行数:103,代码来源:train.py


注:本文中的utils.TextLoader.reset_batch_pointer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。