当前位置: 首页>>代码示例>>Scala>>正文


Scala Estimator类代码示例

本文整理汇总了Scala中org.apache.spark.ml.Estimator的典型用法代码示例。如果您正苦于以下问题:Scala Estimator类的具体用法?Scala Estimator怎么用?Scala Estimator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


在下文中一共展示了Estimator类的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Scala代码示例。

示例1: KMeans

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.clustering

import org.apache.spark.ml
import org.apache.spark.ml.Estimator
import org.apache.spark.sql._

import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}


object KMeans extends BenchmarkAlgorithm with TestFromTraining {

  override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
    import ctx.params._
    DataGenerator.generateGaussianMixtureData(ctx.sqlContext, k, numExamples, ctx.seed(),
      numPartitions, numFeatures)
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new ml.clustering.KMeans()
      .setK(k)
      .setSeed(randomSeed.toLong)
      .setMaxIter(maxIter)
  }

  // TODO(?) add a scoring method here.
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:30,代码来源:KMeans.scala

示例2: LDA

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.clustering

import scala.collection.mutable.{HashMap => MHashMap}

import org.apache.commons.math3.random.Well19937c

import org.apache.spark.ml.Estimator
import org.apache.spark.ml
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.ml.linalg.{Vector, Vectors}

import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._


object LDA extends BenchmarkAlgorithm with TestFromTraining {
  // The LDA model is package private, no need to expose it.

  override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
    import ctx.params._
    val rdd = ctx.sqlContext.sparkContext.parallelize(
      0L until numExamples,
      numPartitions
    )
    val seed: Int = randomSeed
    val docLength = ldaDocLength.get
    val numVocab = ldaNumVocabulary.get
    val data: RDD[(Long, Vector)] = rdd.mapPartitionsWithIndex { (idx, partition) =>
      val rng = new Well19937c(seed ^ idx)
      partition.map { docIndex =>
        var currentSize = 0
        val entries = MHashMap[Int, Int]()
        while (currentSize < docLength) {
          val index = rng.nextInt(numVocab)
          entries(index) = entries.getOrElse(index, 0) + 1
          currentSize += 1
        }

        val iter = entries.toSeq.map(v => (v._1, v._2.toDouble))
        (docIndex, Vectors.sparse(numVocab, iter))
      }
    }
    ctx.sqlContext.createDataFrame(data).toDF("docIndex", "features")
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new ml.clustering.LDA()
      .setK(k)
      .setSeed(randomSeed.toLong)
      .setMaxIter(maxIter)
      .setOptimizer(optimizer)
  }

  // TODO(?) add a scoring method here.
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:58,代码来源:LDA.scala

示例3: RandomForestClassification

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.classification

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.classification.RandomForestClassifier

import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._


object RandomForestClassification extends TreeOrForestClassification {

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    // TODO: subsamplingRate, featureSubsetStrategy
    // TODO: cacheNodeIds, checkpoint?
    new RandomForestClassifier()
      .setMaxDepth(depth)
      .setNumTrees(maxIter)
      .setSeed(ctx.seed())
  }
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:22,代码来源:RandomForestClassification.scala

示例4: LogisticRegression

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.classification

import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator

import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import org.apache.spark.ml
import org.apache.spark.ml.linalg.Vectors


object LogisticRegression extends BenchmarkAlgorithm
  with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    DataGenerator.generateContinuousFeatures(
      ctx.sqlContext,
      numExamples,
      ctx.seed(),
      numPartitions,
      numFeatures)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    val rng = ctx.newGenerator()
    val coefficients =
      Vectors.dense(Array.fill[Double](ctx.params.numFeatures)(2 * rng.nextDouble() - 1))
    // Small intercept to prevent some skew in the data.
    val intercept = 0.01 * (2 * rng.nextDouble - 1)
    ModelBuilder.newLogisticRegressionModel(coefficients, intercept)
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new ml.classification.LogisticRegression()
      .setTol(tol)
      .setMaxIter(maxIter)
      .setRegParam(regParam)
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new MulticlassClassificationEvaluator()
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:46,代码来源:LogisticRegression.scala

