当前位置: 首页>>代码示例>>C++>>正文


C++ LabelledClassificationData::getClassTracker方法代码示例

本文整理汇总了C++中LabelledClassificationData::getClassTracker方法的典型用法代码示例。如果您正苦于以下问题:C++ LabelledClassificationData::getClassTracker方法的具体用法?C++ LabelledClassificationData::getClassTracker怎么用?C++ LabelledClassificationData::getClassTracker使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在LabelledClassificationData的用法示例。


在下文中一共展示了LabelledClassificationData::getClassTracker方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: merge

bool LabelledClassificationData::merge(LabelledClassificationData &labelledData){

    if( labelledData.getNumDimensions() != numDimensions ){
        errorLog << "merge(LabelledClassificationData &labelledData) - The number of dimensions in the labelledData (" << labelledData.getNumDimensions() << ") does not match the number of dimensions of this dataset (" << numDimensions << ")" << endl;
        return false;
    }

    //The dataset has changed so flag that any previous cross validation setup will now not work
    crossValidationSetup = false;
    crossValidationIndexs.clear();

    //Add the data from the labelledData to this instance
    for(UINT i=0; i<labelledData.getNumSamples(); i++){
        addSample(labelledData[i].getClassLabel(), labelledData[i].getSample());
    }

    //Set the class names from the dataset
    vector< ClassTracker > classTracker = labelledData.getClassTracker();
    for(UINT i=0; i<classTracker.size(); i++){
        setClassNameForCorrespondingClassLabel(classTracker[i].className, classTracker[i].classLabel);
    }

	sortClassLabels();

    return true;
}
开发者ID:gaurav38,项目名称:HackDuke13,代码行数:26,代码来源:LabelledClassificationData.cpp

示例2: train

bool Softmax::train(LabelledClassificationData trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl;
        return false;
    }
    
    numFeatures = N;
    numClasses = K;
    models.resize(K);
    classLabels.resize(K);
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }
    
    //Train a regression model for each class in the training data
    for(UINT k=0; k<numClasses; k++){
        
        //Set the class label
        classLabels[k] = trainingData.getClassTracker()[k].classLabel;
        
        //Train the model
        if( !trainSoftmaxModel(classLabels[k],models[k],trainingData) ){
            errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train model for class: " << classLabels[k] << endl;
            return false;
        }
    }
    
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
}
开发者ID:elaye,项目名称:AST_diabolo_tracking,代码行数:43,代码来源:Softmax.cpp

示例3: train_

