当前位置: 首页>>代码示例>>Java>>正文


Java ObjectInspectorUtils.getConstantObjectInspector方法代码示例

本文整理汇总了Java中org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.getConstantObjectInspector方法的典型用法代码示例。如果您正苦于以下问题:Java ObjectInspectorUtils.getConstantObjectInspector方法的具体用法?Java ObjectInspectorUtils.getConstantObjectInspector怎么用?Java ObjectInspectorUtils.getConstantObjectInspector使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils的用法示例。


在下文中一共展示了ObjectInspectorUtils.getConstantObjectInspector方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testReverseTopKWithKey

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testReverseTopKWithKey() throws Exception {
    // = tail-k
    ObjectInspector[] inputOIs = new ObjectInspector[] {
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2 -reverse")};

    final String[] values = new String[] {"banana", "apple", "candy"};
    final double[] keys = new double[] {0.7, 0.5, 0.8};

    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
    evaluator.reset(agg);

    for (int i = 0; i < values.length; i++) {
        evaluator.iterate(agg, new Object[] {values[i], keys[i]});
    }

    List<Object> res = evaluator.terminate(agg);

    Assert.assertEquals(2, res.size());
    Assert.assertEquals("apple", res.get(0));
    Assert.assertEquals("banana", res.get(1));
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:26,代码来源:UDAFToOrderedListTest.java

示例2: testReverseTailKWithKey

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testReverseTailKWithKey() throws Exception {
    // = top-k
    ObjectInspector[] inputOIs = new ObjectInspector[] {
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2 -reverse")};

    final String[] values = new String[] {"banana", "apple", "candy"};
    final double[] keys = new double[] {0.7, 0.5, 0.8};

    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
    evaluator.reset(agg);

    for (int i = 0; i < values.length; i++) {
        evaluator.iterate(agg, new Object[] {values[i], keys[i]});
    }

    List<Object> res = evaluator.terminate(agg);

    Assert.assertEquals(2, res.size());
    Assert.assertEquals("candy", res.get(0));
    Assert.assertEquals("banana", res.get(1));
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:26,代码来源:UDAFToOrderedListTest.java

示例3: testTopK

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testTopK() throws Exception {
    ObjectInspector[] inputOIs = new ObjectInspector[] {
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")};

    final String[] values = new String[] {"banana", "apple", "candy"};

    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
    evaluator.reset(agg);

    for (int i = 0; i < values.length; i++) {
        evaluator.iterate(agg, new Object[] {values[i]});
    }

    List<Object> res = evaluator.terminate(agg);

    Assert.assertEquals(2, res.size());
    Assert.assertEquals("candy", res.get(0));
    Assert.assertEquals("banana", res.get(1));
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:23,代码来源:UDAFToOrderedListTest.java

示例4: testPA2EtaWithParameter

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testPA2EtaWithParameter() throws UDFArgumentException {
    PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA2();
    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 3.0");

    /* do initialize() with aggressiveness parameter */
    udtf.initialize(new ObjectInspector[] {intListOI, intOI, param});
    float loss = 0.1f;

    PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
    float expectedLearningRate1 = 0.4615384f;
    assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);

    PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
    float expectedLearningRate2 = 0.5660377f;
    assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:21,代码来源:PassiveAggressiveUDTFTest.java

示例5: testGaussianInit

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testGaussianInit() throws HiveException {
    println("--------------------------\n testGaussianInit()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String(
            "-factor 3 -rankinit gaussian"));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    mf.initialize(argOIs);
    Assert.assertTrue(mf.rankInit == RankInitScheme.gaussian);

    float[][] rating = { {5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];
    final int num_iters = 100;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
            }
        }
    }
    for (int row = 0; row < rating.length; row++) {
        for (int col = 0, size = rating[row].length; col < size; col++) {
            double predicted = mf.predict(row, col);
            print(rating[row][col] + "[" + predicted + "]\t");
            Assert.assertEquals(rating[row][col], predicted, 0.2d);
        }
        println();
    }
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:37,代码来源:MatrixFactorizationSGDUDTFTest.java

示例6: testUnsupportedLossFunction

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test(expected = UDFArgumentException.class)
public void testUnsupportedLossFunction() throws Exception {
    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss UnsupportedLoss");

    udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:12,代码来源:GeneralClassifierUDTFTest.java

示例7: binarySetUp

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
private void binarySetUp(Object actual, Object predicted, double beta, String average)
        throws Exception {
    fmeasure = new FMeasureUDAF();
    inputOIs = new ObjectInspector[3];

    String actualClassName = actual.getClass().getName();
    if (actualClassName.equals("java.lang.Integer")) {
        inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT);
    } else if (actualClassName.equals("java.lang.Boolean")) {
        inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN);
    } else if ((actualClassName.equals("java.lang.String"))) {
        inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING);
    }

    String predicatedClassName = predicted.getClass().getName();
    if (predicatedClassName.equals("java.lang.Integer")) {
        inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT);
    } else if (predicatedClassName.equals("java.lang.Boolean")) {
        inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN);
    } else if ((predicatedClassName.equals("java.lang.String"))) {
        inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING);
    }

    inputOIs[2] = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta " + beta
                + " -average " + average);

    evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
    agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer();
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:31,代码来源:FMeasureUDAFTest.java