示例5: TreeOrForestClassification

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.classification

import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.sql.DataFrame

import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator


abstract class TreeOrForestClassification extends BenchmarkAlgorithm
  with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {

  import TreeOrForestClassification.getFeatureArity

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    val featureArity: Array[Int] = getFeatureArity(ctx)
    val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
      ctx.seed(), numPartitions, featureArity)
    TreeUtils.setMetadata(data, "features", featureArity)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
      getFeatureArity(ctx), ctx.seed())
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new MulticlassClassificationEvaluator()
}

object DecisionTreeClassification extends TreeOrForestClassification {

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new DecisionTreeClassifier()
      .setMaxDepth(depth)
      .setSeed(ctx.seed())
  }
}

object TreeOrForestClassification {

  
  def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
    val numFeatures = ctx.params.numFeatures
    val fourthFeatures = numFeatures / 4
    Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
      Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
      Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
  }
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:56,代码来源:DecisionTreeClassification.scala

示例6: GBTClassification

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.classification

import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.sql._

import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator


object GBTClassification extends BenchmarkAlgorithm
  with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {

  import TreeOrForestClassification.getFeatureArity

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    val featureArity: Array[Int] = getFeatureArity(ctx)
    val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
      ctx.seed(), numPartitions, featureArity)
    TreeUtils.setMetadata(data, "features", featureArity)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    import ctx.params._
    // We add +1 to the depth to make it more likely that many iterations of boosting are needed
    // to model the true tree.
    ModelBuilder.newDecisionTreeClassificationModel(depth + 1, numClasses, getFeatureArity(ctx),
      ctx.seed())
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    // TODO: subsamplingRate, featureSubsetStrategy
    // TODO: cacheNodeIds, checkpoint?
    new GBTClassifier()
      .setMaxDepth(depth)
      .setMaxIter(maxIter)
      .setSeed(ctx.seed())
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new MulticlassClassificationEvaluator()
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:47,代码来源:GBTClassification.scala

示例7: GLMRegression

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.regression

import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.GeneralizedLinearRegression
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}

import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator


object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
  TrainingSetFromTransformer with ScoringWithEvaluator {

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    DataGenerator.generateContinuousFeatures(
      ctx.sqlContext,
      numExamples,
      ctx.seed(),
      numPartitions,
      numFeatures)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    import ctx.params._
    val rng = ctx.newGenerator()
    val coefficients =
      Vectors.dense(Array.fill[Double](ctx.params.numFeatures)(2 * rng.nextDouble() - 1))
    // Small intercept to prevent some skew in the data.
    val intercept = 0.01 * (2 * rng.nextDouble - 1)
    val m = ModelBuilder.newGLR(coefficients, intercept)
    m.set(m.link, link.get)
    m.set(m.family, family.get)
    m
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new GeneralizedLinearRegression()
      .setLink(link)
      .setFamily(family)
      .setRegParam(regParam)
      .setMaxIter(maxIter)
      .setTol(tol)
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new RegressionEvaluator()
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:52,代码来源:GLMRegression.scala

示例8: LinearRegression

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.regression

import org.apache.spark.ml
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}

import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator


object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
  TrainingSetFromTransformer with ScoringWithEvaluator {

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    DataGenerator.generateContinuousFeatures(
      ctx.sqlContext,
      numExamples,
      ctx.seed(),
      numPartitions,
      numFeatures)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    val rng = ctx.newGenerator()
    val coefficients =
      Vectors.dense(Array.fill[Double](ctx.params.numFeatures)(2 * rng.nextDouble() - 1))
    // Small intercept to prevent some skew in the data.
    val intercept = 0.01 * (2 * rng.nextDouble - 1)
    ModelBuilder.newLinearRegressionModel(coefficients, intercept)
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new ml.regression.LinearRegression()
      .setSolver("l-bfgs")
      .setRegParam(regParam)
      .setMaxIter(maxIter)
      .setTol(tol)
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new RegressionEvaluator()
} 
开发者ID:summerDG,项目名称:spark-sql-perf,代码行数:47,代码来源:LinearRegression.scala

