当前位置: 首页>>代码示例>>Java>>正文


Java Vectors类代码示例

本文整理汇总了Java中org.apache.spark.ml.linalg.Vectors的典型用法代码示例。如果您正苦于以下问题:Java Vectors类的具体用法?Java Vectors怎么用?Java Vectors使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


Vectors类属于org.apache.spark.ml.linalg包,在下文中一共展示了Vectors类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: fieldCall

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
/**
 * String[] -> Obj
 *
 * @param value
 * @return
 * @throws Exception
 */
public Object fieldCall(FieldInfo info, String[] value) throws Exception {
    switch (info.getDataType()) {
        case FieldInfo.STRING_DATATYPE: {
            return value;
        }
        case FieldInfo.DOUBLE_DATATYPE:
        case FieldInfo.INTEGER_DATATYPE:
        case FieldInfo.LONG_DATATYPE: {
            double[] vect = new double[value.length];
            try {
                for (int i = 0; i < value.length; i++) {
                    vect[i] = Double.valueOf(value[i]);
                }
            } catch (Exception e) {
                throw new CantConverException(e.getMessage());
            }
            return Vectors.dense(vect);
        }
        default:
            throw new CantConverException("不合法类型");
    }
}
 
开发者ID:hays2hong,项目名称:stonk,代码行数:30,代码来源:LineParse.java

示例2: testDataFrameSumDMLVectorWithIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumDMLVectorWithIDColumn() {
	System.out.println("MLContextTest - DataFrame sum DML, vector with ID column");

	List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
	list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
	list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
	list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
	JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);

	Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:24,代码来源:MLContextTest.java

示例3: testDataFrameSumPYDMLVectorWithIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumPYDMLVectorWithIDColumn() {
	System.out.println("MLContextTest - DataFrame sum PYDML, vector with ID column");

	List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
	list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
	list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
	list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
	JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);

	Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:24,代码来源:MLContextTest.java

示例4: testDataFrameSumDMLMllibVectorWithIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumDMLMllibVectorWithIDColumn() {
	System.out.println("MLContextTest - DataFrame sum DML, mllib vector with ID column");

	List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
	list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0,
			org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
	list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0,
			org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
	list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0,
			org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
	JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
	fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);

	Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:27,代码来源:MLContextTest.java

示例5: testDataFrameSumPYDMLMllibVectorWithIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumPYDMLMllibVectorWithIDColumn() {
	System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with ID column");

	List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
	list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0,
			org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
	list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0,
			org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
	list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0,
			org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
	JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
	fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX);

	Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:27,代码来源:MLContextTest.java

示例6: testDataFrameSumDMLVectorWithNoIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumDMLVectorWithNoIDColumn() {
	System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column");

	List<Vector> list = new ArrayList<Vector>();
	list.add(Vectors.dense(1.0, 2.0, 3.0));
	list.add(Vectors.dense(4.0, 5.0, 6.0));
	list.add(Vectors.dense(7.0, 8.0, 9.0));
	JavaRDD<Vector> javaRddVector = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);

	Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:23,代码来源:MLContextTest.java

示例7: testDataFrameSumPYDMLVectorWithNoIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumPYDMLVectorWithNoIDColumn() {
	System.out.println("MLContextTest - DataFrame sum PYDML, vector with no ID column");

	List<Vector> list = new ArrayList<Vector>();
	list.add(Vectors.dense(1.0, 2.0, 3.0));
	list.add(Vectors.dense(4.0, 5.0, 6.0));
	list.add(Vectors.dense(7.0, 8.0, 9.0));
	JavaRDD<Vector> javaRddVector = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);

	Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:23,代码来源:MLContextTest.java

示例8: testDataFrameSumDMLMllibVectorWithNoIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumDMLMllibVectorWithNoIDColumn() {
	System.out.println("MLContextTest - DataFrame sum DML, mllib vector with no ID column");

	List<org.apache.spark.mllib.linalg.Vector> list = new ArrayList<org.apache.spark.mllib.linalg.Vector>();
	list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0));
	list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0));
	list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0));
	JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddVector.map(new MllibVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);

	Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:23,代码来源:MLContextTest.java

