本文整理汇总了Scala中org.apache.spark.ml.regression.DecisionTreeRegressor类的典型用法代码示例。如果您正苦于以下问题:Scala DecisionTreeRegressor类的具体用法?Scala DecisionTreeRegressor怎么用?Scala DecisionTreeRegressor使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了DecisionTreeRegressor类的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Scala代码示例。
示例1: DTreeRegressionJob
//设置package包名称以及导入依赖的类
import io.hydrosphere.mist.api._
import io.hydrosphere.mist.api.ml._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.DecisionTreeRegressor
import org.apache.spark.sql.SparkSession
object DTreeRegressionJob extends MLMistJob {
def session: SparkSession = SparkSession
.builder()
.appName(context.appName)
.config(context.getConf)
.getOrCreate()
def train(datasetPath: String, savePath: String): Map[String, Any] = {
val dataset = session.read.format("libsvm").load(datasetPath)
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(dataset)
// Train a DecisionTree model.
val dt = new DecisionTreeRegressor()
.setLabelCol("label")
.setFeaturesCol("indexedFeatures")
// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline()
.setStages(Array(featureIndexer, dt))
// Train model. This also runs the indexers.
val model = pipeline.fit(dataset)
model.write.overwrite().save(savePath)
Map.empty
}
def serve(modelPath: String, features: List[Array[Double]]): Map[String, Any] = {
import LocalPipelineModel._
val pipeline = PipelineLoader.load(modelPath)
val data = LocalData(LocalDataColumn("features", features))
val result: LocalData = pipeline.transform(data)
Map("result" -> result.select("prediction").toMapList)
}
}