当前位置: 首页>>代码示例>>Scala>>正文


Scala RandomForest类代码示例

本文整理汇总了Scala中org.apache.spark.mllib.tree.RandomForest的典型用法代码示例。如果您正苦于以下问题:Scala RandomForest类的具体用法?Scala RandomForest怎么用?Scala RandomForest使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


在下文中一共展示了RandomForest类的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Scala代码示例。

示例1: MLLibRandomForestModel

//设置package包名称以及导入依赖的类
package com.asto.dmp.articlecate.biz

import com.asto.dmp.articlecate.base.Props
import com.asto.dmp.articlecate.utils.FileUtils
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import com.asto.dmp.articlecate.biz.ClsFeaturesParser._
import scala.collection._

class MLLibRandomForestModel(val sc: SparkContext, val modelPath: String) extends scala.Serializable with Logging {

  def genRandomForestModel(svmTrainDataPath: String) = {
    val numClasses = ClsFeaturesParser.clsNameToCodeMap.size //Util.parseMapFrom(clsIndicesPath, nameToCode = true).size
    val categoricalFeaturesInfo = immutable.Map[Int, Int]()
    val numTrees = Props.get("model_numTrees").toInt
    val featureSubsetStrategy = Props.get("model_featureSubsetStrategy") // Let the algorithm choose.
    val impurity = Props.get("model_impurity")
    val maxDepth = Props.get("model_maxDepth").toInt
    val maxBins = Props.get("model_maxBins").toInt

    val trainingData = MLUtils.loadLibSVMFile(sc, svmTrainDataPath).cache()
    val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

    FileUtils.deleteFilesInHDFS(modelPath)
    model.save(sc, modelPath)

    testErrorRate(trainingData, model)
  }

  
  private def testErrorRate(trainingData: RDD[LabeledPoint], model: RandomForestModel) = {
    if (Props.get("model_test").toBoolean) {
      val testData = trainingData.sample(false, Props.get("model_sampleRate").toDouble)
      val labelAndPreds = testData.map { point =>
        val prediction = model.predict(point.features)
        (point.label, prediction)
      }
      val testError = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
      logInfo(s"????????????$testError")
    } else {
      logInfo(s"???????????")
    }
  }

  def predictAndSave(lineAndVectors: Array[(String, org.apache.spark.mllib.linalg.Vector)], resultPath: String) = {
    val model = RandomForestModel.load(sc, modelPath)
    val result = lineAndVectors.map(lv => (s"${clsCodeToNameMap(model.predict(lv._2).toInt.toString)}\t${lv._1}")).mkString("\n")
    FileUtils.saveFileToHDFS(resultPath, result)
  }
} 
开发者ID:luciuschina,项目名称:ArticleCategories,代码行数:56,代码来源:MLLibRandomForestModel.scala

示例2: RandomForestTest

//设置package包名称以及导入依赖的类
package cn.edu.bjtu


import org.apache.spark.SparkConf
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession
object RandomForestTest {
  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf()
      .setAppName("RandomForestTest")
      .setMaster("spark://master:7077")
      .setJars(Array("/home/hadoop/RandomForest.jar"))

    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()

    spark.sparkContext.setLogLevel("WARN")

    // Load and parse the data file.
    val data = MLUtils.loadLibSVMFile(spark.sparkContext, "hdfs://master:9000/sample_formatted.txt")

    // Split the data into training and test sets (30% held out for testing)
    val splits = data.randomSplit(Array(0.7, 0.3))

    val (training, test) = (splits(0), splits(1))

    // Train a RandomForest model.
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    val numClasses = 2
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 3 // Use more in practice.
    val featureSubsetStrategy = "18" // Let the algorithm choose.
    val impurity = "gini"
    val maxDepth = 14
    val maxBins = 16384