示例9: testDataFrameSumPYDMLMllibVectorWithNoIDColumn

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumPYDMLMllibVectorWithNoIDColumn() {
	System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with no ID column");

	List<org.apache.spark.mllib.linalg.Vector> list = new ArrayList<org.apache.spark.mllib.linalg.Vector>();
	list.add(org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0));
	list.add(org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0));
	list.add(org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0));
	JavaRDD<org.apache.spark.mllib.linalg.Vector> javaRddVector = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddVector.map(new MllibVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField("C1", new org.apache.spark.mllib.linalg.VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR);

	Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:23,代码来源:MLContextTest.java

示例10: testDataFrameSumDMLVectorWithIDColumnNoFormatSpecified

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumDMLVectorWithIDColumnNoFormatSpecified() {
	System.out.println("MLContextTest - DataFrame sum DML, vector with ID column, no format specified");

	List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
	list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
	list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
	list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
	JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:22,代码来源:MLContextTest.java

示例11: testDataFrameSumPYDMLVectorWithIDColumnNoFormatSpecified

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumPYDMLVectorWithIDColumnNoFormatSpecified() {
	System.out.println("MLContextTest - DataFrame sum PYDML, vector with ID column, no format specified");

	List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
	list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
	list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
	list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
	JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true));
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:22,代码来源:MLContextTest.java

示例12: testDataFrameSumDMLVectorWithNoIDColumnNoFormatSpecified

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumDMLVectorWithNoIDColumnNoFormatSpecified() {
	System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column, no format specified");

	List<Vector> list = new ArrayList<Vector>();
	list.add(Vectors.dense(1.0, 2.0, 3.0));
	list.add(Vectors.dense(4.0, 5.0, 6.0));
	list.add(Vectors.dense(7.0, 8.0, 9.0));
	JavaRDD<Vector> javaRddVector = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:21,代码来源:MLContextTest.java

示例13: testDataFrameSumPYDMLVectorWithNoIDColumnNoFormatSpecified

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testDataFrameSumPYDMLVectorWithNoIDColumnNoFormatSpecified() {
	System.out.println("MLContextTest - DataFrame sum PYDML, vector with no ID column, no format specified");

	List<Vector> list = new ArrayList<Vector>();
	list.add(Vectors.dense(1.0, 2.0, 3.0));
	list.add(Vectors.dense(4.0, 5.0, 6.0));
	list.add(Vectors.dense(7.0, 8.0, 9.0));
	JavaRDD<Vector> javaRddVector = sc.parallelize(list);

	JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
	List<StructField> fields = new ArrayList<StructField>();
	fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
	StructType schema = DataTypes.createStructType(fields);
	Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);

	Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame);
	setExpectedStdOut("sum: 45.0");
	ml.execute(script);
}
 
开发者ID:apache,项目名称:systemml,代码行数:21,代码来源:MLContextTest.java

示例14: testMinMaxScaler

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testMinMaxScaler() {
    //prepare data
    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
            RowFactory.create(1.0, Vectors.dense(data[0])),
            RowFactory.create(2.0, Vectors.dense(data[1])),
            RowFactory.create(3.0, Vectors.dense(data[2])),
            RowFactory.create(4.0, Vectors.dense(data[3]))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    Dataset<Row> df = spark.createDataFrame(jrdd, schema);

    //train model in spark
    MinMaxScalerModel sparkModel = new MinMaxScaler()
            .setInputCol("features")
            .setOutputCol("scaled")
            .setMin(-5)
            .setMax(5)
            .fit(df);


    //Export model, import it back and get transformer
    byte[] exportedModel = ModelExporter.export(sparkModel);
    final Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    List<Row> sparkOutput = sparkModel.transform(df).orderBy("label").select("features", "scaled").collectAsList();
    assertCorrectness(sparkOutput, expected, transformer);
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:35,代码来源:MinMaxScalerBridgeTest.java

示例15: testStandardScaler

import org.apache.spark.ml.linalg.Vectors; //导入依赖的package包/类
@Test
public void testStandardScaler() {


    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
            RowFactory.create(1.0, Vectors.dense(data[0])),
            RowFactory.create(2.0, Vectors.dense(data[1])),
            RowFactory.create(3.0, Vectors.dense(data[2]))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    Dataset<Row> df = spark.createDataFrame(jrdd, schema);

    //train model in spark
    StandardScalerModel sparkModelNone = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scaledOutput")
            .setWithMean(false)
            .setWithStd(false)
            .fit(df);

    StandardScalerModel sparkModelWithMean = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scaledOutput")
            .setWithMean(true)
            .setWithStd(false)
            .fit(df);

    StandardScalerModel sparkModelWithStd = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scaledOutput")
            .setWithMean(false)
            .setWithStd(true)
            .fit(df);

    StandardScalerModel sparkModelWithBoth = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scaledOutput")
            .setWithMean(true)
            .setWithStd(true)
            .fit(df);


    //Export model, import it back and get transformer
    byte[] exportedModel = ModelExporter.export(sparkModelNone);
    final Transformer transformerNone = ModelImporter.importAndGetTransformer(exportedModel);

    exportedModel = ModelExporter.export(sparkModelWithMean);
    final Transformer transformerWithMean = ModelImporter.importAndGetTransformer(exportedModel);

    exportedModel = ModelExporter.export(sparkModelWithStd);
    final Transformer transformerWithStd = ModelImporter.importAndGetTransformer(exportedModel);

    exportedModel = ModelExporter.export(sparkModelWithBoth);
    final Transformer transformerWithBoth = ModelImporter.importAndGetTransformer(exportedModel);


    //compare predictions
    List<Row> sparkNoneOutput = sparkModelNone.transform(df).orderBy("label").select("features", "scaledOutput").collectAsList();
    assertCorrectness(sparkNoneOutput, data, transformerNone);

    List<Row> sparkWithMeanOutput = sparkModelWithMean.transform(df).orderBy("label").select("features", "scaledOutput").collectAsList();
    assertCorrectness(sparkWithMeanOutput, resWithMean, transformerWithMean);

    List<Row> sparkWithStdOutput = sparkModelWithStd.transform(df).orderBy("label").select("features", "scaledOutput").collectAsList();
    assertCorrectness(sparkWithStdOutput, resWithStd, transformerWithStd);

    List<Row> sparkWithBothOutput = sparkModelWithBoth.transform(df).orderBy("label").select("features", "scaledOutput").collectAsList();
    assertCorrectness(sparkWithBothOutput, resWithBoth, transformerWithBoth);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:76,代码来源:StandardScalerBridgeTest.java


注:本文中的org.apache.spark.ml.linalg.Vectors类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。