示例8: testMerge

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testMerge() throws Exception {
    udaf = new PLSAPredictUDAF();

    inputOIs = new ObjectInspector[] {
            PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
            PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
            PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
            PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
            ObjectInspectorUtils.getConstantObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};

    evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));

    agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();

    final Map<String, Float> doc = new HashMap<String, Float>();
    doc.put("apples", 1.f);
    doc.put("avocados", 1.f);
    doc.put("colds", 1.f);
    doc.put("flu", 1.f);
    doc.put("like", 2.f);
    doc.put("oranges", 1.f);

    Object[] partials = new Object[3];

    // bin #1
    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
    evaluator.reset(agg);
    for (int i = 0; i < 6; i++) {
        evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
    }
    partials[0] = evaluator.terminatePartial(agg);

    // bin #2
    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
    evaluator.reset(agg);
    for (int i = 6; i < 12; i++) {
        evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
    }
    partials[1] = evaluator.terminatePartial(agg);

    // bin #3
    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
    evaluator.reset(agg);
    for (int i = 12; i < 18; i++) {
        evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
    }

    partials[2] = evaluator.terminatePartial(agg);

    // merge in a different order
    final int[][] orders = new int[][] { {0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}};
    for (int i = 0; i < orders.length; i++) {
        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
        evaluator.reset(agg);

        evaluator.merge(agg, partials[orders[i][0]]);
        evaluator.merge(agg, partials[orders[i][1]]);
        evaluator.merge(agg, partials[orders[i][2]]);

        float[] distr = agg.get();
        Assert.assertTrue(distr[0] < distr[1]);
    }
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:66,代码来源:PLSAPredictUDAFTest.java

示例9: getDecisionTreeFromDenseInput

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
        throws IOException, ParseException, HiveException {
    URL url = new URL(urlString);
    InputStream is = new BufferedInputStream(url.openStream());

    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);

    AttributeDataset iris = arffParser.parse(is);
    int size = iris.size();
    double[][] x = iris.toArray(new double[size][]);
    int[] y = iris.toArray(new int[size]);

    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    final List<Double> xi = new ArrayList<Double>(x[0].length);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < x[i].length; j++) {
            xi.add(j, x[i][j]);
        }
        udtf.process(new Object[] {xi, y[i]});
        xi.clear();
    }

    final Text[] placeholder = new Text[1];
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            placeholder[0] = (Text) forward[2];
        }
    };

    udtf.setCollector(collector);
    udtf.close();

    Text modelTxt = placeholder[0];
    Assert.assertNotNull(modelTxt);

    byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
    DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
    return node;
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:48,代码来源:RandomForestClassifierUDTFTest.java

