本文整理汇总了Scala中org.apache.spark.mllib.util.MLlibTestSparkContext类的典型用法代码示例。如果您正苦于以下问题:Scala MLlibTestSparkContext类的具体用法?Scala MLlibTestSparkContext怎么用?Scala MLlibTestSparkContext使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了MLlibTestSparkContext类的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Scala代码示例。
示例1: BMRMSuite
//设置package包名称以及导入依赖的类
package org.apache.spark.mllib.optimization.bmrm
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import scala.util.Random
object BMRMSuite {
def generateSubInput(nPoint: Int, dim: Int,seed: Int):(Array[Vector], Array[Double], Vector) = {
val rnd = new Random(seed)
val label = Array.fill[Double](nPoint)(rnd.nextInt(5)+1.0)
val testData = Array.fill[Vector](nPoint)(Vectors.dense(Array.fill(dim)(rnd.nextInt(10)+1.0)))
val initWeights = Vectors.dense(Array.fill(dim)(rnd.nextInt(10)+1.0))
(testData, label, initWeights)
}
}
class BMRMSuite extends FunSuite with MLlibTestSparkContext {
test("Test the loss and gradient of first iteration") {
val subGrad = new NdcgSubGradient()
val (testData, label, initWeights) = BMRMSuite.generateSubInput(100, 100, 45)
val (gradient, loss) = subGrad.compute(testData, label, initWeights)
println(gradient)
println(loss)
}
test("Test the update of the weights of first iteration") {
val subGrad = new NdcgSubGradient()
val (testData, label, initWeights) = BMRMSuite.generateSubInput(100, 1000, 45)
val (gradient, loss) = subGrad.compute(testData, label, initWeights)
val subUpdater = new DaiFletcherUpdater()
val (newWeights, objval) = subUpdater.compute(initWeights, gradient, loss, 1.0)
println(initWeights)
println(loss)
println(newWeights)
println(objval)
}
test("Test the BMRM optimization") {
val subGrad = new NdcgSubGradient()
val subUpdater = new DaiFletcherUpdater()
val bmrm = new BMRM(subGrad, subUpdater)
val (testData, label, initWeights) = BMRMSuite.generateSubInput(100, 10, 45)
println(initWeights)
val newWeights = bmrm.optimize(testData, label, initWeights)
println(newWeights)
}
}