当前位置: 首页>>代码示例>>Python>>正文


Python RPropMinusTrainer.trainOnDataset方法代码示例

本文整理汇总了Python中pybrain.supervised.RPropMinusTrainer.trainOnDataset方法的典型用法代码示例。如果您正苦于以下问题:Python RPropMinusTrainer.trainOnDataset方法的具体用法?Python RPropMinusTrainer.trainOnDataset怎么用?Python RPropMinusTrainer.trainOnDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pybrain.supervised.RPropMinusTrainer的用法示例。


在下文中一共展示了RPropMinusTrainer.trainOnDataset方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: learn

# 需要导入模块: from pybrain.supervised import RPropMinusTrainer [as 别名]
# 或者: from pybrain.supervised.RPropMinusTrainer import trainOnDataset [as 别名]

#.........这里部分代码省略.........
                                        self.dictIn["WORD_"+word]=index
                                        index+=1
            self.TOTALSIZEOFSENTENCEFeature=index
            f=open(self.FileNameofNumSentenceFeature,"wb")
            pickle.dump(self.TOTALSIZEOFSENTENCEFeature,f)
            f.close()
        elif self.isUseSentenceRepresentationInsteadofBOW:
            index=0
            for i in range(0,LSTMWithBOWTracker.D2V_VECTORSIZE):
                self.dictIn[str(index)+"thElemPV"]=index
                index+=1
            index=0
            for i in range(0,LSTMWithBOWTracker.D2V_VECTORSIZE):
                self.dictIn[str(index)+"thAvrWord"]=index
                index+=1
            assert self.D2V_VECTORSIZE == LSTMWithBOWTracker.D2V_VECTORSIZE, "D2V_VECTORSIZE is restrected to be same over the class"
        else:
            assert False, "Unexpected block" 
        #--(sub input vector 3) Features M1s defined
        index=0
        if self.isEnableToUseM1sFeature:
            rejisteredFeatures=self.__rejisterM1sInputFeatureLabel(self.tagsets,dataset)
            for rFeature in rejisteredFeatures:
                assert rFeature not in self.dictIn, rFeature +" already registered in input vector. Use different label name. "
                self.dictIn[rFeature]=index
                index+=1
            self.TOTALSIZEOFM1DEFINEDFeature=index
            f=open(self.FileNameofNumM1Feature,"wb")
            pickle.dump(self.TOTALSIZEOFM1DEFINEDFeature,f)
            f.close()

        print "inputSize:"+str(len(self.dictIn.keys()))
        assert self.dictIn["CLASS_INFO"] == 0, "Unexpected index CLASS_INFO should has value 0"
        assert self.dictIn["CLASS_Fort Siloso"] == 334, "Unexpected index CLASS_Fort Siloso should has value 334"
        assert self.dictIn["CLASS_Yunnan"] == 1344, "Unexpected index CLASS_Yunnan should has value 1611"
        #--write 
        fileObject = open('dictInput.pic', 'w')
        pickle.dump(self.dictIn, fileObject)
        fileObject.close()
        fileObject = open('dictOutput.pic', 'w')
        pickle.dump(self.dictOut, fileObject)
        fileObject.close()
        
        #Build RNN frame work
        print "Start learning Network"
        #Capability of network is: (30 hidden units can represents 1048576 relations) wherease (10 hidden units can represents 1024)
        #Same to Henderson (http://www.aclweb.org/anthology/W13-4073)?
        net = buildNetwork(len(self.dictIn.keys()), numberOfHiddenUnit, len(self.dictOut.keys()), hiddenclass=LSTMLayer, outclass=SigmoidLayer, outputbias=False, recurrent=True)
        
        #Train network
        #-convert training data into sequence of vector 
        convDataset=[]#[call][uttr][input,targetvec]
        iuttr=0
        convCall=[]
        for elemDataset in dataset:
            for call in elemDataset:
                for (uttr,label) in call:
                    if self.isIgnoreUtterancesNotRelatedToMainTask:
                        if uttr['segment_info']['target_bio'] == "O":
                            continue
                    #-input
                    convInput=self._translateUtteranceIntoInputVector(uttr,call)
                    #-output
                    convOutput=[0.0]*len(self.dictOut.keys())#Occured:+1, Not occured:0
                    if "frame_label" in label:
                        for slot in label["frame_label"].keys():
                            for value in label["frame_label"][slot]:
                                convOutput[self.dictOut[uttr["segment_info"]["topic"]+"_"+slot+"_"+value]]=1
                    #-post proccess
                    if self.isSeparateDialogIntoSubDialog:
                        if uttr['segment_info']['target_bio'] == "B":
                            if len(convCall) > 0:
                                convDataset.append(convCall)
                            convCall=[]
                    convCall.append([convInput,convOutput])
                    #print "Converted utterance" + str(iuttr)
                    iuttr+=1
                if not self.isSeparateDialogIntoSubDialog:
                    if len(convCall) > 0:
                        convDataset.append(convCall)
                    convCall=[]
        #Online learning
        trainer = RPropMinusTrainer(net,weightdecay=weightdecayw)
        EPOCHS = EPOCHS_PER_CYCLE * CYCLES
        for i in xrange(CYCLES):
            #Shuffle order
            ds = SequentialDataSet(len(self.dictIn.keys()),len(self.dictOut.keys()))
            datInd=range(0,len(convDataset))
            random.shuffle(datInd)#Backpropergation already implemeted data shuffling, however though RpropMinus don't. 
            for ind in datInd:
                ds.newSequence()
                for convuttr in convDataset[ind]:
                    ds.addSample(convuttr[0],convuttr[1])
            #Evaluation and Train
            epoch = (i+1) * EPOCHS_PER_CYCLE
            print "\r epoch {}/{} Error={}".format(epoch, EPOCHS,trainer.testOnData(dataset=ds))
            stdout.flush()
            trainer.trainOnDataset(dataset=ds,epochs=EPOCHS_PER_CYCLE)
            NetworkWriter.writeToFile(trainer.module, "LSTM_"+"Epoche"+str(i+1)+".rnnw")
            NetworkWriter.writeToFile(trainer.module, "LSTM.rnnw")
开发者ID:cuihengbin,项目名称:Dialogue-State-Tracking-using-LSTM,代码行数:104,代码来源:LSTMWithBOW.py


注:本文中的pybrain.supervised.RPropMinusTrainer.trainOnDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。