    val model = RandomForest.trainClassifier(training, numClasses, categoricalFeaturesInfo,
      numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

    val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }
    val metrics = new BinaryClassificationMetrics(predictionAndLabels)
    val auROC = metrics.areaUnderROC()
    println("Area under ROC = " + auROC)
    println("Sensitivity = " + predictionAndLabels.filter(x => x._1 == x._2 && x._1 == 1.0).count().toDouble / predictionAndLabels.filter(x => x._2 == 1.0).count().toDouble)
    println("Specificity = " + predictionAndLabels.filter(x => x._1 == x._2 && x._1 == 0.0).count().toDouble / predictionAndLabels.filter(x => x._2 == 0.0).count().toDouble)
    println("Accuracy = " + predictionAndLabels.filter(x => x._1 == x._2).count().toDouble / predictionAndLabels.count().toDouble)
  }
} 
开发者ID:XiaoyuGuo,项目名称:DataFusionClass,代码行数:57,代码来源:RandomForestTest.scala

示例3: RandomForestAlgorithmParams

//设置package包名称以及导入依赖的类
package org.template.classification

import org.apache.predictionio.controller.P2LAlgorithm
import org.apache.predictionio.controller.Params

import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext

import grizzled.slf4j.Logger

case class RandomForestAlgorithmParams(
  numClasses: Int,
  numTrees: Int,
  featureSubsetStrategy: String,
  impurity: String,
  maxDepth: Int,
  maxBins: Int
) extends Params

// extends P2LAlgorithm because the MLlib's RandomForestAlgorithm doesn't contain RDD.
class RandomForestAlgorithm(val ap: RandomForestAlgorithmParams)
  extends P2LAlgorithm[PreparedData, RandomForestModel, Query, PredictedResult] {

  @transient lazy val logger = Logger[this.type]

  def train(sc: SparkContext, data: PreparedData): RandomForestModel = {// Empty categoricalFeaturesInfo indicates all features are continuous.
    val categoricalFeaturesInfo = Map[Int, Int]()
    RandomForest.trainClassifier(
      data.labeledPoints,
      ap.numClasses,
      categoricalFeaturesInfo,
      ap.numTrees,
      ap.featureSubsetStrategy,
      ap.impurity,
      ap.maxDepth,
      ap.maxBins)
  }

  def predict(model: RandomForestModel, query: Query): PredictedResult = {
    val features = Vectors.dense(
      Array(query.voice_usage, query.data_usage, query.text_usage)
    )
    val label = model.predict(features)
    new PredictedResult(label)
  }

} 
开发者ID:wmfongsf,项目名称:predictionio-engine-classification,代码行数:50,代码来源:RandomForestAlgorithm.scala

示例4: RandomForestAlgorithmTest

//设置package包名称以及导入依赖的类
package org.template.classification

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel

import org.scalatest.FlatSpec
import org.scalatest.Matchers

class RandomForestAlgorithmTest
  extends FlatSpec with SharedSingletonContext with Matchers {

  val params = RandomForestAlgorithmParams(
    numClasses = 7,
    numTrees = 12,
    featureSubsetStrategy = "auto",
    impurity = "gini",
    maxDepth = 4,
    maxBins = 100)
  val algorithm = new RandomForestAlgorithm(params)

  val dataSource = Seq(
    LabeledPoint(0, Vectors.dense(1000, 10, 10)),
    LabeledPoint(1, Vectors.dense(10, 1000, 10)),
    LabeledPoint(2, Vectors.dense(10, 10, 1000))
  )

  "train" should "return RandomForest model" in {
    val dataSourceRDD = sparkContext.parallelize(dataSource)
    val preparedData = new PreparedData(labeledPoints = dataSourceRDD)
    val model = algorithm.train(sparkContext, preparedData)
    model shouldBe a [RandomForestModel]
  }
} 
开发者ID:wmfongsf,项目名称:predictionio-engine-classification,代码行数:36,代码来源:RandomForestAlgorithmTest.scala

示例5: ModelTrainer

//设置package包名称以及导入依赖的类
package modelmanager

