當前位置: 首頁>>代碼示例>>Python>>正文


Python RPropMinusTrainer.trainOnDataset方法代碼示例

本文整理匯總了Python中pybrain.supervised.RPropMinusTrainer.trainOnDataset方法的典型用法代碼示例。如果您正苦於以下問題:Python RPropMinusTrainer.trainOnDataset方法的具體用法?Python RPropMinusTrainer.trainOnDataset怎麽用?Python RPropMinusTrainer.trainOnDataset使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在pybrain.supervised.RPropMinusTrainer的用法示例。


在下文中一共展示了RPropMinusTrainer.trainOnDataset方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: learn

# 需要導入模塊: from pybrain.supervised import RPropMinusTrainer [as 別名]
# 或者: from pybrain.supervised.RPropMinusTrainer import trainOnDataset [as 別名]

#.........這裏部分代碼省略.........
                                        self.dictIn["WORD_"+word]=index
                                        index+=1
            self.TOTALSIZEOFSENTENCEFeature=index
            f=open(self.FileNameofNumSentenceFeature,"wb")
            pickle.dump(self.TOTALSIZEOFSENTENCEFeature,f)
            f.close()
        elif self.isUseSentenceRepresentationInsteadofBOW:
            index=0
            for i in range(0,LSTMWithBOWTracker.D2V_VECTORSIZE):
                self.dictIn[str(index)+"thElemPV"]=index
                index+=1
            index=0
            for i in range(0,LSTMWithBOWTracker.D2V_VECTORSIZE):
                self.dictIn[str(index)+"thAvrWord"]=index
                index+=1
            assert self.D2V_VECTORSIZE == LSTMWithBOWTracker.D2V_VECTORSIZE, "D2V_VECTORSIZE is restrected to be same over the class"
        else:
            assert False, "Unexpected block" 
        #--(sub input vector 3) Features M1s defined
        index=0
        if self.isEnableToUseM1sFeature:
            rejisteredFeatures=self.__rejisterM1sInputFeatureLabel(self.tagsets,dataset)
            for rFeature in rejisteredFeatures:
                assert rFeature not in self.dictIn, rFeature +" already registered in input vector. Use different label name. "
                self.dictIn[rFeature]=index
                index+=1
            self.TOTALSIZEOFM1DEFINEDFeature=index
            f=open(self.FileNameofNumM1Feature,"wb")
            pickle.dump(self.TOTALSIZEOFM1DEFINEDFeature,f)
            f.close()

        print "inputSize:"+str(len(self.dictIn.keys()))
        assert self.dictIn["CLASS_INFO"] == 0, "Unexpected index CLASS_INFO should has value 0"
        assert self.dictIn["CLASS_Fort Siloso"] == 334, "Unexpected index CLASS_Fort Siloso should has value 334"
        assert self.dictIn["CLASS_Yunnan"] == 1344, "Unexpected index CLASS_Yunnan should has value 1611"
        #--write 
        fileObject = open('dictInput.pic', 'w')
        pickle.dump(self.dictIn, fileObject)
        fileObject.close()
        fileObject = open('dictOutput.pic', 'w')
        pickle.dump(self.dictOut, fileObject)
        fileObject.close()
        
        #Build RNN frame work
        print "Start learning Network"
        #Capability of network is: (30 hidden units can represents 1048576 relations) wherease (10 hidden units can represents 1024)
        #Same to Henderson (http://www.aclweb.org/anthology/W13-4073)?
        net = buildNetwork(len(self.dictIn.keys()), numberOfHiddenUnit, len(self.dictOut.keys()), hiddenclass=LSTMLayer, outclass=SigmoidLayer, outputbias=False, recurrent=True)
        
        #Train network
        #-convert training data into sequence of vector 
        convDataset=[]#[call][uttr][input,targetvec]
        iuttr=0
        convCall=[]
        for elemDataset in dataset:
            for call in elemDataset:
                for (uttr,label) in call:
                    if self.isIgnoreUtterancesNotRelatedToMainTask:
                        if uttr['segment_info']['target_bio'] == "O":
                            continue
                    #-input
                    convInput=self._translateUtteranceIntoInputVector(uttr,call)
                    #-output
                    convOutput=[0.0]*len(self.dictOut.keys())#Occured:+1, Not occured:0
                    if "frame_label" in label:
                        for slot in label["frame_label"].keys():
                            for value in label["frame_label"][slot]:
                                convOutput[self.dictOut[uttr["segment_info"]["topic"]+"_"+slot+"_"+value]]=1
                    #-post proccess
                    if self.isSeparateDialogIntoSubDialog:
                        if uttr['segment_info']['target_bio'] == "B":
                            if len(convCall) > 0:
                                convDataset.append(convCall)
                            convCall=[]
                    convCall.append([convInput,convOutput])
                    #print "Converted utterance" + str(iuttr)
                    iuttr+=1
                if not self.isSeparateDialogIntoSubDialog:
                    if len(convCall) > 0:
                        convDataset.append(convCall)
                    convCall=[]
        #Online learning
        trainer = RPropMinusTrainer(net,weightdecay=weightdecayw)
        EPOCHS = EPOCHS_PER_CYCLE * CYCLES
        for i in xrange(CYCLES):
            #Shuffle order
            ds = SequentialDataSet(len(self.dictIn.keys()),len(self.dictOut.keys()))
            datInd=range(0,len(convDataset))
            random.shuffle(datInd)#Backpropergation already implemeted data shuffling, however though RpropMinus don't. 
            for ind in datInd:
                ds.newSequence()
                for convuttr in convDataset[ind]:
                    ds.addSample(convuttr[0],convuttr[1])
            #Evaluation and Train
            epoch = (i+1) * EPOCHS_PER_CYCLE
            print "\r epoch {}/{} Error={}".format(epoch, EPOCHS,trainer.testOnData(dataset=ds))
            stdout.flush()
            trainer.trainOnDataset(dataset=ds,epochs=EPOCHS_PER_CYCLE)
            NetworkWriter.writeToFile(trainer.module, "LSTM_"+"Epoche"+str(i+1)+".rnnw")
            NetworkWriter.writeToFile(trainer.module, "LSTM.rnnw")
開發者ID:cuihengbin,項目名稱:Dialogue-State-Tracking-using-LSTM,代碼行數:104,代碼來源:LSTMWithBOW.py


注:本文中的pybrain.supervised.RPropMinusTrainer.trainOnDataset方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。