本文整理汇总了Java中org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.getConstantObjectInspector方法的典型用法代码示例。如果您正苦于以下问题:Java ObjectInspectorUtils.getConstantObjectInspector方法的具体用法?Java ObjectInspectorUtils.getConstantObjectInspector怎么用?Java ObjectInspectorUtils.getConstantObjectInspector使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils
的用法示例。
在下文中一共展示了ObjectInspectorUtils.getConstantObjectInspector方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testReverseTopKWithKey
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testReverseTopKWithKey() throws Exception {
// = tail-k
ObjectInspector[] inputOIs = new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2 -reverse")};
final String[] values = new String[] {"banana", "apple", "candy"};
final double[] keys = new double[] {0.7, 0.5, 0.8};
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
for (int i = 0; i < values.length; i++) {
evaluator.iterate(agg, new Object[] {values[i], keys[i]});
}
List<Object> res = evaluator.terminate(agg);
Assert.assertEquals(2, res.size());
Assert.assertEquals("apple", res.get(0));
Assert.assertEquals("banana", res.get(1));
}
示例2: testReverseTailKWithKey
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testReverseTailKWithKey() throws Exception {
// = top-k
ObjectInspector[] inputOIs = new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2 -reverse")};
final String[] values = new String[] {"banana", "apple", "candy"};
final double[] keys = new double[] {0.7, 0.5, 0.8};
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
for (int i = 0; i < values.length; i++) {
evaluator.iterate(agg, new Object[] {values[i], keys[i]});
}
List<Object> res = evaluator.terminate(agg);
Assert.assertEquals(2, res.size());
Assert.assertEquals("candy", res.get(0));
Assert.assertEquals("banana", res.get(1));
}
示例3: testTopK
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testTopK() throws Exception {
ObjectInspector[] inputOIs = new ObjectInspector[] {
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")};
final String[] values = new String[] {"banana", "apple", "candy"};
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
for (int i = 0; i < values.length; i++) {
evaluator.iterate(agg, new Object[] {values[i]});
}
List<Object> res = evaluator.terminate(agg);
Assert.assertEquals(2, res.size());
Assert.assertEquals("candy", res.get(0));
Assert.assertEquals("banana", res.get(1));
}
示例4: testPA2EtaWithParameter
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testPA2EtaWithParameter() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA2();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 3.0");
/* do initialize() with aggressiveness parameter */
udtf.initialize(new ObjectInspector[] {intListOI, intOI, param});
float loss = 0.1f;
PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
float expectedLearningRate1 = 0.4615384f;
assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);
PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
float expectedLearningRate2 = 0.5660377f;
assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
}
示例5: testGaussianInit
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testGaussianInit() throws HiveException {
println("--------------------------\n testGaussianInit()");
OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String(
"-factor 3 -rankinit gaussian"));
ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
mf.initialize(argOIs);
Assert.assertTrue(mf.rankInit == RankInitScheme.gaussian);
float[][] rating = { {5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
Object[] args = new Object[3];
final int num_iters = 100;
for (int iter = 0; iter < num_iters; iter++) {
for (int row = 0; row < rating.length; row++) {
for (int col = 0, size = rating[row].length; col < size; col++) {
args[0] = row;
args[1] = col;
args[2] = (float) rating[row][col];
mf.process(args);
}
}
}
for (int row = 0; row < rating.length; row++) {
for (int col = 0, size = rating[row].length; col < size; col++) {
double predicted = mf.predict(row, col);
print(rating[row][col] + "[" + predicted + "]\t");
Assert.assertEquals(rating[row][col], predicted, 0.2d);
}
println();
}
}
示例6: testUnsupportedLossFunction
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test(expected = UDFArgumentException.class)
public void testUnsupportedLossFunction() throws Exception {
GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss UnsupportedLoss");
udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
}
示例7: binarySetUp
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
private void binarySetUp(Object actual, Object predicted, double beta, String average)
throws Exception {
fmeasure = new FMeasureUDAF();
inputOIs = new ObjectInspector[3];
String actualClassName = actual.getClass().getName();
if (actualClassName.equals("java.lang.Integer")) {
inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT);
} else if (actualClassName.equals("java.lang.Boolean")) {
inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN);
} else if ((actualClassName.equals("java.lang.String"))) {
inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING);
}
String predicatedClassName = predicted.getClass().getName();
if (predicatedClassName.equals("java.lang.Integer")) {
inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT);
} else if (predicatedClassName.equals("java.lang.Boolean")) {
inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN);
} else if ((predicatedClassName.equals("java.lang.String"))) {
inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING);
}
inputOIs[2] = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta " + beta
+ " -average " + average);
evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer();
}
示例8: testMerge
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testMerge() throws Exception {
udaf = new PLSAPredictUDAF();
inputOIs = new ObjectInspector[] {
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING),
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT),
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
final Map<String, Float> doc = new HashMap<String, Float>();
doc.put("apples", 1.f);
doc.put("avocados", 1.f);
doc.put("colds", 1.f);
doc.put("flu", 1.f);
doc.put("like", 2.f);
doc.put("oranges", 1.f);
Object[] partials = new Object[3];
// bin #1
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
for (int i = 0; i < 6; i++) {
evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
}
partials[0] = evaluator.terminatePartial(agg);
// bin #2
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
for (int i = 6; i < 12; i++) {
evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
}
partials[1] = evaluator.terminatePartial(agg);
// bin #3
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
evaluator.reset(agg);
for (int i = 12; i < 18; i++) {
evaluator.iterate(agg, new Object[] {words[i], doc.get(words[i]), labels[i], probs[i]});
}
partials[2] = evaluator.terminatePartial(agg);
// merge in a different order
final int[][] orders = new int[][] { {0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}};
for (int i = 0; i < orders.length; i++) {
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
evaluator.reset(agg);
evaluator.merge(agg, partials[orders[i][0]]);
evaluator.merge(agg, partials[orders[i][1]]);
evaluator.merge(agg, partials[orders[i][2]]);
float[] distr = agg.get();
Assert.assertTrue(distr[0] < distr[1]);
}
}
示例9: getDecisionTreeFromDenseInput
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString)
throws IOException, ParseException, HiveException {
URL url = new URL(urlString);
InputStream is = new BufferedInputStream(url.openStream());
ArffParser arffParser = new ArffParser();
arffParser.setResponseIndex(4);
AttributeDataset iris = arffParser.parse(is);
int size = iris.size();
double[][] x = iris.toArray(new double[size][]);
int[] y = iris.toArray(new int[size]);
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
final List<Double> xi = new ArrayList<Double>(x[0].length);
for (int i = 0; i < size; i++) {
for (int j = 0; j < x[i].length; j++) {
xi.add(j, x[i][j]);
}
udtf.process(new Object[] {xi, y[i]});
xi.clear();
}
final Text[] placeholder = new Text[1];
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
placeholder[0] = (Text) forward[2];
}
};
udtf.setCollector(collector);
udtf.close();
Text modelTxt = placeholder[0];
Assert.assertNotNull(modelTxt);
byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);
return node;
}
示例10: testClassification
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testClassification() throws HiveException {
final int ROW = 10, COL = 40;
FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-adareg -init_v gaussian -factors 20 -classification -seed 31 -iters 10");
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMStringFeatureMapModel);
float accuracy = 0.f;
final Random rnd = new Random(201L);
for (int numberOfIteration = 0; numberOfIteration < 10000; numberOfIteration++) {
ArrayList<StringFeature[]> fArrayList = new ArrayList<StringFeature[]>();
ArrayList<Double> ans = new ArrayList<Double>();
for (int i = 0; i < ROW; i++) {
ArrayList<StringFeature> feature = new ArrayList<StringFeature>();
for (int j = 1; j <= COL; j++) {
if (i < (0.5f * ROW)) {
if (j == 1) {
feature.add(new StringFeature(j, 1.d));
} else if (j < 0.5 * COL) {
if (rnd.nextFloat() < 0.2f) {
feature.add(new StringFeature(j, rnd.nextDouble()));
}
}
} else {
if (j > 0.5f * COL) {
if (rnd.nextFloat() < 0.2f) {
feature.add(new StringFeature(j, rnd.nextDouble()));
}
}
}
}
StringFeature[] x = new StringFeature[feature.size()];
feature.toArray(x);
fArrayList.add(x);
final double y;
if (i < ROW * 0.5f) {
y = -1.0d;
} else {
y = 1.0d;
}
ans.add(y);
udtf.process(new Object[] {toStringArray(x), y});
}
int bingo = 0;
int total = fArrayList.size();
for (int i = 0; i < total; i++) {
double tmpAns = ans.get(i);
if (tmpAns < 0) {
tmpAns = 0;
} else {
tmpAns = 1;
}
double p = model.predict(fArrayList.get(i));
int predicted = p > 0.5 ? 1 : 0;
if (predicted == tmpAns) {
bingo++;
}
}
accuracy = bingo / (float) total;
println("Accuracy = " + accuracy);
}
udtf.runTrainingIteration(10);
Assert.assertTrue(accuracy > 0.95f);
}
示例11: testNews20MultiClassSparse
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException {
final int numTrees = 10;
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-stratified_sampling -seed 71 -trees " + numTrees);
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
BufferedReader news20 = readFile("news20-multiclass.gz");
ArrayList<String> features = new ArrayList<String>();
String line = news20.readLine();
while (line != null) {
StringTokenizer tokens = new StringTokenizer(line, " ");
int label = Integer.parseInt(tokens.nextToken());
while (tokens.hasMoreTokens()) {
features.add(tokens.nextToken());
}
Assert.assertFalse(features.isEmpty());
udtf.process(new Object[] {features, label});
features.clear();
line = news20.readLine();
}
news20.close();
final MutableInt count = new MutableInt(0);
final MutableInt oobErrors = new MutableInt(0);
final MutableInt oobTests = new MutableInt(0);
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
oobErrors.addValue(((IntWritable) forward[4]).get());
oobTests.addValue(((IntWritable) forward[5]).get());
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.assertEquals(numTrees, count.getValue());
float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
// TODO why multi-class classification so bad??
Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8);
}
示例12: testRegression
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testRegression() throws HiveException {
final int ROW = 1000, COL = 80;
FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-factors 20 -seed 31 -eta 0.001 -lambda0 0.1 -sigma 0.1");
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMStringFeatureMapModel);
double diff = 0.d;
final Random rnd = new Random(201L);
for (int numberOfIteration = 0; numberOfIteration < 100; numberOfIteration++) {
ArrayList<StringFeature[]> fArrayList = new ArrayList<StringFeature[]>();
ArrayList<Double> ans = new ArrayList<Double>();
for (int i = 0; i < ROW; i++) {
ArrayList<StringFeature> feature = new ArrayList<StringFeature>();
for (int j = 1; j <= COL; j++) {
if (i < (0.5f * ROW)) {
if (j == 1) {
feature.add(new StringFeature(j, 1.d));
} else if (j < 0.5 * COL) {
if (rnd.nextFloat() < 0.2f) {
feature.add(new StringFeature(j, rnd.nextDouble()));
}
}
} else {
if (j > (0.5f * COL)) {
if (rnd.nextFloat() < 0.2f) {
feature.add(new StringFeature(j, rnd.nextDouble()));
}
}
}
}
StringFeature[] x = new StringFeature[feature.size()];
feature.toArray(x);
fArrayList.add(x);
final double y;
if (i < ROW * 0.5f) {
y = 0.1d;
} else {
y = 0.4d;
}
ans.add(y);
udtf.process(new Object[] {toStringArray(x), y});
}
diff = 0.d;
for (int i = 0; i < fArrayList.size(); i++) {
double predicted = model.predict(fArrayList.get(i));
double actual = ans.get(i);
double tmpDiff = predicted - actual;
diff += tmpDiff * tmpDiff;
}
println("diff = " + diff);
}
Assert.assertTrue("diff = " + diff, diff < 5.d);
}
示例13: testFileBackedIterationsCloseNoConverge
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testFileBackedIterationsCloseNoConverge() throws HiveException {
println("--------------------------\n testFileBackedIterationsCloseNoConverge()");
OnlineMatrixFactorizationUDTF mf = new MatrixFactorizationSGDUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
int iters = 5;
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String(
"-disable_cv -factor 3 -iterations " + iters));
ObjectInspector[] argOIs = new ObjectInspector[] {intOI, intOI, floatOI, param};
MapredContext mrContext = MapredContextAccessor.create(true, null);
mf.configure(mrContext);
mf.initialize(argOIs);
final MutableInt numCollected = new MutableInt(0);
mf.setCollector(new Collector() {
@Override
public void collect(Object input) throws HiveException {
numCollected.addValue(1);
}
});
Assert.assertTrue(mf.rankInit == RankInitScheme.random);
float[][] rating = { {5, 3, 0, 1}, {4, 0, 0, 1}, {1, 1, 0, 5}, {1, 0, 0, 4}, {0, 1, 5, 4}};
Object[] args = new Object[3];
final int num_iters = 500;
int trainingExamples = 0;
for (int iter = 0; iter < num_iters; iter++) {
for (int row = 0; row < rating.length; row++) {
for (int col = 0, size = rating[row].length; col < size; col++) {
args[0] = row;
args[1] = col;
args[2] = (float) rating[row][col];
mf.process(args);
trainingExamples++;
}
}
}
File tmpFile = mf.fileIO.getFile();
mf.close();
Assert.assertEquals(trainingExamples * iters, mf.count);
Assert.assertEquals(5, numCollected.intValue());
Assert.assertFalse(tmpFile.exists());
}
示例14: testNews20BinarySparse
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
@Test
public void testNews20BinarySparse() throws IOException, ParseException, HiveException {
final int numTrees = 10;
RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF();
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 71 -trees "
+ numTrees);
udtf.initialize(new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
PrimitiveObjectInspectorFactory.javaIntObjectInspector, param});
BufferedReader news20 = readFile("news20-small.binary.gz");
ArrayList<String> features = new ArrayList<String>();
String line = news20.readLine();
while (line != null) {
StringTokenizer tokens = new StringTokenizer(line, " ");
int label = Integer.parseInt(tokens.nextToken());
if (label == -1) {
label = 0;
}
while (tokens.hasMoreTokens()) {
features.add(tokens.nextToken());
}
if (!features.isEmpty()) {
udtf.process(new Object[] {features, label});
features.clear();
}
line = news20.readLine();
}
news20.close();
final MutableInt count = new MutableInt(0);
final MutableInt oobErrors = new MutableInt(0);
final MutableInt oobTests = new MutableInt(0);
Collector collector = new Collector() {
public void collect(Object input) throws HiveException {
Object[] forward = (Object[]) input;
oobErrors.addValue(((IntWritable) forward[4]).get());
oobTests.addValue(((IntWritable) forward[5]).get());
count.addValue(1);
}
};
udtf.setCollector(collector);
udtf.close();
Assert.assertEquals(numTrees, count.getValue());
float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue();
Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3);
}
示例15: run
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; //导入方法依赖的package包/类
private void run(@Nonnull String options) throws Exception {
println(options);
ArrayList<List<String>> samplesList = new ArrayList<List<String>>();
samplesList.add(Arrays.asList("1:-2", "2:-1"));
samplesList.add(Arrays.asList("1:-1", "2:-1"));
samplesList.add(Arrays.asList("1:-1", "2:-2"));
samplesList.add(Arrays.asList("1:1", "2:1"));
samplesList.add(Arrays.asList("1:1", "2:2"));
samplesList.add(Arrays.asList("1:2", "2:1"));
int[] labels = new int[] {0, 0, 0, 1, 1, 1};
GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);
udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
for (int i = 0, size = samplesList.size(); i < size; i++) {
udtf.process(new Object[] {samplesList.get(i), labels[i]});
}
udtf.finalizeTraining();
double cumLoss = udtf.getCumulativeLoss();
println("Cumulative loss: " + cumLoss);
double normalizedLoss = cumLoss / samplesList.size();
Assert.assertTrue("cumLoss: " + cumLoss + ", normalizedLoss: " + normalizedLoss
+ "\noptions: " + options, normalizedLoss < 0.5d);
int numTests = 0;
int numCorrect = 0;
for (int i = 0, size = samplesList.size(); i < size; i++) {
int label = labels[i];
float score = udtf.predict(udtf.parseFeatures(samplesList.get(i)));
int predicted = score > 0.f ? 1 : 0;
println("Score: " + score + ", Predicted: " + predicted + ", Actual: " + label);
if (predicted == label) {
++numCorrect;
}
++numTests;
}
float accuracy = numCorrect / (float) numTests;
println("Accuracy: " + accuracy);
Assert.assertTrue(accuracy == 1.f);
}