import java.io.File

import com.typesafe.config.Config
import org.apache.commons.io.FileUtils
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.streaming.StreamingContext

import scala.collection.JavaConversions._

object ModelTrainer {
  val nClasses: Int = 2
  val positiveLabel: Double = 1.0
  val negativeLabel: Double = 0.0
  val maxBins: Int = 100

  def trainModels(ssc: StreamingContext, config: Config) = {

    //Load configuration
    val depth = config.getInt("models.trainingConfiguration.depth")
    val impurity = config.getString("models.trainingConfiguration.impurity")
    val strategy = config.getString("models.trainingConfiguration.strategy")
    val seed = config.getInt("models.trainingConfiguration.seed")
    val forestSize = config.getInt("models.trainingConfiguration.forestSize")
    val dataPath = config.getString("models.trainingConfiguration.pathToTrainingData")
    val modelsPath = config.getString("models.pathToModels")
    val events = config.getStringList("models.models")
    val categoricalInfo = Range(0, config.getInt("eventsCount")).map((_, 2)).toMap

    val models = events.par.map(modelName => {
      (modelName,
        RandomForest.trainClassifier(
          MLUtils.loadLibSVMFile(ssc.sparkContext, dataPath + modelName + ".libsvm"),
          nClasses,
          categoricalInfo,
          forestSize,
          strategy,
          impurity,
          depth,
          maxBins,
          seed))
    })

    if (config.getBoolean("models.saveModels"))
      models.seq.foreach(x => {
        FileUtils.deleteQuietly(new File(modelsPath + x._1))
        x._2.save(ssc.sparkContext, modelsPath + x._1)
      })
    models
  }
} 
开发者ID:jandion,项目名称:SparkOFP,代码行数:54,代码来源:ModelTrainer.scala

示例6: main

//设置package包名称以及导入依赖的类
package BlkFish

import java.util.Calendar

import org.apache.spark._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.rdd.RDD
import com.typesafe.config.ConfigFactory
import org.apache.hadoop.mapred._
import Preprocess._


  def main(args: Array[String]) = {

    val conf = ConfigFactory.load()

    val categoricalFeatureInfo = Map[Int, Int]()

    val trainData = sc.objectFile[LabeledPoint](conf.getString("ml.path.trainData"))
    val testDataBytes: RDD[(String, String)] = sc.wholeTextFiles(conf.getString("ml.path.testData"))
    val testLabeledPoints = toLabeledPoints(bytesToInt(byteCount(removeMemPath(testDataBytes))))

    val model = RandomForest.trainClassifier(
      trainData,
      conf.getInt("ml.algo.numberOfClasses"),
      categoricalFeatureInfo,
      conf.getInt("ml.algo.numberOfTrees"),
      conf.getString("ml.algo.featureSubsetStrategy"),
      conf.getString("ml.algo.costFunction"),
      conf.getInt("ml.algo.maxDepth"),
      conf.getInt("ml.algo.maxBins")
    )

    val predictions = testLabeledPoints.map { point => model.predict(point.features) }

    val formattedPredictions = predictions.map(predication => predication.toInt + 1)

    try {
      formattedPredictions.saveAsTextFile(conf.getString("ml.path.predictionsOutput"))
    } catch {

      case ex: FileAlreadyExistsException => println("Prediction file already exists attempting to save with count append")
        try {
          formattedPredictions.saveAsTextFile(conf.getString("ml.path.predictionsOutput") + Calendar.getInstance.getTimeInMillis.toString)
        } catch {
          case ex: FileAlreadyExistsException => println("Failed to save with appended file name.")
            println("File will not be saved")
          case _ => println("Unknown error at save predictions")
        }

      case _ => println("Unknown error at save predictions")
    }

  }
} 
开发者ID:snakes-in-the-box,项目名称:blk-fish,代码行数:57,代码来源:Driver.scala


注:本文中的org.apache.spark.mllib.tree.RandomForest类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。