本文整理汇总了Java中hex.tree.drf.DRFModel类的典型用法代码示例。如果您正苦于以下问题:Java DRFModel类的具体用法?Java DRFModel怎么用?Java DRFModel使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
DRFModel类属于hex.tree.drf包,在下文中一共展示了DRFModel类的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testDRFModelBinomial
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testDRFModelBinomial() throws IOException {
DRFModel model = null, loadedModel = null;
try {
model = prepareDRFModel("smalldata/logreg/prostate.csv", ar("ID"), "CAPSULE", true, 5);
CompressedTree[][] trees = getTrees(model);
loadedModel = saveAndLoad(model);
// And compare
assertModelBinaryEquals(model, loadedModel);
CompressedTree[][] loadedTrees = getTrees(loadedModel);
assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
} finally {
if (model!=null) model.delete();
if (loadedModel!=null) loadedModel.delete();
}
}
示例2: prepareDRFModel
import hex.tree.drf.DRFModel; //导入依赖的package包/类
private DRFModel prepareDRFModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
Frame f = parse_test_file(dataset);
try {
if (classification && !f.vec(response).isCategorical()) {
f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
DKV.put(f._key, f);
}
DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
drfParams._train = f._key;
drfParams._ignored_columns = ignoredColumns;
drfParams._response_column = response;
drfParams._ntrees = ntrees;
drfParams._score_each_iteration = true;
return new DRF(drfParams).trainModel().get();
} finally {
if (f!=null) f.delete();
}
}
示例3: testDRFModelMultinomial
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testDRFModelMultinomial() throws IOException {
DRFModel model = null, loadedModel = null;
try {
model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
CompressedTree[][] trees = getTrees(model);
loadedModel = saveAndLoad(model);
// And compare
assertModelBinaryEquals(model, loadedModel);
CompressedTree[][] loadedTrees = getTrees(loadedModel);
assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
} finally {
if (loadedModel!=null) loadedModel.delete();
}
}
示例4: testPubDev2075
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testPubDev2075() {
Frame fr = null;
try {
fr = parse_test_file("smalldata/junit/cars_20mpg.csv");
fr.replace(fr.find("cylinders"), fr.vec("cylinders").toCategoricalVec()).remove();
DRFModel.DRFParameters p1 = new DRFModel.DRFParameters();
p1._train = fr._key;
p1._response_column = "economy_20mpg";
p1._ignored_columns = new String[]{"name", "columns", "cylinders"};
p1._ntrees = 2;
p1._max_depth = 5;
p1._nbins = 6;
p1._mtries = 2;
p1._seed = 8887264963748798740L;
DRFModel.DRFParameters p2 = new DRFModel.DRFParameters();
p2._train = fr._key;
p2._response_column = "economy_20mpg";
p2._ignored_columns = new String[]{"name", "columns", "cylinders"};
p2._ntrees = 5;
p2._max_depth = 1;
p2._nbins = 3;
p2._mtries = 4;
p2._seed = 8887264963748798740L;
Assert.assertNotEquals(p1.checksum(), p2.checksum());
} finally {
if (fr != null) {
fr.delete();
}
}
}
示例5: testDRFModelMultinomial
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test
public void testDRFModelMultinomial() throws IOException {
DRFModel model, loadedModel = null;
try {
model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
CompressedTree[][] trees = getTrees(model);
loadedModel = saveAndLoad(model);
// And compare
assertModelBinaryEquals(model, loadedModel);
CompressedTree[][] loadedTrees = getTrees(loadedModel);
assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
} finally {
if (loadedModel!=null) loadedModel.delete();
}
}
示例6: if
import hex.tree.drf.DRFModel; //导入依赖的package包/类
Job<DRFModel>defaultRandomForest() {
if (exceededSearchLimits("DRF")) return null;
DRFModel.DRFParameters drfParameters = new DRFModel.DRFParameters();
setCommonModelBuilderParams(drfParameters);
drfParameters._stopping_tolerance = this.buildSpec.build_control.stopping_criteria.stopping_tolerance();
Job randomForestJob = trainModel(null, "drf", drfParameters);
return randomForestJob;
}
示例7: buildModel
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override
public ModelBuilder buildModel(DRFModel.DRFParameters params) {
return new DRF(params);
}
示例8: getModelFactory
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override
protected ModelFactory<DRFModel.DRFParameters> getModelFactory() {
return ModelFactories.DRF_MODEL_FACTORY;
}
示例9: createImpl
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override public DRFModel createImpl() {
DRFV3.DRFParametersV3 p = this.parameters;
DRFModel.DRFParameters parms = p.createImpl();
return new DRFModel( model_id.key(), parms, new DRFModel.DRFOutput(null, Double.NaN, Double.NaN) );
}
示例10: computeDRFMetalearner
import hex.tree.drf.DRFModel; //导入依赖的package包/类
void computeDRFMetalearner(){
//DRF Metalearner
DRF metaDRFBuilder;
metaDRFBuilder = ModelBuilder.make("DRF", _metalearnerJob, _metalearnerKey);
DRFV3.DRFParametersV3 params = new DRFV3.DRFParametersV3();
params.init_meta();
params.fillFromImpl(metaDRFBuilder._parms); // Defaults for this builder into schema
//Metalearner parameters
if (_hasMetalearnerParams) {
Properties p = new Properties();
HashMap<String, String[]> map = new Gson().fromJson(_metalearner_params, new TypeToken<HashMap<String, String[]>>() {
}.getType());
for (Map.Entry<String, String[]> param : map.entrySet()) {
String[] paramVal = param.getValue();
if (paramVal.length == 1) {
p.setProperty(param.getKey(), paramVal[0]);
} else {
p.setProperty(param.getKey(), Arrays.toString(paramVal));
}
params.fillFromParms(p, true);
}
DRFModel.DRFParameters drfParams = params.createAndFillImpl();
metaDRFBuilder._parms = drfParams;
}
metaDRFBuilder._parms._train = _levelOneTrainingFrame._key;
metaDRFBuilder._parms._valid = (_levelOneValidationFrame == null ? null : _levelOneValidationFrame._key);
metaDRFBuilder._parms._response_column = _model.responseColumn;
metaDRFBuilder._parms._nfolds = _model._parms._metalearner_nfolds; //cross-validation of the metalearner
if (_model._parms._metalearner_fold_column == null) {
metaDRFBuilder._parms._nfolds = _model._parms._metalearner_nfolds; //cross-validation of the metalearner
if (_model._parms._metalearner_nfolds > 1) {
if (_model._parms._metalearner_fold_assignment == null) {
metaDRFBuilder._parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
} else {
metaDRFBuilder._parms._fold_assignment = _model._parms._metalearner_fold_assignment; //cross-validation of the metalearner
}
}
} else {
metaDRFBuilder._parms._fold_column = _model._parms._metalearner_fold_column; //cross-validation of the metalearner
}
metaDRFBuilder.init(false);
Job<DRFModel> j = metaDRFBuilder.trainModel();
while (j.isRunning()) {
try {
_job.update(j._work, "training metalearner(" + _model._parms._metalearner_algorithm + ")");
Thread.sleep(100);
} catch (InterruptedException e) {
}
}
Log.info("Finished training metalearner model(" + _model._parms._metalearner_algorithm + ").");
_model._output._metalearner = metaDRFBuilder.get();
_model.doScoreOrCopyMetrics(_job);
if (_parms._keep_levelone_frame) {
_model._output._levelone_frame_id = _levelOneTrainingFrame; //Keep Level One Training Frame in Stacked Ensemble model object
} else {
DKV.remove(_levelOneTrainingFrame._key); //Remove Level One Training Frame from DKV
}
if (null != _levelOneValidationFrame) {
DKV.remove(_levelOneValidationFrame._key); //Remove Level One Validation Frame from DKV
}
_model.update(_job);
_model.unlock(_job);
}
示例11: createImpl
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Override public DRFModel createImpl() {
DRFV3.DRFParametersV3 p = this.parameters;
DRFModel.DRFParameters parms = p.createImpl();
return new DRFModel( model_id.key(), parms, new DRFModel.DRFOutput(null) );
}
示例12: testXValPredictions
import hex.tree.drf.DRFModel; //导入依赖的package包/类
@Test public void testXValPredictions() {
final int nfolds = 3;
Frame tfr = null;
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
Frame foldId = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789)});
tfr.add(foldId);
DKV.put(tfr);
// GBM
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "class";
parms._ntrees = 1;
parms._max_depth = 1;
parms._fold_column = "foldId";
parms._distribution = DistributionFamily.multinomial;
parms._keep_cross_validation_predictions=true;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
checkModel(gbm, foldId.anyVec(),3);
// DRF
DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
parmsDRF._train = tfr._key;
parmsDRF._response_column = "class";
parmsDRF._ntrees = 1;
parmsDRF._max_depth = 1;
parmsDRF._fold_column = "foldId";
parmsDRF._distribution = DistributionFamily.multinomial;
parmsDRF._keep_cross_validation_predictions=true;
DRF drfJob = new DRF(parmsDRF);
DRFModel drf = drfJob.trainModel().get();
checkModel(drf, foldId.anyVec(),3);
// GLM
GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
parmsGLM._train = tfr._key;
parmsGLM._response_column = "sepal_len";
parmsGLM._fold_column = "foldId";
parmsGLM._keep_cross_validation_predictions=true;
GLM glmJob = new GLM(parmsGLM);
GLMModel glm = glmJob.trainModel().get();
checkModel(glm, foldId.anyVec(),1);
// DL
DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
parmsDL._train = tfr._key;
parmsDL._response_column = "class";
parmsDL._hidden = new int[]{1};
parmsDL._epochs = 1;
parmsDL._fold_column = "foldId";
parmsDL._keep_cross_validation_predictions=true;
DeepLearning dlJob = new DeepLearning(parmsDL);
DeepLearningModel dl = dlJob.trainModel().get();
checkModel(dl, foldId.anyVec(),3);
} finally {
if (tfr != null) tfr.remove();
}
}