当前位置: 首页>>代码示例>>Python>>正文


Python TextLoader.next_batch_tr方法代码示例

本文整理汇总了Python中utils.TextLoader.next_batch_tr方法的典型用法代码示例。如果您正苦于以下问题:Python TextLoader.next_batch_tr方法的具体用法?Python TextLoader.next_batch_tr怎么用?Python TextLoader.next_batch_tr使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.TextLoader的用法示例。


在下文中一共展示了TextLoader.next_batch_tr方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from utils import TextLoader [as 别名]
# 或者: from utils.TextLoader import next_batch_tr [as 别名]
def train(args):
    print("training on \'"+args.data_dir+"\'")
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size
    
    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        print("RELOADING FROM CHECKPOING")
        # check if all necessary files exist 
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
        
        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    print("====================================")
    printargs(args)
    print("====================================")
    model = Model(args)

    def validateonce(expectationdropout=True, TrueIfVal_FalseIfTrain=True):
        data_loader.reset_batch_pointers()
        model.resetweights(expectationdropout=expectationdropout)
        state = model.resetstate()
        start = time.time()
        losses = []
        backupptrtr = data_loader.pointer_tr
        entrps = None
        truths = None
        allprobs = None
        for b in range(data_loader.num_batches_te):
            if TrueIfVal_FalseIfTrain:
                x, y = data_loader.next_batch_te()
            else:
                x, y = data_loader.next_batch_tr()
            # shapes of x and y are (batchsize, seqlength); each element is an integer from 0 to (vocabsize-1)
            feed = {model.input_data: x, model.targets: y, model.initial_state: state}
            feed = model.extrafeed(feed)
            state, probs, entropies = sess.run([model.final_state, model.probs, model.pred_entropy], feed)
            theseprobs = np.reshape(probs, (1, args.batch_size, args.seq_length, args.vocab_size))
            thesey = np.reshape(y, (args.batch_size, args.seq_length))
            allprobs = tryconcat(allprobs, theseprobs, axis=2)
            truths = tryconcat(truths, thesey, axis=1)
            y = y.flatten()
            for ii in range(y.size):
                losses.append(-np.log2(probs[ii,y[ii]]))
            thesentropies = np.reshape(entropies,(1,args.batch_size,args.seq_length))
            entrps = tryconcat(entrps, thesentropies, axis=2)
        data_loader.pointer_tr = backupptrtr
        end = time.time()
        testtimeperbatch = (end-start) / float(data_loader.num_batches_te)
        return (np.array(losses), truths, entrps, allprobs, testtimeperbatch)

    # for tensorboard
    valsumplh_cost = tf.placeholder(tf.float32, (1,), name="validation_summary_placeholder_cost")
    valsumplh_pent = tf.placeholder(tf.float32, (1,), name="validation_summary_placeholder_prediction_entropy")
    #reduce_sum fixes tensorflow scalar handling being weird (vector of size 1)
    valsumscs_cost = tf.scalar_summary('cost_val', tf.reduce_sum(valsumplh_cost))
    valsumscs_pent = tf.scalar_summary('prediction_entropy_val', tf.reduce_sum(valsumplh_pent))
    sumwriter = tf.train.SummaryWriter(args.save_dir, graph=tf.get_default_graph())
    
    befstarttime = time.time()
    
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())

        print("====================================")
        allvars = tf.all_variables()
        trainablevars = tf.trainable_variables()
        trainableMB = 0
        for tvar in allvars:
            #print(type(tvar))
            #print(tvar.name+" -- "+str(tvar.dtype)+" -- "+str(tvar.get_shape()))
            if tvar in trainablevars:
                print("@@@ "+tvar.name+" -- "+str(tvar.get_shape()))
                trainableMB += 4*tvar.get_shape().num_elements()
            else:
                print(tvar.name+" -- "+str(tvar.get_shape()))
        print(" ")
        print("trainable megabytes: "+str(float(trainableMB)/1e6))
#.........这里部分代码省略.........
开发者ID:jasonbunk,项目名称:char-rnn-tensorflow,代码行数:103,代码来源:train.py


注:本文中的utils.TextLoader.next_batch_tr方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。