本文整理汇总了Python中utils.TextLoader.cue_batch_pointer_to_epoch_fraction方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.cue_batch_pointer_to_epoch_fraction方法的具体用法?Python TextLoader.cue_batch_pointer_to_epoch_fraction怎么用?Python TextLoader.cue_batch_pointer_to_epoch_fraction使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils.TextLoader
的用法示例。
在下文中一共展示了TextLoader.cue_batch_pointer_to_epoch_fraction方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import cue_batch_pointer_to_epoch_fraction [as 别名]
def train(args):
# Create the data_loader object, which loads up all of our batches, vocab dictionary, etc.
# from utils.py (and creates them if they don't already exist).
# These files go in the data directory.
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
args.vocab_size = data_loader.vocab_size
load_model = False
if not os.path.exists(args.save_dir):
print("Creating directory %s" % args.save_dir)
os.mkdir(args.save_dir)
elif (os.path.exists(os.path.join(args.save_dir, 'config.pkl'))):
# Trained model already exists
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
with open(os.path.join(args.save_dir, 'config.pkl')) as f:
saved_args = cPickle.load(f)
args.rnn_size = saved_args.rnn_size
args.num_layers = saved_args.num_layers
args.model = saved_args.model
print("Found a previous checkpoint. Overwriting model description arguments to:")
print(" model: {}, rnn_size: {}, num_layers: {}".format(
saved_args.model, saved_args.rnn_size, saved_args.num_layers))
load_model = True
# Save all arguments to config.pkl in the save directory -- NOT the data directory.
with open(os.path.join(args.save_dir, 'config.pkl'), 'w') as f:
cPickle.dump(args, f)
# Save a tuple of the characters list and the vocab dictionary to chars_vocab.pkl in
# the save directory -- NOT the data directory.
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'w') as f:
cPickle.dump((data_loader.chars, data_loader.vocab), f)
# Create the model!
print("Building the model")
model = Model(args)
config = tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(model.save_variables_list())
if (load_model):
print("Loading saved parameters")
saver.restore(sess, ckpt.model_checkpoint_path)
global_epoch_fraction = sess.run(model.global_epoch_fraction)
global_seconds_elapsed = sess.run(model.global_seconds_elapsed)
if load_model: print("Resuming from global epoch fraction {:.3f},"
" total trained time: {}, learning rate: {}".format(
global_epoch_fraction, global_seconds_elapsed, sess.run(model.lr)))
data_loader.cue_batch_pointer_to_epoch_fraction(global_epoch_fraction)
initial_batch_step = int((global_epoch_fraction
- int(global_epoch_fraction)) * data_loader.total_batch_count)
epoch_range = (int(global_epoch_fraction),
args.num_epochs + int(global_epoch_fraction))
writer = tf.train.SummaryWriter(args.save_dir, graph=tf.get_default_graph())
outputs = [model.cost, model.final_state, model.train_op, model.summary_op]
is_lstm = args.model == 'lstm'
global_step = epoch_range[0] * data_loader.total_batch_count + initial_batch_step
try:
for e in xrange(*epoch_range):
# e iterates through the training epochs.
# Reset the model state, so it does not carry over from the end of the previous epoch.
state = sess.run(model.initial_state)
batch_range = (initial_batch_step, data_loader.total_batch_count)
initial_batch_step = 0
for b in xrange(*batch_range):
global_step += 1
if global_step % args.decay_steps == 0:
# Set the model.lr element of the model to track
# the appropriately decayed learning rate.
current_learning_rate = sess.run(model.lr)
current_learning_rate *= args.decay_rate
sess.run(tf.assign(model.lr, current_learning_rate))
print("Decayed learning rate to {}".format(current_learning_rate))
start = time.time()
# Pull the next batch inputs (x) and targets (y) from the data loader.
x, y = data_loader.next_batch()
# feed is a dictionary of variable references and respective values for initialization.
# Initialize the model's input data and target data from the batch,
# and initialize the model state to the final state from the previous batch, so that
# model state is accumulated and carried over between batches.
feed = {model.input_data: x, model.targets: y}
if is_lstm:
for i, (c, h) in enumerate(model.initial_state):
feed[c] = state[i].c
feed[h] = state[i].h
else:
for i, c in enumerate(model.initial_state):
feed[c] = state[i]
# Run the session! Specifically, tell TensorFlow to compute the graph to calculate
# the values of cost, final state, and the training op.
# Cost is used to monitor progress.
# Final state is used to carry over the state into the next batch.
# Training op is not used, but we want it to be calculated, since that calculation
# is what updates parameter states (i.e. that is where the training happens).
train_loss, state, _, summary = sess.run(outputs, feed)
elapsed = time.time() - start
global_seconds_elapsed += elapsed
#.........这里部分代码省略.........