当前位置: 首页>>代码示例>>Java>>正文


Java DRFModel类代码示例

本文整理汇总了Java中hex.tree.drf.DRFModel的典型用法代码示例。如果您正苦于以下问题:Java DRFModel类的具体用法?Java DRFModel怎么用?Java DRFModel使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


DRFModel类属于hex.tree.drf包,在下文中一共展示了DRFModel类的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testDRFModelBinomial

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testDRFModelBinomial() throws IOException {
  DRFModel model = null, loadedModel = null;
  try {
    model = prepareDRFModel("smalldata/logreg/prostate.csv", ar("ID"), "CAPSULE", true, 5);
    CompressedTree[][] trees = getTrees(model);
    loadedModel = saveAndLoad(model);
    // And compare
    assertModelBinaryEquals(model, loadedModel);
    CompressedTree[][] loadedTrees = getTrees(loadedModel);
    assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
  } finally {
    if (model!=null) model.delete();
    if (loadedModel!=null) loadedModel.delete();
  }
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:17,代码来源:ModelSerializationTest.java

示例2: prepareDRFModel

import hex.tree.drf.DRFModel; //导入依赖的package包/类
private DRFModel prepareDRFModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
  Frame f = parse_test_file(dataset);
  try {
    if (classification && !f.vec(response).isCategorical()) {
      f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
      DKV.put(f._key, f);
    }
    DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
    drfParams._train = f._key;
    drfParams._ignored_columns = ignoredColumns;
    drfParams._response_column = response;
    drfParams._ntrees = ntrees;
    drfParams._score_each_iteration = true;
    return new DRF(drfParams).trainModel().get();
  } finally {
    if (f!=null) f.delete();
  }
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:19,代码来源:ModelSerializationTest.java

示例3: testDRFModelMultinomial

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testDRFModelMultinomial() throws IOException {
  DRFModel model = null, loadedModel = null;
  try {
    model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
    CompressedTree[][] trees = getTrees(model);
    loadedModel = saveAndLoad(model);
    // And compare
    assertModelBinaryEquals(model, loadedModel);
    CompressedTree[][] loadedTrees = getTrees(loadedModel);
    assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
  } finally {
    if (loadedModel!=null) loadedModel.delete();
  }
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:16,代码来源:ModelSerializationTest.java

示例4: testPubDev2075

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testPubDev2075() {
  Frame fr = null;

  try {
    fr = parse_test_file("smalldata/junit/cars_20mpg.csv");
    fr.replace(fr.find("cylinders"), fr.vec("cylinders").toCategoricalVec()).remove();

    DRFModel.DRFParameters p1 = new DRFModel.DRFParameters();
    p1._train = fr._key;
    p1._response_column = "economy_20mpg";
    p1._ignored_columns = new String[]{"name", "columns", "cylinders"};
    p1._ntrees = 2;
    p1._max_depth = 5;
    p1._nbins = 6;
    p1._mtries = 2;
    p1._seed = 8887264963748798740L;

    DRFModel.DRFParameters p2 = new DRFModel.DRFParameters();
    p2._train = fr._key;
    p2._response_column = "economy_20mpg";
    p2._ignored_columns = new String[]{"name", "columns", "cylinders"};
    p2._ntrees = 5;
    p2._max_depth = 1;
    p2._nbins = 3;
    p2._mtries = 4;
    p2._seed = 8887264963748798740L;

    Assert.assertNotEquals(p1.checksum(), p2.checksum());
  } finally {
    if (fr != null) {
      fr.delete();
    }
  }

}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:37,代码来源:ModelParametersChecksumTest.java

示例5: testDRFModelMultinomial

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testDRFModelMultinomial() throws IOException {
  DRFModel model, loadedModel = null;
  try {
    model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
    CompressedTree[][] trees = getTrees(model);
    loadedModel = saveAndLoad(model);
    // And compare
    assertModelBinaryEquals(model, loadedModel);
    CompressedTree[][] loadedTrees = getTrees(loadedModel);
    assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
  } finally {
    if (loadedModel!=null) loadedModel.delete();
  }
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:16,代码来源:ModelSerializationTest.java

示例6: if

import hex.tree.drf.DRFModel; //导入依赖的package包/类
Job<DRFModel>defaultRandomForest() {
  if (exceededSearchLimits("DRF")) return null;

  DRFModel.DRFParameters drfParameters = new DRFModel.DRFParameters();
  setCommonModelBuilderParams(drfParameters);

  drfParameters._stopping_tolerance = this.buildSpec.build_control.stopping_criteria.stopping_tolerance();

  Job randomForestJob = trainModel(null, "drf", drfParameters);
  return randomForestJob;
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:12,代码来源:AutoML.java

示例7: buildModel

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override
public ModelBuilder buildModel(DRFModel.DRFParameters params) {
  return new DRF(params);
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:5,代码来源:ModelFactories.java

示例8: getModelFactory

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override
protected ModelFactory<DRFModel.DRFParameters> getModelFactory() {
  return ModelFactories.DRF_MODEL_FACTORY;
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:5,代码来源:DRFGridSearchHandler.java

示例9: createImpl

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override public DRFModel createImpl() {
  DRFV3.DRFParametersV3 p = this.parameters;
  DRFModel.DRFParameters parms = p.createImpl();
  return new DRFModel( model_id.key(), parms, new DRFModel.DRFOutput(null, Double.NaN, Double.NaN) );
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:6,代码来源:DRFModelV3.java

示例10: computeDRFMetalearner

import hex.tree.drf.DRFModel; //导入依赖的package包/类
void computeDRFMetalearner(){
    //DRF Metalearner
    DRF metaDRFBuilder;
    metaDRFBuilder = ModelBuilder.make("DRF", _metalearnerJob, _metalearnerKey);
    DRFV3.DRFParametersV3 params = new DRFV3.DRFParametersV3();
    params.init_meta();
    params.fillFromImpl(metaDRFBuilder._parms); // Defaults for this builder into schema

    //Metalearner parameters
    if (_hasMetalearnerParams) {
        Properties p = new Properties();
        HashMap<String, String[]> map = new Gson().fromJson(_metalearner_params, new TypeToken<HashMap<String, String[]>>() {
        }.getType());
        for (Map.Entry<String, String[]> param : map.entrySet()) {
            String[] paramVal = param.getValue();
            if (paramVal.length == 1) {
                p.setProperty(param.getKey(), paramVal[0]);
            } else {
                p.setProperty(param.getKey(), Arrays.toString(paramVal));
            }
            params.fillFromParms(p, true);
        }
        DRFModel.DRFParameters drfParams = params.createAndFillImpl();
        metaDRFBuilder._parms = drfParams;
    }
    metaDRFBuilder._parms._train = _levelOneTrainingFrame._key;
    metaDRFBuilder._parms._valid = (_levelOneValidationFrame == null ? null : _levelOneValidationFrame._key);
    metaDRFBuilder._parms._response_column = _model.responseColumn;
    metaDRFBuilder._parms._nfolds = _model._parms._metalearner_nfolds;  //cross-validation of the metalearner
    if (_model._parms._metalearner_fold_column == null) {
        metaDRFBuilder._parms._nfolds = _model._parms._metalearner_nfolds;  //cross-validation of the metalearner
        if (_model._parms._metalearner_nfolds > 1) {
            if (_model._parms._metalearner_fold_assignment == null) {
                metaDRFBuilder._parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
            } else {
                metaDRFBuilder._parms._fold_assignment = _model._parms._metalearner_fold_assignment;  //cross-validation of the metalearner
            }
        }
    } else {
        metaDRFBuilder._parms._fold_column = _model._parms._metalearner_fold_column;  //cross-validation of the metalearner
    }

    metaDRFBuilder.init(false);

    Job<DRFModel> j = metaDRFBuilder.trainModel();

    while (j.isRunning()) {
        try {
            _job.update(j._work, "training metalearner(" + _model._parms._metalearner_algorithm + ")");
            Thread.sleep(100);
        } catch (InterruptedException e) {
        }
    }

    Log.info("Finished training metalearner model(" + _model._parms._metalearner_algorithm + ").");

    _model._output._metalearner = metaDRFBuilder.get();
    _model.doScoreOrCopyMetrics(_job);
    if (_parms._keep_levelone_frame) {
        _model._output._levelone_frame_id = _levelOneTrainingFrame; //Keep Level One Training Frame in Stacked Ensemble model object
    } else {
        DKV.remove(_levelOneTrainingFrame._key); //Remove Level One Training Frame from DKV
    }
    if (null != _levelOneValidationFrame) {
        DKV.remove(_levelOneValidationFrame._key); //Remove Level One Validation Frame from DKV
    }
    _model.update(_job);
    _model.unlock(_job);
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:70,代码来源:Metalearner.java

示例11: createImpl

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override public DRFModel createImpl() {
  DRFV3.DRFParametersV3 p = this.parameters;
  DRFModel.DRFParameters parms = p.createImpl();
  return new DRFModel( model_id.key(), parms, new DRFModel.DRFOutput(null) );
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:6,代码来源:DRFModelV3.java

示例12: testXValPredictions

import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test public void testXValPredictions() {
  final int nfolds = 3;
  Frame tfr = null;
  try {
    // Load data, hack frames
    tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
    Frame foldId = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789)});
    tfr.add(foldId);
    DKV.put(tfr);

    // GBM
    GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
    parms._train = tfr._key;
    parms._response_column = "class";
    parms._ntrees = 1;
    parms._max_depth = 1;
    parms._fold_column = "foldId";
    parms._distribution = DistributionFamily.multinomial;
    parms._keep_cross_validation_predictions=true;
    GBM job = new GBM(parms);
    GBMModel gbm = job.trainModel().get();
    checkModel(gbm, foldId.anyVec(),3);

    // DRF
    DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
    parmsDRF._train = tfr._key;
    parmsDRF._response_column = "class";
    parmsDRF._ntrees = 1;
    parmsDRF._max_depth = 1;
    parmsDRF._fold_column = "foldId";
    parmsDRF._distribution = DistributionFamily.multinomial;
    parmsDRF._keep_cross_validation_predictions=true;
    DRF drfJob = new DRF(parmsDRF);
    DRFModel drf = drfJob.trainModel().get();
    checkModel(drf, foldId.anyVec(),3);

    // GLM
    GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
    parmsGLM._train = tfr._key;
    parmsGLM._response_column = "sepal_len";
    parmsGLM._fold_column = "foldId";
    parmsGLM._keep_cross_validation_predictions=true;
    GLM glmJob = new GLM(parmsGLM);
    GLMModel glm = glmJob.trainModel().get();
    checkModel(glm, foldId.anyVec(),1);

    // DL
    DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
    parmsDL._train = tfr._key;
    parmsDL._response_column = "class";
    parmsDL._hidden = new int[]{1};
    parmsDL._epochs = 1;
    parmsDL._fold_column = "foldId";
    parmsDL._keep_cross_validation_predictions=true;
    DeepLearning dlJob = new DeepLearning(parmsDL);
    DeepLearningModel dl = dlJob.trainModel().get();
    checkModel(dl, foldId.anyVec(),3);

  } finally {
    if (tfr != null) tfr.remove();
  }
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:63,代码来源:XValPredictionsCheck.java


注:本文中的hex.tree.drf.DRFModel类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。