本文整理汇总了Scala中org.apache.spark.mllib.tree.RandomForest类的典型用法代码示例。如果您正苦于以下问题:Scala RandomForest类的具体用法?Scala RandomForest怎么用?Scala RandomForest使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了RandomForest类的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Scala代码示例。
示例1: MLLibRandomForestModel
//设置package包名称以及导入依赖的类
package com.asto.dmp.articlecate.biz
import com.asto.dmp.articlecate.base.Props
import com.asto.dmp.articlecate.utils.FileUtils
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import com.asto.dmp.articlecate.biz.ClsFeaturesParser._
import scala.collection._
class MLLibRandomForestModel(val sc: SparkContext, val modelPath: String) extends scala.Serializable with Logging {
def genRandomForestModel(svmTrainDataPath: String) = {
val numClasses = ClsFeaturesParser.clsNameToCodeMap.size //Util.parseMapFrom(clsIndicesPath, nameToCode = true).size
val categoricalFeaturesInfo = immutable.Map[Int, Int]()
val numTrees = Props.get("model_numTrees").toInt
val featureSubsetStrategy = Props.get("model_featureSubsetStrategy") // Let the algorithm choose.
val impurity = Props.get("model_impurity")
val maxDepth = Props.get("model_maxDepth").toInt
val maxBins = Props.get("model_maxBins").toInt
val trainingData = MLUtils.loadLibSVMFile(sc, svmTrainDataPath).cache()
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
FileUtils.deleteFilesInHDFS(modelPath)
model.save(sc, modelPath)
testErrorRate(trainingData, model)
}
private def testErrorRate(trainingData: RDD[LabeledPoint], model: RandomForestModel) = {
if (Props.get("model_test").toBoolean) {
val testData = trainingData.sample(false, Props.get("model_sampleRate").toDouble)
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testError = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
logInfo(s"????????????$testError")
} else {
logInfo(s"???????????")
}
}
def predictAndSave(lineAndVectors: Array[(String, org.apache.spark.mllib.linalg.Vector)], resultPath: String) = {
val model = RandomForestModel.load(sc, modelPath)
val result = lineAndVectors.map(lv => (s"${clsCodeToNameMap(model.predict(lv._2).toInt.toString)}\t${lv._1}")).mkString("\n")
FileUtils.saveFileToHDFS(resultPath, result)
}
}
示例2: RandomForestTest
//设置package包名称以及导入依赖的类
package cn.edu.bjtu
import org.apache.spark.SparkConf
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession
object RandomForestTest {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf()
.setAppName("RandomForestTest")
.setMaster("spark://master:7077")
.setJars(Array("/home/hadoop/RandomForest.jar"))
val spark = SparkSession.builder()
.config(sparkConf)
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "hdfs://master:9000/sample_formatted.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (training, test) = (splits(0), splits(1))
// Train a RandomForest model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 3 // Use more in practice.
val featureSubsetStrategy = "18" // Let the algorithm choose.
val impurity = "gini"
val maxDepth = 14
val maxBins = 16384
val model = RandomForest.trainClassifier(training, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
val metrics = new BinaryClassificationMetrics(predictionAndLabels)
val auROC = metrics.areaUnderROC()
println("Area under ROC = " + auROC)
println("Sensitivity = " + predictionAndLabels.filter(x => x._1 == x._2 && x._1 == 1.0).count().toDouble / predictionAndLabels.filter(x => x._2 == 1.0).count().toDouble)
println("Specificity = " + predictionAndLabels.filter(x => x._1 == x._2 && x._1 == 0.0).count().toDouble / predictionAndLabels.filter(x => x._2 == 0.0).count().toDouble)
println("Accuracy = " + predictionAndLabels.filter(x => x._1 == x._2).count().toDouble / predictionAndLabels.count().toDouble)
}
}
示例3: RandomForestAlgorithmParams
//设置package包名称以及导入依赖的类
package org.template.classification
import org.apache.predictionio.controller.P2LAlgorithm
import org.apache.predictionio.controller.Params
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.SparkContext
import grizzled.slf4j.Logger
case class RandomForestAlgorithmParams(
numClasses: Int,
numTrees: Int,
featureSubsetStrategy: String,
impurity: String,
maxDepth: Int,
maxBins: Int
) extends Params
// extends P2LAlgorithm because the MLlib's RandomForestAlgorithm doesn't contain RDD.
class RandomForestAlgorithm(val ap: RandomForestAlgorithmParams)
extends P2LAlgorithm[PreparedData, RandomForestModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(sc: SparkContext, data: PreparedData): RandomForestModel = {// Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
RandomForest.trainClassifier(
data.labeledPoints,
ap.numClasses,
categoricalFeaturesInfo,
ap.numTrees,
ap.featureSubsetStrategy,
ap.impurity,
ap.maxDepth,
ap.maxBins)
}
def predict(model: RandomForestModel, query: Query): PredictedResult = {
val features = Vectors.dense(
Array(query.voice_usage, query.data_usage, query.text_usage)
)
val label = model.predict(features)
new PredictedResult(label)
}
}
示例4: RandomForestAlgorithmTest
//设置package包名称以及导入依赖的类
package org.template.classification
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.scalatest.FlatSpec
import org.scalatest.Matchers
class RandomForestAlgorithmTest
extends FlatSpec with SharedSingletonContext with Matchers {
val params = RandomForestAlgorithmParams(
numClasses = 7,
numTrees = 12,
featureSubsetStrategy = "auto",
impurity = "gini",
maxDepth = 4,
maxBins = 100)
val algorithm = new RandomForestAlgorithm(params)
val dataSource = Seq(
LabeledPoint(0, Vectors.dense(1000, 10, 10)),
LabeledPoint(1, Vectors.dense(10, 1000, 10)),
LabeledPoint(2, Vectors.dense(10, 10, 1000))
)
"train" should "return RandomForest model" in {
val dataSourceRDD = sparkContext.parallelize(dataSource)
val preparedData = new PreparedData(labeledPoints = dataSourceRDD)
val model = algorithm.train(sparkContext, preparedData)
model shouldBe a [RandomForestModel]
}
}
示例5: ModelTrainer
//设置package包名称以及导入依赖的类
package modelmanager
import java.io.File
import com.typesafe.config.Config
import org.apache.commons.io.FileUtils
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.streaming.StreamingContext
import scala.collection.JavaConversions._
object ModelTrainer {
val nClasses: Int = 2
val positiveLabel: Double = 1.0
val negativeLabel: Double = 0.0
val maxBins: Int = 100
def trainModels(ssc: StreamingContext, config: Config) = {
//Load configuration
val depth = config.getInt("models.trainingConfiguration.depth")
val impurity = config.getString("models.trainingConfiguration.impurity")
val strategy = config.getString("models.trainingConfiguration.strategy")
val seed = config.getInt("models.trainingConfiguration.seed")
val forestSize = config.getInt("models.trainingConfiguration.forestSize")
val dataPath = config.getString("models.trainingConfiguration.pathToTrainingData")
val modelsPath = config.getString("models.pathToModels")
val events = config.getStringList("models.models")
val categoricalInfo = Range(0, config.getInt("eventsCount")).map((_, 2)).toMap
val models = events.par.map(modelName => {
(modelName,
RandomForest.trainClassifier(
MLUtils.loadLibSVMFile(ssc.sparkContext, dataPath + modelName + ".libsvm"),
nClasses,
categoricalInfo,
forestSize,
strategy,
impurity,
depth,
maxBins,
seed))
})
if (config.getBoolean("models.saveModels"))
models.seq.foreach(x => {
FileUtils.deleteQuietly(new File(modelsPath + x._1))
x._2.save(ssc.sparkContext, modelsPath + x._1)
})
models
}
}
示例6: main
//设置package包名称以及导入依赖的类
package BlkFish
import java.util.Calendar
import org.apache.spark._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.rdd.RDD
import com.typesafe.config.ConfigFactory
import org.apache.hadoop.mapred._
import Preprocess._
def main(args: Array[String]) = {
val conf = ConfigFactory.load()
val categoricalFeatureInfo = Map[Int, Int]()
val trainData = sc.objectFile[LabeledPoint](conf.getString("ml.path.trainData"))
val testDataBytes: RDD[(String, String)] = sc.wholeTextFiles(conf.getString("ml.path.testData"))
val testLabeledPoints = toLabeledPoints(bytesToInt(byteCount(removeMemPath(testDataBytes))))
val model = RandomForest.trainClassifier(
trainData,
conf.getInt("ml.algo.numberOfClasses"),
categoricalFeatureInfo,
conf.getInt("ml.algo.numberOfTrees"),
conf.getString("ml.algo.featureSubsetStrategy"),
conf.getString("ml.algo.costFunction"),
conf.getInt("ml.algo.maxDepth"),
conf.getInt("ml.algo.maxBins")
)
val predictions = testLabeledPoints.map { point => model.predict(point.features) }
val formattedPredictions = predictions.map(predication => predication.toInt + 1)
try {
formattedPredictions.saveAsTextFile(conf.getString("ml.path.predictionsOutput"))
} catch {
case ex: FileAlreadyExistsException => println("Prediction file already exists attempting to save with count append")
try {
formattedPredictions.saveAsTextFile(conf.getString("ml.path.predictionsOutput") + Calendar.getInstance.getTimeInMillis.toString)
} catch {
case ex: FileAlreadyExistsException => println("Failed to save with appended file name.")
println("File will not be saved")
case _ => println("Unknown error at save predictions")
}
case _ => println("Unknown error at save predictions")
}
}
}