当前位置: 首页>>代码示例>>C++>>正文


C++ LabelledClassificationData::getClassLabels方法代码示例

本文整理汇总了C++中LabelledClassificationData::getClassLabels方法的典型用法代码示例。如果您正苦于以下问题:C++ LabelledClassificationData::getClassLabels方法的具体用法?C++ LabelledClassificationData::getClassLabels怎么用?C++ LabelledClassificationData::getClassLabels使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在LabelledClassificationData的用法示例。


在下文中一共展示了LabelledClassificationData::getClassLabels方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: train

bool RandomForests::train(LabelledClassificationData trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl;
        return false;
    }
    
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }
    
    //Train the random forest
    forestSize = 10;
    Random random;
    
    DecisionTree tree;
    tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
    tree.setTrainingMode( DecisionTree::BEST_RANDOM_SPLIT );
    tree.setNumSplittingSteps( numRandomSplits );
    tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
    tree.setMaxDepth( maxDepth );
    
    for(UINT i=0; i<forestSize; i++){
        LabelledClassificationData data = trainingData.getBootstrappedDataset();
        
        if( !tree.train( data ) ){
            errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train tree at forest index: " << i << endl;
            return false;
        }
        
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    }
    
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
}
开发者ID:Mr07,项目名称:MA-Gesture-Recognition,代码行数:52,代码来源:RandomForests.cpp

示例2: train

bool BAG::train(LabelledClassificationData trainingData){
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    trained = false;
    classLabels.clear();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData trainingData) - Training data has zero samples!" << endl;
        return false;
    }
    
    numFeatures = N;
    numClasses = K;
    classLabels.resize(K);
    ranges = trainingData.getRanges();
    
    UINT ensembleSize = (UINT)ensemble.size();
    
    if( ensembleSize == 0 ){
        errorLog << "train(LabelledClassificationData trainingData) - The ensemble size is zero! You need to add some classifiers to the ensemble first." << endl;
        return false;
    }
    
    for(UINT i=0; i<ensembleSize; i++){
        if( ensemble[i] == NULL ){
            errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " has not been set!" << endl;
            return false;
        }
    }

    //Train the ensemble
    for(UINT i=0; i<ensembleSize; i++){
        LabelledClassificationData boostedDataset = trainingData.getBootstrappedDataset();
        
        //Train the classifier with the bootstrapped dataset
        if( !ensemble[i]->train( boostedDataset ) ){
            errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " failed training!" << endl;
            return false;
        }
    }
    
    //Set the class labels
    classLabels = trainingData.getClassLabels();
    
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
}
开发者ID:pixelmaid,项目名称:evodraw,代码行数:50,代码来源:BAG.cpp


注:本文中的LabelledClassificationData::getClassLabels方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。