當前位置: 首頁>>代碼示例>>Python>>正文


Python TextLoader.reset_batch_pointers方法代碼示例

本文整理匯總了Python中utils.TextLoader.reset_batch_pointers方法的典型用法代碼示例。如果您正苦於以下問題:Python TextLoader.reset_batch_pointers方法的具體用法?Python TextLoader.reset_batch_pointers怎麽用?Python TextLoader.reset_batch_pointers使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在utils.TextLoader的用法示例。


在下文中一共展示了TextLoader.reset_batch_pointers方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: train

# 需要導入模塊: from utils import TextLoader [as 別名]
# 或者: from utils.TextLoader import reset_batch_pointers [as 別名]
def train(args):
    print("training on \'"+args.data_dir+"\'")
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size
    
    # check compatibility if training is continued from previously saved model
    if args.init_from is not None:
        print("RELOADING FROM CHECKPOING")
        # check if all necessary files exist 
        assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from
        assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt,"No checkpoint found"
        assert ckpt.model_checkpoint_path,"No model path found in checkpoint"

        # open old config and check if models are compatible
        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = cPickle.load(f)
        need_be_same=["model","rnn_size","num_layers","seq_length"]
        for checkme in need_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme
        
        # open saved vocab/dict and check if vocabs/dicts are compatible
        with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f:
            saved_chars, saved_vocab = cPickle.load(f)
        assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!"
        assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!"

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    print("====================================")
    printargs(args)
    print("====================================")
    model = Model(args)

    def validateonce(expectationdropout=True, TrueIfVal_FalseIfTrain=True):
        data_loader.reset_batch_pointers()
        model.resetweights(expectationdropout=expectationdropout)
        state = model.resetstate()
        start = time.time()
        losses = []
        backupptrtr = data_loader.pointer_tr
        entrps = None
        truths = None
        allprobs = None
        for b in range(data_loader.num_batches_te):
            if TrueIfVal_FalseIfTrain:
                x, y = data_loader.next_batch_te()
            else:
                x, y = data_loader.next_batch_tr()
            # shapes of x and y are (batchsize, seqlength); each element is an integer from 0 to (vocabsize-1)
            feed = {model.input_data: x, model.targets: y, model.initial_state: state}
            feed = model.extrafeed(feed)
            state, probs, entropies = sess.run([model.final_state, model.probs, model.pred_entropy], feed)
            theseprobs = np.reshape(probs, (1, args.batch_size, args.seq_length, args.vocab_size))
            thesey = np.reshape(y, (args.batch_size, args.seq_length))
            allprobs = tryconcat(allprobs, theseprobs, axis=2)
            truths = tryconcat(truths, thesey, axis=1)
            y = y.flatten()
            for ii in range(y.size):
                losses.append(-np.log2(probs[ii,y[ii]]))
            thesentropies = np.reshape(entropies,(1,args.batch_size,args.seq_length))
            entrps = tryconcat(entrps, thesentropies, axis=2)
        data_loader.pointer_tr = backupptrtr
        end = time.time()
        testtimeperbatch = (end-start) / float(data_loader.num_batches_te)
        return (np.array(losses), truths, entrps, allprobs, testtimeperbatch)

    # for tensorboard
    valsumplh_cost = tf.placeholder(tf.float32, (1,), name="validation_summary_placeholder_cost")
    valsumplh_pent = tf.placeholder(tf.float32, (1,), name="validation_summary_placeholder_prediction_entropy")
    #reduce_sum fixes tensorflow scalar handling being weird (vector of size 1)
    valsumscs_cost = tf.scalar_summary('cost_val', tf.reduce_sum(valsumplh_cost))
    valsumscs_pent = tf.scalar_summary('prediction_entropy_val', tf.reduce_sum(valsumplh_pent))
    sumwriter = tf.train.SummaryWriter(args.save_dir, graph=tf.get_default_graph())
    
    befstarttime = time.time()
    
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())

        print("====================================")
        allvars = tf.all_variables()
        trainablevars = tf.trainable_variables()
        trainableMB = 0
        for tvar in allvars:
            #print(type(tvar))
            #print(tvar.name+" -- "+str(tvar.dtype)+" -- "+str(tvar.get_shape()))
            if tvar in trainablevars:
                print("@@@ "+tvar.name+" -- "+str(tvar.get_shape()))
                trainableMB += 4*tvar.get_shape().num_elements()
            else:
                print(tvar.name+" -- "+str(tvar.get_shape()))
        print(" ")
        print("trainable megabytes: "+str(float(trainableMB)/1e6))
#.........這裏部分代碼省略.........
開發者ID:jasonbunk,項目名稱:char-rnn-tensorflow,代碼行數:103,代碼來源:train.py


注:本文中的utils.TextLoader.reset_batch_pointers方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。