示例10: testClassification

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testClassification() throws HiveException {
    final int ROW = 10, COL = 40;

    FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
    ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
    ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-adareg -init_v gaussian -factors 20 -classification -seed 31 -iters 10");
    udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
    FactorizationMachineModel model = udtf.initModel(udtf._params);
    Assert.assertTrue("Actual class: " + model.getClass().getName(),
        model instanceof FMStringFeatureMapModel);

    float accuracy = 0.f;
    final Random rnd = new Random(201L);
    for (int numberOfIteration = 0; numberOfIteration < 10000; numberOfIteration++) {
        ArrayList<StringFeature[]> fArrayList = new ArrayList<StringFeature[]>();
        ArrayList<Double> ans = new ArrayList<Double>();
        for (int i = 0; i < ROW; i++) {
            ArrayList<StringFeature> feature = new ArrayList<StringFeature>();
            for (int j = 1; j <= COL; j++) {
                if (i < (0.5f * ROW)) {
                    if (j == 1) {
                        feature.add(new StringFeature(j, 1.d));
                    } else if (j < 0.5 * COL) {
                        if (rnd.nextFloat() < 0.2f) {
                            feature.add(new StringFeature(j, rnd.nextDouble()));
                        }
                    }
                } else {
                    if (j > 0.5f * COL) {
                        if (rnd.nextFloat() < 0.2f) {
                            feature.add(new StringFeature(j, rnd.nextDouble()));
                        }
                    }
                }
            }
            StringFeature[] x = new StringFeature[feature.size()];
            feature.toArray(x);
            fArrayList.add(x);

            final double y;
            if (i < ROW * 0.5f) {
                y = -1.0d;
            } else {
                y = 1.0d;
            }
            ans.add(y);

            udtf.process(new Object[] {toStringArray(x), y});
        }
        int bingo = 0;
        int total = fArrayList.size();
        for (int i = 0; i < total; i++) {
            double tmpAns = ans.get(i);
            if (tmpAns < 0) {
                tmpAns = 0;
            } else {
                tmpAns = 1;
            }
            double p = model.predict(fArrayList.get(i));
            int predicted = p > 0.5 ? 1 : 0;
            if (predicted == tmpAns) {
                bingo++;
            }
        }
        accuracy = bingo / (float) total;
        println("Accuracy = " + accuracy);
    }
    udtf.runTrainingIteration(10);
    Assert.assertTrue(accuracy > 0.95f);
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:75,代码来源:StringFeatureMapModelTest.java

示例11: testNews20MultiClassSparse

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
    final int numTrees = 10;
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-stratified_sampling -seed 71 -trees " + numTrees);
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});


    BufferedReader news20 = readFile("news20-multiclass.gz");
    ArrayList<String> features = new ArrayList<String>();
    String line = news20.readLine();
    while (line != null) {
        StringTokenizer tokens = new StringTokenizer(line, " ");
        int label = Integer.parseInt(tokens.nextToken());
        while (tokens.hasMoreTokens()) {
            features.add(tokens.nextToken());
        }
        Assert.assertFalse(features.isEmpty());
        udtf.process(new Object[] {features, label});

        features.clear();
        line = news20.readLine();
    }
    news20.close();

    final MutableInt count = new MutableInt(0);
    final MutableInt oobErrors = new MutableInt(0);
    final MutableInt oobTests = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            oobErrors.addValue(((IntWritable) forward[4]).get());
            oobTests.addValue(((IntWritable) forward[5]).get());
            count.addValue(1);
        }
    };
    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(numTrees, count.getValue());
    float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
    // TODO why multi-class classification so bad??
    Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:49,代码来源:RandomForestClassifierUDTFTest.java

示例12: testRegression

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testRegression() throws HiveException {
    final int ROW = 1000, COL = 80;

    FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
    ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
    ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        "-factors 20 -seed 31 -eta 0.001 -lambda0 0.1 -sigma 0.1");
    udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
    FactorizationMachineModel model = udtf.initModel(udtf._params);
    Assert.assertTrue("Actual class: " + model.getClass().getName(),
        model instanceof FMStringFeatureMapModel);

    double diff = 0.d;
    final Random rnd = new Random(201L);
    for (int numberOfIteration = 0; numberOfIteration < 100; numberOfIteration++) {
        ArrayList<StringFeature[]> fArrayList = new ArrayList<StringFeature[]>();
        ArrayList<Double> ans = new ArrayList<Double>();
        for (int i = 0; i < ROW; i++) {
            ArrayList<StringFeature> feature = new ArrayList<StringFeature>();
            for (int j = 1; j <= COL; j++) {
                if (i < (0.5f * ROW)) {
                    if (j == 1) {
                        feature.add(new StringFeature(j, 1.d));
                    } else if (j < 0.5 * COL) {
                        if (rnd.nextFloat() < 0.2f) {
                            feature.add(new StringFeature(j, rnd.nextDouble()));
                        }
                    }
                } else {
                    if (j > (0.5f * COL)) {
                        if (rnd.nextFloat() < 0.2f) {
                            feature.add(new StringFeature(j, rnd.nextDouble()));
                        }
                    }
                }
            }
            StringFeature[] x = new StringFeature[feature.size()];
            feature.toArray(x);
            fArrayList.add(x);

            final double y;
            if (i < ROW * 0.5f) {
                y = 0.1d;
            } else {
                y = 0.4d;
            }
            ans.add(y);

            udtf.process(new Object[] {toStringArray(x), y});
        }

        diff = 0.d;
        for (int i = 0; i < fArrayList.size(); i++) {
            double predicted = model.predict(fArrayList.get(i));
            double actual = ans.get(i);
            double tmpDiff = predicted - actual;
            diff += tmpDiff * tmpDiff;
        }
        println("diff = " + diff);
    }
    Assert.assertTrue("diff = " + diff, diff < 5.d);
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:66,代码来源:StringFeatureMapModelTest.java