示例9: GloVe

//设置package包名称以及导入依赖的类
package org.apache.spark.ml.feature

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.feature
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType

final class GloVe(override val uid: String)
  extends Estimator[GloVeModel] with GloVeBase with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("glove"))

  def setInputCol(value: String): this.type = set(inputCol, value)

  def setOutputCol(value: String): this.type = set(outputCol, value)

  def setDim(value: Int): this.type = set(dim, value)

  def setAlpha(value: Double): this.type = set(alpha, value)

  def setWindow(value: Int): this.type = set(window, value)

  def setStepSize(value: Double): this.type = set(stepSize, value)

  def setMaxIter(value: Int): this.type = set(maxIter, value)

  def setSeed(value: Long): this.type = set(seed, value)

  def setMinCount(value: Int): this.type = set(minCount, value)

  override def fit(dataset: Dataset[_]): GloVeModel = {
    transformSchema(dataset.schema, logging = true)
    val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
    val wordVectors = new feature.GloVe()
      .setLearningRate($(stepSize))
      .setMinCount($(minCount))
      .setNumIterations($(maxIter))
      .setSeed($(seed))
      .setDim($(dim))
      .fit(input)
    copyValues(new GloVeModel(uid, wordVectors).setParent(this))
  }

  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }

  override def copy(extra: ParamMap): GloVe = defaultCopy(extra)
} 
开发者ID:mdymczyk,项目名称:spark-miner,代码行数:52,代码来源:GloVe.scala

示例10: TreeOrForestClassification

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.classification

import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.sql.DataFrame

import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator


abstract class TreeOrForestClassification extends BenchmarkAlgorithm
  with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {

  import TreeOrForestClassification.getFeatureArity

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    val featureArity: Array[Int] = getFeatureArity(ctx)
    val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
      ctx.seed(), numPartitions, featureArity)
    TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
      getFeatureArity(ctx), ctx.seed())
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new MulticlassClassificationEvaluator()
}

object DecisionTreeClassification extends TreeOrForestClassification {

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    new DecisionTreeClassifier()
      .setMaxDepth(depth)
      .setSeed(ctx.seed())
  }
}

object TreeOrForestClassification {

  
  def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
    val numFeatures = ctx.params.numFeatures
    val fourthFeatures = numFeatures / 4
    Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
      Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
      Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
  }
} 
开发者ID:sparkonpower,项目名称:spark-sql-perf-spark2.0.0,代码行数:56,代码来源:DecisionTreeClassification.scala

示例11: GBTClassification

//设置package包名称以及导入依赖的类
package com.databricks.spark.sql.perf.mllib.classification

import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.sql._

import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator


object GBTClassification extends BenchmarkAlgorithm
  with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {

  import TreeOrForestClassification.getFeatureArity

  override protected def initialData(ctx: MLBenchContext) = {
    import ctx.params._
    val featureArity: Array[Int] = getFeatureArity(ctx)
    val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
      ctx.seed(), numPartitions, featureArity)
    TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
  }

  override protected def trueModel(ctx: MLBenchContext): Transformer = {
    import ctx.params._
    // We add +1 to the depth to make it more likely that many iterations of boosting are needed
    // to model the true tree.
    ModelBuilder.newDecisionTreeClassificationModel(depth + 1, numClasses, getFeatureArity(ctx),
      ctx.seed())
  }

  override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
    import ctx.params._
    // TODO: subsamplingRate, featureSubsetStrategy
    // TODO: cacheNodeIds, checkpoint?
    new GBTClassifier()
      .setMaxDepth(depth)
      .setMaxIter(maxIter)
      .setSeed(ctx.seed())
  }

  override protected def evaluator(ctx: MLBenchContext): Evaluator =
    new MulticlassClassificationEvaluator()
} 
开发者ID:sparkonpower,项目名称:spark-sql-perf-spark2.0.0,代码行数:47,代码来源:GBTClassification.scala


注:本文中的org.apache.spark.ml.Estimator类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。