当前位置: 首页>>代码示例>>Python>>正文


Python TextLoader.cue_batch_pointer_to_epoch_fraction方法代码示例

本文整理汇总了Python中utils.TextLoader.cue_batch_pointer_to_epoch_fraction方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.cue_batch_pointer_to_epoch_fraction方法的具体用法?Python TextLoader.cue_batch_pointer_to_epoch_fraction怎么用?Python TextLoader.cue_batch_pointer_to_epoch_fraction使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.TextLoader的用法示例。


在下文中一共展示了TextLoader.cue_batch_pointer_to_epoch_fraction方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import cue_batch_pointer_to_epoch_fraction [as 别名]
def train(args):
    # Create the data_loader object, which loads up all of our batches, vocab dictionary, etc.
    # from utils.py (and creates them if they don't already exist).
    # These files go in the data directory.
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    load_model = False
    if not os.path.exists(args.save_dir):
        print("Creating directory %s" % args.save_dir)
        os.mkdir(args.save_dir)
    elif (os.path.exists(os.path.join(args.save_dir, 'config.pkl'))):
        # Trained model already exists
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            with open(os.path.join(args.save_dir, 'config.pkl')) as f:
                saved_args = cPickle.load(f)
                args.rnn_size = saved_args.rnn_size
                args.num_layers = saved_args.num_layers
                args.model = saved_args.model
                print("Found a previous checkpoint. Overwriting model description arguments to:")
                print(" model: {}, rnn_size: {}, num_layers: {}".format(
                    saved_args.model, saved_args.rnn_size, saved_args.num_layers))
                load_model = True

    # Save all arguments to config.pkl in the save directory -- NOT the data directory.
    with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
        cPickle.dump(args, f)
    # Save a tuple of the characters list and the vocab dictionary to chars_vocab.pkl in
    # the save directory -- NOT the data directory.
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'w') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    # Create the model!
    print("Building the model")
    model = Model(args)

    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(model.save_variables_list())
        if (load_model):
            print("Loading saved parameters")
            saver.restore(sess, ckpt.model_checkpoint_path)
        global_epoch_fraction = sess.run(model.global_epoch_fraction)
        global_seconds_elapsed = sess.run(model.global_seconds_elapsed)
        if load_model: print("Resuming from global epoch fraction {:.3f},"
                " total trained time: {}, learning rate: {}".format(
                global_epoch_fraction, global_seconds_elapsed, sess.run(model.lr)))
        data_loader.cue_batch_pointer_to_epoch_fraction(global_epoch_fraction)
        initial_batch_step = int((global_epoch_fraction
                - int(global_epoch_fraction)) * data_loader.total_batch_count)
        epoch_range = (int(global_epoch_fraction),
                args.num_epochs + int(global_epoch_fraction))
        writer = tf.train.SummaryWriter(args.save_dir, graph=tf.get_default_graph())
        outputs = [model.cost, model.final_state, model.train_op, model.summary_op]
        is_lstm = args.model == 'lstm'
        global_step = epoch_range[0] * data_loader.total_batch_count + initial_batch_step
        try:
            for e in xrange(*epoch_range):
                # e iterates through the training epochs.
                # Reset the model state, so it does not carry over from the end of the previous epoch.
                state = sess.run(model.initial_state)
                batch_range = (initial_batch_step, data_loader.total_batch_count)
                initial_batch_step = 0
                for b in xrange(*batch_range):
                    global_step += 1
                    if global_step % args.decay_steps == 0:
                        # Set the model.lr element of the model to track
                        # the appropriately decayed learning rate.
                        current_learning_rate = sess.run(model.lr)
                        current_learning_rate *= args.decay_rate
                        sess.run(tf.assign(model.lr, current_learning_rate))
                        print("Decayed learning rate to {}".format(current_learning_rate))
                    start = time.time()
                    # Pull the next batch inputs (x) and targets (y) from the data loader.
                    x, y = data_loader.next_batch()

                    # feed is a dictionary of variable references and respective values for initialization.
                    # Initialize the model's input data and target data from the batch,
                    # and initialize the model state to the final state from the previous batch, so that
                    # model state is accumulated and carried over between batches.
                    feed = {model.input_data: x, model.targets: y}
                    if is_lstm:
                        for i, (c, h) in enumerate(model.initial_state):
                            feed[c] = state[i].c
                            feed[h] = state[i].h
                    else:
                        for i, c in enumerate(model.initial_state):
                            feed[c] = state[i]
                    # Run the session! Specifically, tell TensorFlow to compute the graph to calculate
                    # the values of cost, final state, and the training op.
                    # Cost is used to monitor progress.
                    # Final state is used to carry over the state into the next batch.
                    # Training op is not used, but we want it to be calculated, since that calculation
                    # is what updates parameter states (i.e. that is where the training happens).
                    train_loss, state, _, summary = sess.run(outputs, feed)
                    elapsed = time.time() - start
                    global_seconds_elapsed += elapsed
#.........这里部分代码省略.........
开发者ID:Navdevl,项目名称:chatbot-rnn,代码行数:103,代码来源:train.py


注:本文中的utils.TextLoader.cue_batch_pointer_to_epoch_fraction方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。