示例13: testFileBackedIterationsCloseNoConverge

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testFileBackedIterationsCloseNoConverge() throws HiveException {
    println("--------------------------\n testFileBackedIterationsCloseNoConverge()");
    OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();

    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
    int iters = 5;
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String(
            "-disable_cv -factor 3 -iterations " + iters));
    ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
    MapredContext mrContext = MapredContextAccessor.create(true, null);
    mf.configure(mrContext);
    mf.initialize(argOIs);
    final MutableInt numCollected = new MutableInt(0);
    mf.setCollector(new Collector() {
        @Override
        public void collect(Object input) throws HiveException {
            numCollected.addValue(1);
        }
    });
    Assert.assertTrue(mf.rankInit == RankInitScheme.random);

    float[][] rating = { {5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
    Object[] args = new Object[3];

    final int num_iters = 500;
    int trainingExamples = 0;
    for (int iter = 0; iter < num_iters; iter++) {
        for (int row = 0; row < rating.length; row++) {
            for (int col = 0, size = rating[row].length; col < size; col++) {
                args[0] = row;
                args[1] = col;
                args[2] = (float) rating[row][col];
                mf.process(args);
                trainingExamples++;
            }
        }
    }

    File tmpFile = mf.fileIO.getFile();
    mf.close();
    Assert.assertEquals(trainingExamples * iters, mf.count);
    Assert.assertEquals(5, numCollected.intValue());
    Assert.assertFalse(tmpFile.exists());
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:48,代码来源:MatrixFactorizationSGDUDTFTest.java

示例14: testNews20BinarySparse

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testNews20BinarySparse() throws IOException, ParseException, HiveException {
    final int numTrees = 10;
    RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
    ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 71 -trees "
                + numTrees);
    udtf.initialize(new ObjectInspector[] {
            ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
            PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});

    BufferedReader news20 = readFile("news20-small.binary.gz");
    ArrayList<String> features = new ArrayList<String>();
    String line = news20.readLine();
    while (line != null) {
        StringTokenizer tokens = new StringTokenizer(line, " ");
        int label = Integer.parseInt(tokens.nextToken());
        if (label == -1) {
            label = 0;
        }
        while (tokens.hasMoreTokens()) {
            features.add(tokens.nextToken());
        }
        if (!features.isEmpty()) {
            udtf.process(new Object[] {features, label});
            features.clear();
        }
        line = news20.readLine();
    }
    news20.close();

    final MutableInt count = new MutableInt(0);
    final MutableInt oobErrors = new MutableInt(0);
    final MutableInt oobTests = new MutableInt(0);
    Collector collector = new Collector() {
        public void collect(Object input) throws HiveException {
            Object[] forward = (Object[]) input;
            oobErrors.addValue(((IntWritable) forward[4]).get());
            oobTests.addValue(((IntWritable) forward[5]).get());
            count.addValue(1);
        }
    };
    udtf.setCollector(collector);
    udtf.close();

    Assert.assertEquals(numTrees, count.getValue());
    float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
    Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3);
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:50,代码来源:RandomForestClassifierUDTFTest.java

示例15: run

import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
private void run(@Nonnull String options) throws Exception {
    println(options);

    ArrayList<List<String>> samplesList = new ArrayList<List<String>>();
    samplesList.add(Arrays.asList("1:-2", "2:-1"));
    samplesList.add(Arrays.asList("1:-1", "2:-1"));
    samplesList.add(Arrays.asList("1:-1", "2:-2"));
    samplesList.add(Arrays.asList("1:1", "2:1"));
    samplesList.add(Arrays.asList("1:1", "2:2"));
    samplesList.add(Arrays.asList("1:2", "2:1"));

    int[] labels = new int[] {0, 0, 0, 1, 1, 1};

    GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
    ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
    ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
    ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

    udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});

    for (int i = 0, size = samplesList.size(); i < size; i++) {
        udtf.process(new Object[] {samplesList.get(i), labels[i]});
    }

    udtf.finalizeTraining();

    double cumLoss = udtf.getCumulativeLoss();
    println("Cumulative loss: " + cumLoss);
    double normalizedLoss = cumLoss / samplesList.size();
    Assert.assertTrue("cumLoss: " + cumLoss + ", normalizedLoss: " + normalizedLoss
            + "\noptions: " + options, normalizedLoss < 0.5d);

    int numTests = 0;
    int numCorrect = 0;

    for (int i = 0, size = samplesList.size(); i < size; i++) {
        int label = labels[i];

        float score = udtf.predict(udtf.parseFeatures(samplesList.get(i)));
        int predicted = score > 0.f ? 1 : 0;

        println("Score: " + score + ", Predicted: " + predicted + ", Actual: " + label);

        if (predicted == label) {
            ++numCorrect;
        }
        ++numTests;
    }

    float accuracy = numCorrect / (float) numTests;
    println("Accuracy: " + accuracy);
    Assert.assertTrue(accuracy == 1.f);
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:56,代码来源:GeneralClassifierUDTFTest.java


注:本文中的org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.getConstantObjectInspector方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。