本文整理汇总了Python中utils.TextLoader.next_batch_tr方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.next_batch_tr方法的具体用法?Python TextLoader.next_batch_tr怎么用?Python TextLoader.next_batch_tr使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils.TextLoader
的用法示例。
在下文中一共展示了TextLoader.next_batch_tr方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import next_batch_tr [as 别名]
def train(args):
print("training on \'"+args.data_dir+"\'")
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
args.vocab_size = data_loader.vocab_size
# check compatibility if training is continued from previously saved model
if args.init_from is not None:
print("RELOADING FROM CHECKPOING")
# check if all necessary files exist
assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
ckpt = tf.train.get_checkpoint_state(args.init_from)
assert ckpt,"No checkpoint found"
assert ckpt.model_checkpoint_path,"No model path found in checkpoint"
# open old config and check if models are compatible
with open(os.path.join(args.init_from, 'config.pkl')) as f:
saved_model_args = cPickle.load(f)
need_be_same=["model","rnn_size","num_layers","seq_length"]
for checkme in need_be_same:
assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
# open saved vocab/dict and check if vocabs/dicts are compatible
with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
saved_chars, saved_vocab = cPickle.load(f)
assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!"
assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"
with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
cPickle.dump(args, f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
cPickle.dump((data_loader.chars, data_loader.vocab), f)
print("====================================")
printargs(args)
print("====================================")
model = Model(args)
def validateonce(expectationdropout=True, TrueIfVal_FalseIfTrain=True):
data_loader.reset_batch_pointers()
model.resetweights(expectationdropout=expectationdropout)
state = model.resetstate()
start = time.time()
losses = []
backupptrtr = data_loader.pointer_tr
entrps = None
truths = None
allprobs = None
for b in range(data_loader.num_batches_te):
if TrueIfVal_FalseIfTrain:
x, y = data_loader.next_batch_te()
else:
x, y = data_loader.next_batch_tr()
# shapes of x and y are (batchsize, seqlength); each element is an integer from 0 to (vocabsize-1)
feed = {model.input_data: x, model.targets: y, model.initial_state: state}
feed = model.extrafeed(feed)
state, probs, entropies = sess.run([model.final_state, model.probs, model.pred_entropy], feed)
theseprobs = np.reshape(probs, (1, args.batch_size, args.seq_length, args.vocab_size))
thesey = np.reshape(y, (args.batch_size, args.seq_length))
allprobs = tryconcat(allprobs, theseprobs, axis=2)
truths = tryconcat(truths, thesey, axis=1)
y = y.flatten()
for ii in range(y.size):
losses.append(-np.log2(probs[ii,y[ii]]))
thesentropies = np.reshape(entropies,(1,args.batch_size,args.seq_length))
entrps = tryconcat(entrps, thesentropies, axis=2)
data_loader.pointer_tr = backupptrtr
end = time.time()
testtimeperbatch = (end-start) / float(data_loader.num_batches_te)
return (np.array(losses), truths, entrps, allprobs, testtimeperbatch)
# for tensorboard
valsumplh_cost = tf.placeholder(tf.float32, (1,), name="validation_summary_placeholder_cost")
valsumplh_pent = tf.placeholder(tf.float32, (1,), name="validation_summary_placeholder_prediction_entropy")
#reduce_sum fixes tensorflow scalar handling being weird (vector of size 1)
valsumscs_cost = tf.scalar_summary('cost_val', tf.reduce_sum(valsumplh_cost))
valsumscs_pent = tf.scalar_summary('prediction_entropy_val', tf.reduce_sum(valsumplh_pent))
sumwriter = tf.train.SummaryWriter(args.save_dir, graph=tf.get_default_graph())
befstarttime = time.time()
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
print("====================================")
allvars = tf.all_variables()
trainablevars = tf.trainable_variables()
trainableMB = 0
for tvar in allvars:
#print(type(tvar))
#print(tvar.name+" -- "+str(tvar.dtype)+" -- "+str(tvar.get_shape()))
if tvar in trainablevars:
print("@@@ "+tvar.name+" -- "+str(tvar.get_shape()))
trainableMB += 4*tvar.get_shape().num_elements()
else:
print(tvar.name+" -- "+str(tvar.get_shape()))
print(" ")
print("trainable megabytes: "+str(float(trainableMB)/1e6))
#.........这里部分代码省略.........