当前位置: 首页>>代码示例>>Java>>正文


Java RebalanceDataSet类代码示例

本文整理汇总了Java中water.fvec.RebalanceDataSet的典型用法代码示例。如果您正苦于以下问题:Java RebalanceDataSet类的具体用法?Java RebalanceDataSet怎么用?Java RebalanceDataSet使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


RebalanceDataSet类属于water.fvec包,在下文中一共展示了RebalanceDataSet类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: reBalance

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
/**
 * Rebalance a frame for load balancing
 * @param fr Input frame
 * @param local whether to only create enough chunks to max out all cores on one node only
 * @return Frame that has potentially more chunks
 */
private Frame reBalance(final Frame fr, boolean local, final String name) {
  int chunks = (int)Math.min( 4 * H2O.NUMCPUS * (local ? 1 : H2O.CLOUD.size()), fr.numRows());
  if (fr.anyVec().nChunks() > chunks && !_parms._reproducible) {
    Log.info("Dataset already contains " + fr.anyVec().nChunks() + " chunks. No need to rebalance.");
    return fr;
  } else if (_parms._reproducible) {
    Log.warn("Reproducibility enforced - using only 1 thread - can be slow.");
    chunks = 1;
  }
  if (!_parms._quiet_mode) Log.info("ReBalancing dataset into (at least) " + chunks + " chunks.");
  Key newKey = Key.make(name + ".chunks" + chunks);
  RebalanceDataSet rb = new RebalanceDataSet(fr, newKey, chunks);
  H2O.submitTask(rb);
  rb.join();
  Frame f = DKV.get(newKey).get();
  _delete_me.add(f);
  return f;
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:25,代码来源:DeepLearning.java

示例2: reBalance

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
/**
   * Rebalance a frame for load balancing
   * @param fr Input frame
   * @param local whether to only create enough chunks to max out all cores on one node only
   * @return Frame that has potentially more chunks
   */
  private Frame reBalance(final Frame fr, boolean local) {
    int chunks = (int)Math.min( 4 * H2O.NUMCPUS * (local ? 1 : H2O.CLOUD.size()), fr.numRows());
    if (fr.anyVec().nChunks() > chunks && !reproducible) {
      Log.info("Dataset already contains " + fr.anyVec().nChunks() + " chunks. No need to rebalance.");
      return fr;
    } else if (reproducible) {
      Log.warn("Reproducibility enforced - using only 1 thread - can be slow.");
      chunks = 1;
    }
    if (!quiet_mode) Log.info("ReBalancing dataset into (at least) " + chunks + " chunks.");
//      return MRUtils.shuffleAndBalance(fr, chunks, seed, local, shuffle_training_data);
    String snewKey = fr._key != null ? (fr._key.toString() + ".balanced") : Key.rand();
    Key newKey = Key.makeSystem(snewKey);
    RebalanceDataSet rb = new RebalanceDataSet(fr, newKey, chunks);
    H2O.submitTask(rb);
    rb.join();
    return UKV.get(newKey);
  }
 
开发者ID:h2oai,项目名称:h2o-2,代码行数:25,代码来源:DeepLearning.java

示例3: testTranspose

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test
  public void testTranspose(){
    Futures fs = new Futures();
    Key parsed = Key.make("prostate_parsed");
    Key modelKey = Key.make("prostate_model");
    GLMModel model = null;
    File f = TestUtil.find_test_file("smalldata/glm_test/prostate_cat_replaced.csv");
    Frame fr = getFrameForFile(parsed, "smalldata/glm_test/prostate_cat_replaced.csv");
    fr.remove("RACE").remove(fs);
    Key k = Key.make("rebalanced");
    H2O.submitTask(new RebalanceDataSet(fr, k, 64)).join();
    fr.delete();
    fr = DKV.get(k).get();
    Frame tr = DMatrix.transpose(fr);
    tr.reloadVecs();
    for(int i = 0; i < fr.numRows(); ++i)
      for(int j = 0; j < fr.numCols(); ++j)
        assertEquals(fr.vec(j).at(i),tr.vec(i).at(j),1e-4);
    fr.delete();
    for(Vec v:tr.vecs())
      v.remove(fs);
    fs.blockForPending();
//    checkLeakedKeys();
  }
 
开发者ID:h2oai,项目名称:h2o-2,代码行数:25,代码来源:MatrixTest.java

示例4: reBalanceFrames

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
public void reBalanceFrames () {
  final Key [] keySet = H2O.KeySnapshot.globalSnapshot().keys();
  for (Key key : keySet) {
    final Value val = DKV.get(key);
    if (val == null || !val.isFrame()) continue;
    final Frame fr = val.get();
    if (!fr._key.toString().contains("balanced")) {
      final int splits = Math.min((int)fr.numRows(), 4*H2O.NUMCPUS*H2O.CLOUD.size());
      final String name = fr._key.toString() + ".rebalanced";
      Log.info("Load balancing frame under key '" + fr._key.toString() + "' into " + splits + " splits.");
      try {
        final Key frHexBalanced = Key.make(name);
        new RebalanceDataSet(fr, frHexBalanced, splits).invoke();
      } catch(Exception ex) {
        Log.err(ex.getMessage());
      }
    }
  }
}
 
开发者ID:h2oai,项目名称:h2o-2,代码行数:20,代码来源:LoadDatasets.java

示例5: testPubDev928

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
/** Load simple dataset, rebalance to a number of chunks > number of rows, and run deep learning */
@Test public void testPubDev928() {
  // Create rebalanced dataset
  Key rebalancedKey = Key.make("rebalanced");
  NFSFileVec nfs = NFSFileVec.make(find_test_file("smalldata/logreg/prostate.csv"));
  Frame fr = ParseDataset.parse(Key.make(), nfs._key);
  RebalanceDataSet rb = new RebalanceDataSet(fr, rebalancedKey, (int)(fr.numRows()+1));
  H2O.submitTask(rb);
  rb.join();
  Frame rebalanced = DKV.get(rebalancedKey).get();

  // Assert that there is at least one 0-len chunk
  assertZeroLengthChunk("Rebalanced dataset should contain at least one 0-len chunk!", rebalanced.anyVec());

  DeepLearningModel dlModel = null;
  try {
    // Launch Deep Learning
    DeepLearningParameters dlParams = new DeepLearningParameters();
    dlParams._train = rebalancedKey;
    dlParams._epochs = 5;
    dlParams._response_column = "CAPSULE";

    dlModel = new DeepLearning(dlParams).trainModel().get();
  } finally {
    fr.delete();
    rebalanced.delete();
    if (dlModel != null) dlModel.delete();
  }
}
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:30,代码来源:DeepLearningScoreTest.java

示例6: testChunks

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testChunks() {
  Frame frame = parse_test_file("smalldata/covtype/covtype.20k.data");

  AggregatorModel.AggregatorParameters parms = new AggregatorModel.AggregatorParameters();
  parms._train = frame._key;
  parms._target_num_exemplars = 137;
  parms._rel_tol_num_exemplars = 0.05;
  long start = System.currentTimeMillis();
  AggregatorModel agg = new Aggregator(parms).trainModel().get();  // 0.418
  System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds");    agg.checkConsistency();
  Frame output = agg._output._output_frame.get();
  checkNumExemplars(agg);
  output.remove();
  agg.remove();

  for (int i : new int[]{1,2,5,10,50,100}) {
    Key key = Key.make();
    RebalanceDataSet rb = new RebalanceDataSet(frame, key, i);
    H2O.submitTask(rb);
    rb.join();
    Frame rebalanced = DKV.get(key).get();

    parms = new AggregatorModel.AggregatorParameters();
    parms._train = frame._key;
    parms._target_num_exemplars = 137;
    parms._rel_tol_num_exemplars = 0.05;
    start = System.currentTimeMillis();
    AggregatorModel agg2 = new Aggregator(parms).trainModel().get();  // 0.373 0.504 0.357 0.454 0.368 0.355
    System.out.println("AggregatorModel finished in: " + (System.currentTimeMillis() - start)/1000. + " seconds");      agg2.checkConsistency();
    Log.info("Number of exemplars for " + i + " chunks: " + agg2._exemplars.length);
    rebalanced.delete();
    Assert.assertTrue(Math.abs(agg._exemplars.length - agg2._exemplars.length) == 0);
    output = agg2._output._output_frame.get();
    output.remove();
    checkNumExemplars(agg);
    agg2.remove();
  }
  frame.delete();
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:40,代码来源:AggregatorTest.java

示例7: testPubDev928

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
/** Load simple dataset, rebalance to a number of chunks > number of rows, and run deep learning */
@Test public void testPubDev928() {
  // Create rebalanced dataset
  Key rebalancedKey = Key.make("rebalanced");
  NFSFileVec nfs = TestUtil.makeNfsFileVec("smalldata/logreg/prostate.csv");
  Frame fr = ParseDataset.parse(Key.make(), nfs._key);
  RebalanceDataSet rb = new RebalanceDataSet(fr, rebalancedKey, (int)(fr.numRows()+1));
  H2O.submitTask(rb);
  rb.join();
  Frame rebalanced = DKV.get(rebalancedKey).get();

  // Assert that there is at least one 0-len chunk
  assertZeroLengthChunk("Rebalanced dataset should contain at least one 0-len chunk!", rebalanced.anyVec());

  DeepLearningModel dlModel = null;
  try {
    // Launch Deep Learning
    DeepLearningParameters dlParams = new DeepLearningParameters();
    dlParams._train = rebalancedKey;
    dlParams._epochs = 5;
    dlParams._response_column = "CAPSULE";

    dlModel = new DeepLearning(dlParams).trainModel().get();
  } finally {
    fr.delete();
    rebalanced.delete();
    if (dlModel != null) dlModel.delete();
  }
}
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:30,代码来源:DeepLearningScoreTest.java

示例8: serve

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Override public RequestBuilders.Response serve() {
  if( source==null ) throw new IllegalArgumentException("Missing frame to rebalance!");
  try {
    if (chunks > source.numRows()) throw new IllegalArgumentException("Cannot create more than " + source.numRows() + " chunks.");
    if( after==null ) after = Key.make(source._key.toString() + ".balanced");
    RebalanceDataSet rb = new RebalanceDataSet(source, after, chunks);
    H2O.submitTask(rb);
    rb.join();
    return RequestBuilders.Response.done(this);
  } catch( Throwable t ) {
    return RequestBuilders.Response.error(t);
  }
}
 
开发者ID:h2oai,项目名称:h2o-2,代码行数:14,代码来源:ReBalance.java

示例9: testReprodubility

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReprodubility() {
    Frame tfr=null;
    final int N = 5;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("smalldata/covtype/covtype.20k.data");

      // rebalance to 256 chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);

      for (int i=0; i<N; ++i) {
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "C55";
        parms._nbins = 1000;
        parms._ntrees = 1;
        parms._max_depth = 8;
        parms._learn_rate = 0.1f;
        parms._min_rows = 10;
//        parms._distribution = Family.multinomial;
        parms._distribution = Distribution.Family.gaussian;

        // Build a first model; all remaining models should be equal
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        assertEquals(gbm._output._ntrees, parms._ntrees);

        mses[i] = gbm._output._scored_train[gbm._output._scored_train.length-1]._mse;
        job.remove();
        gbm.delete();
      }
    } finally{
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    for( double mse : mses ) assertEquals(mse, mses[0], 1e-15);
  }
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:48,代码来源:GBMTest.java

示例10: testReprodubilityAirline

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReprodubilityAirline() {
    Frame tfr=null;
    final int N = 1;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");

      // rebalance to fixed number of chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);
      for (String s : new String[]{
              "DepTime", "ArrTime", "ActualElapsedTime",
              "AirTime", "ArrDelay", "DepDelay", "Cancelled",
              "CancellationCode", "CarrierDelay", "WeatherDelay",
              "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
      }) {
        tfr.remove(s).remove();
      }
      DKV.put(tfr);
      for (int i=0; i<N; ++i) {
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "IsDepDelayed";
        parms._nbins = 10;
        parms._nbins_cats = 500;
        parms._ntrees = 7;
        parms._max_depth = 5;
        parms._min_rows = 10;
        parms._distribution = Distribution.Family.bernoulli;
        parms._balance_classes = true;
        parms._seed = 0;

        // Build a first model; all remaining models should be equal
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        assertEquals(gbm._output._ntrees, parms._ntrees);

        mses[i] = gbm._output._scored_train[gbm._output._scored_train.length-1]._mse;
        job.remove();
        gbm.delete();
      }
    } finally {
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    for( double mse : mses )
      assertEquals(0.21979375165014595, mse, 1e-8); //check for the same result on 1 nodes and 5 nodes (will only work with enough chunks), mse, 1e-8); //check for the same result on 1 nodes and 5 nodes (will only work with enough chunks)
  }
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:58,代码来源:GBMTest.java

示例11: testReprodubilityAirlineSingleNode

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReprodubilityAirlineSingleNode() {
    Frame tfr=null;
    final int N = 1;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");

      // rebalance to fixed number of chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);
      for (String s : new String[]{
              "DepTime", "ArrTime", "ActualElapsedTime",
              "AirTime", "ArrDelay", "DepDelay", "Cancelled",
              "CancellationCode", "CarrierDelay", "WeatherDelay",
              "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
      }) {
        tfr.remove(s).remove();
      }
      DKV.put(tfr);
      for (int i=0; i<N; ++i) {
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "IsDepDelayed";
        parms._nbins = 10;
        parms._nbins_cats = 500;
        parms._ntrees = 7;
        parms._max_depth = 5;
        parms._min_rows = 10;
        parms._distribution = Distribution.Family.bernoulli;
        parms._balance_classes = true;
        parms._seed = 0;
        parms._build_tree_one_node = true;

        // Build a first model; all remaining models should be equal
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        assertEquals(gbm._output._ntrees, parms._ntrees);

        mses[i] = gbm._output._scored_train[gbm._output._scored_train.length-1]._mse;
        job.remove();
        gbm.delete();
      }
    } finally {
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    for( double mse : mses )
      assertEquals(0.21979375165014595, mse, 1e-8); //check for the same result on 1 nodes and 5 nodes (will only work with enough chunks)
  }
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:59,代码来源:GBMTest.java

示例12: testReproducibility

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReproducibility() {
    Frame tfr=null;
    final int N = 5;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("smalldata/covtype/covtype.20k.data");

      // rebalance to 256 chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);

      for (int i=0; i<N; ++i) {
        DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
        parms._train = tfr._key;
        parms._response_column = "C55";
        parms._nbins = 1000;
        parms._ntrees = 1;
        parms._max_depth = 8;
        parms._mtries = -1;
        parms._min_rows = 10;
        parms._seed = 1234;

        // Build a first model; all remaining models should be equal
        DRF job = new DRF(parms);
        DRFModel drf = job.trainModel().get();
        assertEquals(drf._output._ntrees, parms._ntrees);

        mses[i] = drf._output._scored_train[drf._output._scored_train.length-1]._mse;
        job.remove();
        drf.delete();
      }
    } finally{
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    for (int i=0; i<mses.length; ++i) {
      Log.info("trial: " + i + " -> MSE: " + mses[i]);
    }
    for(double mse : mses)
      assertEquals(mse, mses[0], 1e-15);
  }
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:51,代码来源:DRFTest.java

示例13: testReproducibilityAirline

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReproducibilityAirline() {
    Frame tfr=null;
    final int N = 1;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");

      // rebalance to fixed number of chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);
      for (String s : new String[]{
              "DepTime", "ArrTime", "ActualElapsedTime",
              "AirTime", "ArrDelay", "DepDelay", "Cancelled",
              "CancellationCode", "CarrierDelay", "WeatherDelay",
              "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
      }) {
        tfr.remove(s).remove();
      }
      DKV.put(tfr);
      for (int i=0; i<N; ++i) {
        DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
        parms._train = tfr._key;
        parms._response_column = "IsDepDelayed";
        parms._nbins = 10;
        parms._nbins_cats = 1024;
        parms._ntrees = 7;
        parms._max_depth = 10;
        parms._binomial_double_trees = false;
        parms._mtries = -1;
        parms._min_rows = 1;
        parms._sample_rate = 0.632f;   // Simulated sampling with replacement
        parms._balance_classes = true;
        parms._seed = (1L<<32)|2;

        // Build a first model; all remaining models should be equal
        DRF job = new DRF(parms);
        DRFModel drf = job.trainModel().get();
        assertEquals(drf._output._ntrees, parms._ntrees);

        mses[i] = drf._output._training_metrics.mse();
        job.remove();
        drf.delete();
      }
    } finally{
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    for (int i=0; i<mses.length; ++i) {
      Log.info("trial: " + i + " -> MSE: " + mses[i]);
    }
    for (int i=0; i<mses.length; ++i) {
      assertEquals(0.2148575516521361, mses[i], 1e-4); //check for the same result on 1 nodes and 5 nodes
    }
  }
 
开发者ID:kyoren,项目名称:https-github.com-h2oai-h2o-3,代码行数:64,代码来源:DRFTest.java

示例14: testReprodubilityAirline

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReprodubilityAirline() {
    Frame tfr=null;
    final int N = 5;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");

      // rebalance to fixed number of chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);
      for (String s : new String[]{
              "DepTime", "ArrTime", "ActualElapsedTime",
              "AirTime", "ArrDelay", "DepDelay", "Cancelled",
              "CancellationCode", "CarrierDelay", "WeatherDelay",
              "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
      }) {
        tfr.remove(s).remove();
      }
      DKV.put(tfr);
      for (int i=0; i<N; ++i) {
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "IsDepDelayed";
        parms._nbins = 10;
        parms._nbins_cats = 500;
        parms._ntrees = 7;
        parms._max_depth = 5;
        parms._min_rows = 10;
        parms._distribution = DistributionFamily.bernoulli;
        parms._balance_classes = true;
        parms._seed = 0;

        // Build a first model; all remaining models should be equal
        GBMModel gbm = new GBM(parms).trainModel().get();
        assertEquals(gbm._output._ntrees, parms._ntrees);

        mses[i] = gbm._output._scored_train[gbm._output._scored_train.length-1]._mse;
        gbm.delete();
      }
    } finally {
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    System.out.println("MSEs start");
    for(double d:mses)
      System.out.println(d);
    System.out.println("MSEs End");
    System.out.flush();
    for( double mse : mses )
      assertEquals(0.21694215729861027, mse, 1e-8); //check for the same result on 1 nodes and 5 nodes (will only work with enough chunks), mse, 1e-8); //check for the same result on 1 nodes and 5 nodes (will only work with enough chunks)
  }
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:61,代码来源:GBMTest.java

示例15: testReprodubilityAirlineSingleNode

import water.fvec.RebalanceDataSet; //导入依赖的package包/类
@Test public void testReprodubilityAirlineSingleNode() {
    Frame tfr=null;
    final int N = 10;
    double[] mses = new double[N];

    Scope.enter();
    try {
      // Load data, hack frames
      tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");

      // rebalance to fixed number of chunks
      Key dest = Key.make("df.rebalanced.hex");
      RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
      H2O.submitTask(rb);
      rb.join();
      tfr.delete();
      tfr = DKV.get(dest).get();
//      Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
//      DKV.put(tfr);
      for (String s : new String[]{
              "DepTime", "ArrTime", "ActualElapsedTime",
              "AirTime", "ArrDelay", "DepDelay", "Cancelled",
              "CancellationCode", "CarrierDelay", "WeatherDelay",
              "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
      }) {
        tfr.remove(s).remove();
      }
      DKV.put(tfr);
      for (int i=0; i<N; ++i) {
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "IsDepDelayed";
        parms._nbins = 10;
        parms._nbins_cats = 500;
        parms._ntrees = 7;
        parms._max_depth = 5;
        parms._min_rows = 10;
        parms._distribution = DistributionFamily.bernoulli;
        parms._balance_classes = true;
        parms._seed = 0;
        parms._build_tree_one_node = true;
        
        // Build a first model; all remaining models should be equal
        GBMModel gbm = new GBM(parms).trainModel().get();
        assertEquals(gbm._output._ntrees, parms._ntrees);

        mses[i] = gbm._output._scored_train[gbm._output._scored_train.length-1]._mse;
        gbm.delete();
      }
    } finally {
      if (tfr != null) tfr.remove();
    }
    Scope.exit();
    System.out.println("MSE");
    for(double d:mses)
      System.out.println(d);
    for( double mse : mses )
      assertEquals(0.21694215729861027, mse, 1e-8); //check for the same result on 1 nodes and 5 nodes (will only work with enough chunks)
  }
 
开发者ID:h2oai,项目名称:h2o-3,代码行数:60,代码来源:GBMTest.java


注:本文中的water.fvec.RebalanceDataSet类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。