本文整理汇总了Java中ml.dmlc.xgboost4j.java.DMatrix类的典型用法代码示例。如果您正苦于以下问题:Java DMatrix类的具体用法?Java DMatrix怎么用?Java DMatrix使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
DMatrix类属于ml.dmlc.xgboost4j.java包,在下文中一共展示了DMatrix类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: close
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Override
public void close() throws HiveException {
try {
// Kick off training with XGBoost
final DMatrix trainData = new DMatrix(featuresList.iterator(), "");
final Booster booster = createXGBooster(params, featuresList);
final int num_round = (Integer) params.get("num_round");
for (int i = 0; i < num_round; i++) {
booster.update(trainData, i);
}
// Output the built model
final String modelId = generateUniqueModelId();
final byte[] predModel = booster.toByteArray();
logger.info("model_id:" + modelId.toString() + " size:" + predModel.length);
forward(new Object[] {modelId, predModel});
} catch (Exception e) {
throw new HiveException(e);
}
}
示例2: buildClassifier
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public void buildClassifier(Instances instances) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(instances);
// remove instances with missing class
instances = new Instances(instances);
instances.deleteWithMissingClass();
DMatrix dmat = DMatrixLoader.instancesToDMatrix(instances);
Map<String, DMatrix> watches = new HashMap<>();
watches.put("train", dmat);
if (!params.containsKey("num_class")) {
params.put("num_class", instances.numClasses());
}
booster = ml.dmlc.xgboost4j.java.XGBoost.train(dmat, params, numRound, watches, null, null);
if (usePredictor) {
// Load model and create Predictor
this.predictor = new Predictor(new ByteArrayInputStream(booster.toByteArray()));
}
}
示例3: instancesToDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instancesToDMatrix(Instances instances) throws XGBoostError {
long[] rowHeaders = new long[instances.size()+1];
rowHeaders[0]=0;
List<Float> dataList = new ArrayList<>();
List<Integer> colList = new ArrayList<>();
float[] labels = new float[instances.size()];
for(int i=0; i<instances.size(); i++) {
Instance instance = instances.get(i);
rowHeaders[i] = dataList.size();
processInstance(instance, dataList, colList);
labels[i] = (float) instance.classValue();
}
rowHeaders[rowHeaders.length - 1] = dataList.size();
int colNum = instances.numAttributes()-1;
DMatrix dMatrix = createDMatrix(rowHeaders, dataList, colList, colNum);
dMatrix.setLabel(labels);
return dMatrix;
}
示例4: instanceToDenseDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instanceToDenseDMatrix(Instance instance) throws XGBoostError {
Attribute classAttribute = instance.classAttribute();
int classAttrIndex = classAttribute.index();
int colNum = instance.numAttributes()-1;
int rowNum = 1;
float[] data = new float[colNum*rowNum];
Enumeration<Attribute> attributeEnumeration = instance.enumerateAttributes();
int dataIndex = 0;
while (attributeEnumeration.hasMoreElements()) {
Attribute attribute = attributeEnumeration.nextElement();
int attrIndex = attribute.index();
if(attrIndex == classAttrIndex){
continue;
}
data[dataIndex]= (float) instance.value(attribute);
dataIndex++;
}
return new DMatrix(data, rowNum, colNum);
}
示例5: makeMetrics
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
private ModelMetrics makeMetrics(Booster booster, DMatrix data, Frame dataFrame, String description,
Key<Frame> predFrameKey) throws XGBoostError {
Futures fs = new Futures();
ModelMetrics[] mms = new ModelMetrics[1];
Frame predictions = makePreds(booster, data, mms, true, predFrameKey, fs);
if (predFrameKey == null) {
predictions.remove(fs);
} else {
DKV.put(predictions, fs);
}
fs.blockForPending();
ModelMetrics mm = mms[0]
.withModelAndFrame(this, dataFrame)
.withDescription(description);
return mm;
}
示例6: score
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Override
public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics) throws IllegalArgumentException {
Frame adaptFr = new Frame(fr);
computeMetrics = computeMetrics && (!isSupervised() || (adaptFr.vec(_output.responseName()) != null && !adaptFr.vec(_output.responseName()).isBad()));
String[] msg = adaptTestForTrain(adaptFr,true, computeMetrics); // Adapt
if (msg.length > 0) {
for (String s : msg)
Log.warn(s);
}
try {
DMatrix trainMat = convertFrametoDMatrix( model_info()._dataInfoKey, adaptFr,
_parms._response_column, _parms._weights_column, _parms._fold_column, null, _output._sparse);
Key<Frame> destFrameKey = Key.make(destination_key);
if (computeMetrics){
ModelMetrics mm = makeMetrics(model_info().booster(), trainMat, fr, "Prediction on frame " + fr._key, destFrameKey);
// Update model with newly computed model metrics
this.addModelMetrics(mm);
DKV.put(this);
} else
makePredsOnly(model_info().booster(), trainMat, destFrameKey);
return destFrameKey.get();
} catch (XGBoostError xgBoostError) {
throw new RuntimeException(xgBoostError);
}
}
示例7: BasicModel
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Test public void BasicModel() throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix[] mat = getMatrices();
DMatrix trainMat = mat[0];
DMatrix testMat = mat[1];
HashMap<String, Object> params = new HashMap<>();
params.put("eta", 0.1);
params.put("max_depth", 5);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainMat);
watches.put("test", testMat);
Booster booster = XGBoost.train(trainMat, params, 10, watches, null, null);
float[][] preds = booster.predict(testMat);
for (int i=0;i<10;++i)
Log.info(preds[i][0]);
}
示例8: checkpoint
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Test
public void checkpoint() throws XGBoostError, IOException {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix[] mat = getMatrices();
DMatrix trainMat = mat[0];
DMatrix testMat = mat[1];
HashMap<String, Object> params = new HashMap<>();
params.put("eta", 0.1);
params.put("max_depth", 5);
params.put("silent", 1);
params.put("objective", "binary:logistic");
HashMap<String, DMatrix> watches = new HashMap<>();
watches.put("train", trainMat);
Booster booster = XGBoost.train(trainMat, params, 0, watches, null, null);
// Train for 10 iterations
for (int i=0;i<10;++i) {
booster.update(trainMat, i);
float[][] preds = booster.predict(testMat);
for (int j = 0; j < 10; ++j)
Log.info(preds[j][0]);
}
}
示例9: predictWithDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
private double[] predictWithDMatrix(Instance instance) throws XGBoostError {
DMatrix dmat = DMatrixLoader.instanceToDMatrix(instance);
float[][] predict = booster.predict(dmat);
double[] predictDouble = new double[predict[0].length];
for (int i = 0; i < predict[0].length; i++) {
// predictDouble[i] = Double.valueOf(String.valueOf(predict[i]));
predictDouble[i] = predict[0][i];
}
return predictDouble;
}
示例10: instancesToDenseDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instancesToDenseDMatrix(Instances instances) throws XGBoostError {
int colNum = instances.numAttributes()-1;
int rowNum = instances.size();
float[] data = new float[colNum*rowNum];
float[] labels = new float[instances.size()];
Attribute classAttribute = instances.classAttribute();
int classAttrIndex = classAttribute.index();
for(int i=0, dataIndex = 0; i<instances.size(); i++) {
Instance instance = instances.get(i);
labels[i] = (float) instance.classValue();
Enumeration<Attribute> attributeEnumeration = instance.enumerateAttributes();
while (attributeEnumeration.hasMoreElements()) {
Attribute attribute = attributeEnumeration.nextElement();
int attrIndex = attribute.index();
if(attrIndex == classAttrIndex){
continue;
}
data[dataIndex]= (float) instance.value(attribute);
dataIndex++;
}
}
DMatrix dMatrix = new DMatrix(data, rowNum, colNum);
dMatrix.setLabel(labels);
return dMatrix;
}
示例11: instanceToDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
public static DMatrix instanceToDMatrix(Instance instance) throws XGBoostError {
List<Float> dataList = new ArrayList<>();
List<Integer> colList = new ArrayList<>();
processInstance(instance, dataList, colList);
long[] rowHeaders = new long[]{0, dataList.size()};
int colNum = instance.numAttributes()-1;
return createDMatrix(rowHeaders, dataList, colList, colNum);
}
示例12: createDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
protected static DMatrix createDMatrix(long[] rowHeaders, List<Float> dataList, List<Integer> colList, int colNum) throws XGBoostError {
float[] data = new float[dataList.size()];
int[] colIndices = new int[dataList.size()];
colIndices[0] = 0;
for(int i=0; i<dataList.size(); i++) {
data[i] = dataList.get(i);
colIndices[i] = colList.get(i);
}
return new DMatrix(rowHeaders, colIndices, data, DMatrix.SparseType.CSR, colNum);
}
示例13: createXGBooster
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Nonnull
private static Booster createXGBooster(final Map<String, Object> params,
final List<LabeledPoint> input) throws NoSuchMethodException, XGBoostError,
IllegalAccessException, InvocationTargetException, InstantiationException {
Class<?>[] args = {Map.class, DMatrix[].class};
Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
ctor.setAccessible(true);
return ctor.newInstance(new Object[] {params,
new DMatrix[] {new DMatrix(input.iterator(), "")}});
}
示例14: createDMatrix
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
@Nonnull
private static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data)
throws XGBoostError {
final List<LabeledPoint> points = new ArrayList<>(data.size());
for (LabeledPointWithRowId d : data) {
points.add(d.point);
}
return new DMatrix(points.iterator(), "");
}
示例15: predictAndFlush
import ml.dmlc.xgboost4j.java.DMatrix; //导入依赖的package包/类
private void predictAndFlush(final Booster model, final List<LabeledPointWithRowId> buf)
throws HiveException {
final DMatrix testData;
final float[][] predicted;
try {
testData = createDMatrix(buf);
predicted = model.predict(testData);
} catch (XGBoostError e) {
throw new HiveException(e);
}
forwardPredicted(buf, predicted);
buf.clear();
}