bool KNN::train_(LabelledClassificationData &trainingData,UINT K){

    //Clear any previous models
    clear();

    if( trainingData.getNumSamples() == 0 ){
        errorLog << "train(LabelledClassificationData &trainingData) - Training data has zero samples!" << endl;
        return false;
    }

    //Set the dimensionality of the input data
    this->K = K;
    this->numFeatures = trainingData.getNumDimensions();
    this->numClasses = trainingData.getNumClasses();

    //TODO: In the future need to build a kdtree from the training data to allow better realtime prediction
    this->trainingData = trainingData;

    if( useScaling ){
        ranges = this->trainingData.getRanges();
        this->trainingData.scale(ranges, 0, 1);
    }

    //Set the class labels
    classLabels.resize(numClasses);
    for(UINT k=0; k<numClasses; k++){
        classLabels[k] = trainingData.getClassTracker()[k].classLabel;
    }

    //Flag that the algorithm has been trained so we can compute the rejection thresholds
    trained = true;
    
    //If null rejection is enabled then compute the null rejection thresholds
    if( useNullRejection ){

        //Set the null rejection to false so we can compute the values for it (this will be set back to its current value later)
        bool tempUseNullRejection = useNullRejection;
        useNullRejection = false;
        rejectionThresholds.clear();

        //Compute the rejection thresholds for each of the K classes
        VectorDouble counter(numClasses,0);
        trainingMu.resize( numClasses, 0 );
        trainingSigma.resize( numClasses, 0 );
        rejectionThresholds.resize( numClasses, 0 );

        //Compute Mu for each of the classes
        const unsigned int numTrainingExamples = trainingData.getNumSamples();
        vector< IndexedDouble > predictionResults( numTrainingExamples );
        for(UINT i=0; i<numTrainingExamples; i++){
            predict( trainingData[i].getSample(), K);

            UINT classLabelIndex = 0;
            for(UINT k=0; k<numClasses; k++){
                if( predictedClassLabel == classLabels[k] ){
                    classLabelIndex = k;
                    break;
                }
            }

            predictionResults[ i ].index = classLabelIndex;
            predictionResults[ i ].value = classDistances[ classLabelIndex ];

            trainingMu[ classLabelIndex ] += predictionResults[ i ].value;
            counter[ classLabelIndex ]++;
        }

        for(UINT j=0; j<numClasses; j++){
            trainingMu[j] /= counter[j];
        }

        //Compute Sigma for each of the classes
        for(UINT i=0; i<numTrainingExamples; i++){
            trainingSigma[predictionResults[i].index] += SQR(predictionResults[i].value - trainingMu[predictionResults[i].index]);
        }

        for(UINT j=0; j<numClasses; j++){
            double count = counter[j];
            if( count > 1 ){
                trainingSigma[ j ] = sqrt( trainingSigma[j] / (count-1) );
            }else{
                trainingSigma[ j ] = 1.0;
            }
        }

        //Check to see if any of the mu or sigma values are zero or NaN
        bool errorFound = false;
        for(UINT j=0; j<numClasses; j++){
            if( trainingMu[j] == 0 ){
                warningLog << "TrainingMu[ " << j << " ] is zero for a K value of " << K << endl;
            }
            if( trainingSigma[j] == 0 ){
                warningLog << "TrainingSigma[ " << j << " ] is zero for a K value of " << K << endl;
            }
            if( isnan( trainingMu[j] ) ){
                errorLog << "TrainingMu[ " << j << " ] is NAN for a K value of " << K << endl;
                errorFound = true;
            }
            if( isnan( trainingSigma[j] ) ){
                errorLog << "TrainingSigma[ " << j << " ] is NAN for a K value of " << K << endl;
//.........这里部分代码省略.........
开发者ID:pixelmaid,项目名称:shape-recog,代码行数:101,代码来源:KNN.cpp

示例4: train

bool ANBC::train(LabelledClassificationData &labelledTrainingData,double gamma) {

    const unsigned int M = labelledTrainingData.getNumSamples();
    const unsigned int N = labelledTrainingData.getNumDimensions();
    const unsigned int K = labelledTrainingData.getNumClasses();
    trained = false;
    models.clear();
    classLabels.clear();

    if( M == 0 ) {
        errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - Training data has zero samples!" << endl;
        return false;
    }

    if( weightsDataSet ) {
        if( weightsData.getNumDimensions() != N ) {
            errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - The number of dimensions in the weights data (" << weightsData.getNumDimensions() << ") is not equal to the number of dimensions of the training data (" << N << ")" << endl;
            return false;
        }
    }

    numFeatures = N;
    numClasses = K;
    models.resize(K);
    classLabels.resize(K);
    ranges = labelledTrainingData.getRanges();

    //Train each of the models
    for(UINT k=0; k<numClasses; k++) {

        //Get the class label for the kth class
        UINT classLabel = labelledTrainingData.getClassTracker()[k].classLabel;

        //Set the kth class label
        classLabels[k] = classLabel;

        //Get the weights for this class
        VectorDouble weights(numFeatures);
        if( weightsDataSet ) {
            bool weightsFound = false;
            for(UINT i=0; i<weightsData.getNumSamples(); i++) {
                if( weightsData[i].getClassLabel() == classLabel ) {
                    weights = weightsData[i].getSample();
                    weightsFound = true;
                    break;
                }
            }

            if( !weightsFound ) {
                errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - Failed to find the weights for class " << classLabel << endl;
                return false;
            }
        } else {
            //If the weights data has not been set then all the weights are 1
            for(UINT j=0; j<numFeatures; j++) weights[j] = 1.0;
        }

        //Get all the training data for this class
        LabelledClassificationData classData = labelledTrainingData.getClassData(classLabel);
        MatrixDouble data(classData.getNumSamples(),N);

        //Copy the training data into a matrix, scaling the training data if needed
        for(UINT i=0; i<data.getNumRows(); i++) {
            for(UINT j=0; j<data.getNumCols(); j++) {
                if( useScaling ) {
                    data[i][j] = scale(classData[i][j],ranges[j].minValue,ranges[j].maxValue,MIN_SCALE_VALUE,MAX_SCALE_VALUE);
                } else data[i][j] = classData[i][j];
            }
        }

        //Train the model for this class
        models[k].gamma = gamma;
        if( !models[k].train(classLabel,data,weights) ) {
            errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - Failed to train model for class: " << classLabel << endl;

            //Try and work out why the training failed
            if( models[k].N == 0 ) {
                errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - N == 0!" << endl;
                models.clear();
                return false;
            }
            for(UINT j=0; j<numFeatures; j++) {
                if( models[k].mu[j] == 0 ) {
                    errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - The mean of column " << j+1 << " is zero! Check the training data" << endl;
                    models.clear();
                    return false;
                }
            }
            models.clear();
            return false;
        }

    }

    //Store the null rejection thresholds
    nullRejectionThresholds.resize(numClasses);
    for(UINT k=0; k<numClasses; k++) {
        nullRejectionThresholds[k] = models[k].threshold;
    }

//.........这里部分代码省略.........
开发者ID:pixelmaid,项目名称:shape-recog,代码行数:101,代码来源:ANBC.cpp

示例5: train

bool MinDist::train(LabelledClassificationData &labelledTrainingData,double gamma){
    
    const unsigned int M = labelledTrainingData.getNumSamples();
    const unsigned int N = labelledTrainingData.getNumDimensions();
    const unsigned int K = labelledTrainingData.getNumClasses();
    trained = false;
    models.clear();
    classLabels.clear();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - Training data has zero samples!" << endl;
        return false;
    }
    
    if( M <= numClusters ){
        errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - There are not enough training samples for the number of clusters. Either reduce the number of clusters or increase the number of training samples!" << endl;
        return false;
    }

    numFeatures = N;
    numClasses = K;
    models.resize(K);
    classLabels.resize(K);
    ranges = labelledTrainingData.getRanges();
    
    //Train each of the models
	for(UINT k=0; k<numClasses; k++){
        
        //Get the class label for the kth class
        UINT classLabel = labelledTrainingData.getClassTracker()[k].classLabel;
        
        //Set the kth class label
        classLabels[k] = classLabel;
        
        //Get all the training data for this class
        LabelledClassificationData classData = labelledTrainingData.getClassData(classLabel);
        MatrixDouble data(classData.getNumSamples(),N);
        
        //Copy the training data into a matrix, scaling the training data if needed
        for(UINT i=0; i<data.getNumRows(); i++){
            for(UINT j=0; j<data.getNumCols(); j++){
                if( useScaling ){
                    data[i][j] = scale(classData[i][j],ranges[j].minValue,ranges[j].maxValue,0,1);
                }else data[i][j] = classData[i][j];
            }
        }
        
        //Train the model for this class
		models[k].setGamma( gamma );
		if( !models[k].train(classLabel,data,numClusters) ){
            errorLog << "train(LabelledClassificationData &labelledTrainingData,double gamma) - Failed to train model for class: " << classLabel;
            errorLog << ". This is might be because this class does not have enough training samples! You should reduce the number of clusters or increase the number of training samples for this class." << endl;
            models.clear();
            return false;
        }
        
	}
    
    trained = true;
    return true;
}
开发者ID:MarkusKonk,项目名称:Geographic-Interaction,代码行数:61,代码来源:MinDist.cpp

示例6: train

bool GMM::train(LabelledClassificationData trainingData){
    
    //Clear any old models
    models.clear();
    trained = false;
    numFeatures = 0;
    numClasses = 0;
    
    if( trainingData.getNumSamples() == 0 ){
        errorLog << "train(LabelledClassificationData &trainingData) - Training data is empty!" << endl;
        return false;
    }
    
    //Set the number of features and number of classes and resize the models buffer
    numFeatures = trainingData.getNumDimensions();
    numClasses = trainingData.getNumClasses();
    models.resize(numClasses);
    
    if( numFeatures >= 6 ){
        warningLog << "train(LabelledClassificationData &trainingData) - The number of features in your training data is high (" << numFeatures << ").  The GMMClassifier does not work well with high dimensional data, you might get better results from one of the other classifiers." << endl;
    }
    
    //Get the ranges of the training data if the training data is going to be scaled
    if( useScaling ){
        ranges = trainingData.getRanges();
    }

    //Fit a Mixture Model to each class (independently)
    for(UINT k=0; k<numClasses; k++){
        UINT classLabel = trainingData.getClassTracker()[k].classLabel;
        LabelledClassificationData classData = trainingData.getClassData( classLabel );
        
        //Scale the training data if needed
        if( useScaling ){
            if( !classData.scale(ranges,GMM_MIN_SCALE_VALUE, GMM_MAX_SCALE_VALUE) ){
                errorLog << "train(LabelledClassificationData &trainingData) - Failed to scale training data!" << endl;
                return false;

            }
        }
        
        //Convert the labelled data to unlabelled data
        UnlabelledClassificationData unlabelledData = classData.reformatAsUnlabelledClassificationData();
        
        //Train the Mixture Model for this class
        GaussianMixtureModels gaussianMixtureModel;
        gaussianMixtureModel.setMinChange( minChange );
        gaussianMixtureModel.setMaxIter( maxIter );
        if( !gaussianMixtureModel.train(unlabelledData, numMixtureModels) ){
            errorLog << "train(LabelledClassificationData &trainingData) - Failed to train Mixture Model for class " << classLabel << endl;
            return false;
        }
        
        //Setup the model container
        models[k].resize( numMixtureModels );
        models[k].setClassLabel( classLabel );
        
        //Store the mixture model in the container
        for(UINT j=0; j<numMixtureModels; j++){
            models[k][j].mu = gaussianMixtureModel.getMu().getRowVector(j);
            models[k][j].sigma = gaussianMixtureModel.getSigma()[j];
            
            //Compute the determinant and invSigma for the realtime prediction
            LUDecomposition ludcmp(models[k][j].sigma);
            if( !ludcmp.inverse( models[k][j].invSigma ) ){
                models.clear();
                errorLog << "train(LabelledClassificationData &trainingData) - Failed to invert Matrix for class " << classLabel << "!" << endl;
                return false;
            }
            models[k][j].det = ludcmp.det();
        }
        
        //Compute the normalize factor
        models[k].recomputeNormalizationFactor();
        
        //Compute the rejection thresholds
        double mu = 0;
        double sigma = 0;
        VectorDouble predictionResults(classData.getNumSamples(),0);
        for(UINT i=0; i<classData.getNumSamples(); i++){
            vector< double > sample = classData[i].getSample();
            predictionResults[i] = models[k].computeMixtureLikelihood( sample );
            mu += predictionResults[i];
        }
        
        //Update mu
        mu /= double( classData.getNumSamples() );
        
        //Calculate the standard deviation
        for(UINT i=0; i<classData.getNumSamples(); i++) 
            sigma += SQR( (predictionResults[i]-mu) );
        sigma = sqrt( sigma / (double(classData.getNumSamples())-1.0) );
        sigma = 0.2;
        
        //Set the models training mu and sigma 
        models[k].setTrainingMuAndSigma(mu,sigma);
        
        if( !models[k].recomputeNullRejectionThreshold(nullRejectionCoeff) && useNullRejection ){
            warningLog << "train(LabelledClassificationData &trainingData) - Failed to recompute rejection threshold for class " << classLabel << " - the nullRjectionCoeff value is too high!" << endl;
        }
//.........这里部分代码省略.........
开发者ID:gaurav38,项目名称:HackDuke13,代码行数:101,代码来源:GMM.cpp


注:本文中的LabelledClassificationData::getClassTracker方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。