当前位置: 首页>>代码示例>>Java>>正文


Java DMatrix类代码示例

本文整理汇总了Java中ml.dmlc.xgboost4j.java.DMatrix的典型用法代码示例。如果您正苦于以下问题:Java DMatrix类的具体用法?Java DMatrix怎么用?Java DMatrix使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


DMatrix类属于ml.dmlc.xgboost4j.java包,在下文中一共展示了DMatrix类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: close

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Override
public void close() throws HiveException {
    try {
        // Kick off training with XGBoost
        final DMatrix trainData = new DMatrix(featuresList.iterator(), "");
        final Booster booster = createXGBooster(params, featuresList);
        final int num_round = (Integer) params.get("num_round");
        for (int i = 0; i < num_round; i++) {
            booster.update(trainData, i);
        }

        // Output the built model
        final String modelId = generateUniqueModelId();
        final byte[] predModel = booster.toByteArray();
        logger.info("model_id:" + modelId.toString() + " size:" + predModel.length);
        forward(new Object[] {modelId, predModel});
    } catch (Exception e) {
        throw new HiveException(e);
    }
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:21,代码来源:XGBoostUDTF.java

示例2: buildClassifier

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public void buildClassifier(Instances instances) throws Exception {
    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    DMatrix dmat = DMatrixLoader.instancesToDMatrix(instances);

    Map<String, DMatrix> watches = new HashMap<>();
    watches.put("train", dmat);

    if (!params.containsKey("num_class")) {
        params.put("num_class", instances.numClasses());
    }

    booster = ml.dmlc.xgboost4j.java.XGBoost.train(dmat, params, numRound, watches, null, null);
    if (usePredictor) {
        // Load model and create Predictor
        this.predictor = new Predictor(new ByteArrayInputStream(booster.toByteArray()));

    }

}
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:26,代码来源:XGBoost.java

示例3: instancesToDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instancesToDMatrix(Instances instances) throws XGBoostError {
    long[] rowHeaders = new long[instances.size()+1];
    rowHeaders[0]=0;
    List<Float> dataList = new ArrayList<>();
    List<Integer> colList = new ArrayList<>();
    float[] labels = new float[instances.size()];

    for(int i=0; i<instances.size(); i++) {
        Instance instance = instances.get(i);
        rowHeaders[i] = dataList.size();
        processInstance(instance, dataList, colList);
        labels[i] = (float) instance.classValue();
    }
    rowHeaders[rowHeaders.length - 1] = dataList.size();
    int colNum = instances.numAttributes()-1;
    DMatrix dMatrix = createDMatrix(rowHeaders, dataList, colList, colNum);

    dMatrix.setLabel(labels);
    return dMatrix;

}
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:22,代码来源:DMatrixLoader.java

示例4: instanceToDenseDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instanceToDenseDMatrix(Instance instance) throws XGBoostError {
    Attribute classAttribute = instance.classAttribute();
    int classAttrIndex = classAttribute.index();

    int colNum = instance.numAttributes()-1;
    int rowNum = 1;

    float[] data = new float[colNum*rowNum];

    Enumeration<Attribute> attributeEnumeration = instance.enumerateAttributes();
    int dataIndex = 0;
    while (attributeEnumeration.hasMoreElements()) {
        Attribute attribute = attributeEnumeration.nextElement();
        int attrIndex = attribute.index();
        if(attrIndex == classAttrIndex){
            continue;
        }
        data[dataIndex]= (float) instance.value(attribute);
        dataIndex++;
    }

    return new DMatrix(data, rowNum, colNum);
}
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:24,代码来源:DMatrixLoader.java

示例5: makeMetrics

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
private ModelMetrics makeMetrics(Booster booster, DMatrix data, Frame dataFrame, String description,
                                 Key<Frame> predFrameKey) throws XGBoostError {
  Futures fs = new Futures();
  ModelMetrics[] mms = new ModelMetrics[1];
  Frame predictions = makePreds(booster, data, mms, true, predFrameKey, fs);
  if (predFrameKey == null) {
      predictions.remove(fs);
  } else {
    DKV.put(predictions, fs);
  }
  fs.blockForPending();
  ModelMetrics mm = mms[0]
      .withModelAndFrame(this, dataFrame)
      .withDescription(description);
  return mm;
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:17,代码来源:XGBoostModel.java

示例6: score

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Override
public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics) throws IllegalArgumentException {
  Frame adaptFr = new Frame(fr);
  computeMetrics = computeMetrics && (!isSupervised() || (adaptFr.vec(_output.responseName()) != null && !adaptFr.vec(_output.responseName()).isBad()));
  String[] msg = adaptTestForTrain(adaptFr,true, computeMetrics);   // Adapt
  if (msg.length > 0) {
    for (String s : msg)
      Log.warn(s);
  }
  try {
    DMatrix trainMat = convertFrametoDMatrix( model_info()._dataInfoKey, adaptFr,
        _parms._response_column, _parms._weights_column, _parms._fold_column, null, _output._sparse);
    Key<Frame> destFrameKey = Key.make(destination_key);
    if (computeMetrics){
      ModelMetrics mm = makeMetrics(model_info().booster(), trainMat, fr, "Prediction on frame " + fr._key, destFrameKey);
      // Update model with newly computed model metrics
      this.addModelMetrics(mm);
      DKV.put(this);
    } else
      makePredsOnly(model_info().booster(), trainMat, destFrameKey);
    return destFrameKey.get();
  } catch (XGBoostError xgBoostError) {
    throw new RuntimeException(xgBoostError);
  }
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:26,代码来源:XGBoostModel.java

示例7: BasicModel

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Test public void BasicModel() throws XGBoostError {
  // load file from text file, also binary buffer generated by xgboost4j
  DMatrix[] mat = getMatrices();
  DMatrix trainMat = mat[0];
  DMatrix testMat = mat[1];

  HashMap<String, Object> params = new HashMap<>();
  params.put("eta", 0.1);
  params.put("max_depth", 5);
  params.put("silent", 1);
  params.put("objective", "binary:logistic");

  HashMap<String, DMatrix> watches = new HashMap<>();
  watches.put("train", trainMat);
  watches.put("test",  testMat);

  Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
  float[][] preds = booster.predict(testMat);
  for (int i=0;i<10;++i)
    Log.info(preds[i][0]);
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:22,代码来源:XGBoostTest.java

示例8: checkpoint

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Test
public void checkpoint() throws XGBoostError, IOException {
  // load file from text file, also binary buffer generated by xgboost4j
  DMatrix[] mat = getMatrices();
  DMatrix trainMat = mat[0];
  DMatrix testMat = mat[1];

  HashMap<String, Object> params = new HashMap<>();
  params.put("eta", 0.1);
  params.put("max_depth", 5);
  params.put("silent", 1);
  params.put("objective", "binary:logistic");

  HashMap<String, DMatrix> watches = new HashMap<>();
  watches.put("train", trainMat);

  Booster booster = XGBoost.train(trainMat, params, 0, watches, null, null);

  // Train for 10 iterations
  for (int i=0;i<10;++i) {
    booster.update(trainMat, i);
    float[][] preds = booster.predict(testMat);
    for (int j = 0; j < 10; ++j)
      Log.info(preds[j][0]);
  }
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:27,代码来源:XGBoostTest.java

示例9: predictWithDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
private double[] predictWithDMatrix(Instance instance) throws XGBoostError {
        DMatrix dmat = DMatrixLoader.instanceToDMatrix(instance);
        float[][] predict = booster.predict(dmat);
        double[] predictDouble = new double[predict[0].length];
        for (int i = 0; i < predict[0].length; i++) {
//            predictDouble[i] = Double.valueOf(String.valueOf(predict[i]));
            predictDouble[i] = predict[0][i];
        }
        return predictDouble;
    }
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:11,代码来源:XGBoost.java

示例10: instancesToDenseDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instancesToDenseDMatrix(Instances instances) throws XGBoostError {
    int colNum = instances.numAttributes()-1;
    int rowNum = instances.size();

    float[] data = new float[colNum*rowNum];
    float[] labels = new float[instances.size()];
    Attribute classAttribute = instances.classAttribute();
    int classAttrIndex = classAttribute.index();

    for(int i=0, dataIndex = 0; i<instances.size(); i++) {
        Instance instance = instances.get(i);

        labels[i] = (float) instance.classValue();
        Enumeration<Attribute> attributeEnumeration = instance.enumerateAttributes();
        while (attributeEnumeration.hasMoreElements()) {
            Attribute attribute = attributeEnumeration.nextElement();
            int attrIndex = attribute.index();
            if(attrIndex == classAttrIndex){
                continue;
            }
            data[dataIndex]= (float) instance.value(attribute);
            dataIndex++;
        }
    }


    DMatrix dMatrix = new DMatrix(data, rowNum, colNum);

    dMatrix.setLabel(labels);
    return dMatrix;

}
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:33,代码来源:DMatrixLoader.java

示例11: instanceToDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instanceToDMatrix(Instance instance) throws XGBoostError {

        List<Float> dataList = new ArrayList<>();
        List<Integer> colList = new ArrayList<>();

        processInstance(instance, dataList, colList);
        long[] rowHeaders = new long[]{0, dataList.size()};

        int colNum = instance.numAttributes()-1;
        return createDMatrix(rowHeaders, dataList, colList, colNum);
    }
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:12,代码来源:DMatrixLoader.java

示例12: createDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
protected static DMatrix createDMatrix(long[] rowHeaders, List<Float> dataList, List<Integer> colList, int colNum) throws XGBoostError {
    float[] data = new float[dataList.size()];
    int[] colIndices = new int[dataList.size()];
    colIndices[0] = 0;
    for(int i=0; i<dataList.size(); i++) {
        data[i] = dataList.get(i);
        colIndices[i] = colList.get(i);
    }

    return new DMatrix(rowHeaders, colIndices, data, DMatrix.SparseType.CSR, colNum);
}
 
开发者ID:SigDelta,项目名称:weka-xgboost,代码行数:12,代码来源:DMatrixLoader.java

示例13: createXGBooster

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Nonnull
private static Booster createXGBooster(final Map<String, Object> params,
        final List<LabeledPoint> input) throws NoSuchMethodException, XGBoostError,
        IllegalAccessException, InvocationTargetException, InstantiationException {
    Class<?>[] args = {Map.class, DMatrix[].class};
    Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
    ctor.setAccessible(true);
    return ctor.newInstance(new Object[] {params,
            new DMatrix[] {new DMatrix(input.iterator(), "")}});
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:11,代码来源:XGBoostUDTF.java

示例14: createDMatrix

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Nonnull
private static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data)
        throws XGBoostError {
    final List<LabeledPoint> points = new ArrayList<>(data.size());
    for (LabeledPointWithRowId d : data) {
        points.add(d.point);
    }
    return new DMatrix(points.iterator(), "");
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:10,代码来源:XGBoostPredictUDTF.java

示例15: predictAndFlush

import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
private void predictAndFlush(final Booster model, final List<LabeledPointWithRowId> buf)
        throws HiveException {
    final DMatrix testData;
    final float[][] predicted;
    try {
        testData = createDMatrix(buf);
        predicted = model.predict(testData);
    } catch (XGBoostError e) {
        throw new HiveException(e);
    }
    forwardPredicted(buf, predicted);
    buf.clear();
}
 
开发者ID:apache,项目名称:incubator-hivemall,代码行数:14,代码来源:XGBoostPredictUDTF.java


注:本文中的ml.dmlc.xgboost4j.java.DMatrix类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。