当前位置: 首页>>代码示例>>Python>>正文


Python preprocessing.get_input_tensors方法代码示例

本文整理汇总了Python中preprocessing.get_input_tensors方法的典型用法代码示例。如果您正苦于以下问题:Python preprocessing.get_input_tensors方法的具体用法?Python preprocessing.get_input_tensors怎么用?Python preprocessing.get_input_tensors使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在preprocessing的用法示例。


在下文中一共展示了preprocessing.get_input_tensors方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: extract_data

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def extract_data(self, tf_record, filter_amount=1):
        pos_tensor, label_tensors = preprocessing.get_input_tensors(
            1, [tf_record], num_repeats=1, shuffle_records=False,
            shuffle_examples=False, filter_amount=filter_amount)
        recovered_data = []
        with tf.Session() as sess:
            while True:
                try:
                    pos_value, label_values = sess.run([pos_tensor, label_tensors])
                    recovered_data.append((
                        pos_value,
                        label_values['pi_tensor'],
                        label_values['value_tensor']))
                except tf.errors.OutOfRangeError:
                    break
        return recovered_data 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:18,代码来源:test_preprocessing.py

示例2: extract_data

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def extract_data(self, tf_record, filter_amount=1):
    pos_tensor, label_tensors = preprocessing.get_input_tensors(
        model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1,
        shuffle_records=False, shuffle_examples=False,
        filter_amount=filter_amount)
    recovered_data = []
    with tf.Session() as sess:
      while True:
        try:
          pos_value, label_values = sess.run([pos_tensor, label_tensors])
          recovered_data.append((
              pos_value,
              label_values['pi_tensor'],
              label_values['value_tensor']))
        except tf.errors.OutOfRangeError:
          break
    return recovered_data 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:19,代码来源:preprocessing_test.py

示例3: validate

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(*tf_records):
    """Validate a model's performance on a set of holdout data."""
    if FLAGS.use_tpu:
        def _input_fn(params):
            return preprocessing.get_tpu_input_tensors(
                params['train_batch_size'], params['input_layout'], tf_records,
                filter_amount=1.0)
    else:
        def _input_fn():
            return preprocessing.get_input_tensors(
                FLAGS.train_batch_size, FLAGS.input_layout, tf_records,
                filter_amount=1.0, shuffle_examples=False)

    steps = FLAGS.examples_to_validate // FLAGS.train_batch_size
    if FLAGS.use_tpu:
        steps //= FLAGS.num_tpu_cores

    estimator = dual_net.get_estimator()
    with utils.logged_timer("Validating"):
        estimator.evaluate(_input_fn, steps=steps, name=FLAGS.validate_name) 
开发者ID:mlperf,项目名称:training,代码行数:22,代码来源:validate.py

示例4: train

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(working_dir, tf_records, generation_num, **hparams):
    assert generation_num > 0, "Model 0 is random weights"
    estimator = get_estimator(working_dir, **hparams)
    print ("generations = ", generation_num)
    max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
    print ("max_steps = ", max_steps)

    def input_fn(): return preprocessing.get_input_tensors(
        TRAIN_BATCH_SIZE, tf_records)
    update_ratio_hook = UpdateRatioSessionHook(working_dir)
    print("Train with TRAIN_BATCH_SIZE=", TRAIN_BATCH_SIZE)
    estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:14,代码来源:dual_net.py

示例5: validate

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(working_dir, tf_records, checkpoint_name=None, **hparams):
    estimator = get_estimator(working_dir, **hparams)
    if checkpoint_name is None:
        checkpoint_name = estimator.latest_checkpoint()

    def input_fn(): return preprocessing.get_input_tensors(
        TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000,
        filter_amount=0.05)
    estimator.evaluate(input_fn, steps=1000) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:11,代码来源:dual_net.py

示例6: train

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(working_dir, tf_records, generation_num, params):
  """Train the model for a specific generation.

  Args:
    working_dir: The model working directory to save model parameters,
      drop logs, checkpoints, and so on.
    tf_records: A list of tf_record filenames for training input.
    generation_num: The generation to be trained.
    params: hyperparams of the model.

  Raises:
    ValueError: if generation_num is not greater than 0.
  """
  if generation_num <= 0:
    raise ValueError('Model 0 is random weights')
  estimator = tf.estimator.Estimator(
      dualnet_model.model_fn, model_dir=working_dir, params=params)
  max_steps = (generation_num * params.examples_per_generation
               // params.batch_size)
  profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600)

  def input_fn():
    return preprocessing.get_input_tensors(
        params, params.batch_size, tf_records)
  estimator.train(
      input_fn, hooks=[profiler_hook], max_steps=max_steps) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:28,代码来源:dualnet.py

示例7: validate

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(working_dir, tf_records, params):
  """Perform model validation on the hold out data.

  Args:
    working_dir: The model working directory.
    tf_records: A list of tf_records filenames for holdout data.
    params: hyperparams of the model.
  """
  estimator = tf.estimator.Estimator(
      dualnet_model.model_fn, model_dir=working_dir, params=params)
  def input_fn():
    return preprocessing.get_input_tensors(
        params, params.batch_size, tf_records, filter_amount=0.05)
  estimator.evaluate(input_fn, steps=1000) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:16,代码来源:dualnet.py

示例8: train

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(estimator_dir, tf_records, model_version, **kwargs):
    """
    Main training function for the PolicyValueNetwork
    Args:
        estimator_dir (str): Path to the estimator directory
        tf_records (list): A list of TFRecords from which we parse the training examples
        model_version (int): The version of the model
    """
    model = get_estimator(estimator_dir, **kwargs)
    logger.info("Training model version: {}".format(model_version))
    max_steps = model_version * GLOBAL_PARAMETER_STORE.EXAMPLES_PER_GENERATION // \
                GLOBAL_PARAMETER_STORE.TRAIN_BATCH_SIZE
    model.train(input_fn=lambda: preprocessing.get_input_tensors(list_tf_records=tf_records),
                max_steps=max_steps)
    logger.info("Trained model version: {}".format(model_version)) 
开发者ID:PacktPublishing,项目名称:Python-Reinforcement-Learning-Projects,代码行数:17,代码来源:network.py

示例9: validate

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def validate(estimator_dir, tf_records, checkpoint_path=None, **kwargs):
    model = get_estimator(estimator_dir, **kwargs)
    if checkpoint_path is None:
        checkpoint_path = model.latest_checkpoint()
    model.evaluate(input_fn=lambda: preprocessing.get_input_tensors(
        list_tf_records=tf_records,
        buffer_size=GLOBAL_PARAMETER_STORE.VALIDATION_BUFFER_SIZE),
                   steps=GLOBAL_PARAMETER_STORE.VALIDATION_NUMBER_OF_STEPS,
                   checkpoint_path=checkpoint_path) 
开发者ID:PacktPublishing,项目名称:Python-Reinforcement-Learning-Projects,代码行数:11,代码来源:network.py

示例10: extract_data

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def extract_data(self, tf_record, filter_amount=1, random_rotation=False):
        pos_tensor, label_tensors = preprocessing.get_input_tensors(
            1, [tf_record], num_repeats=1, shuffle_records=False,
            shuffle_examples=False, filter_amount=filter_amount,
            random_rotation=random_rotation)
        return self.get_data_tensors(pos_tensor, label_tensors) 
开发者ID:mlperf,项目名称:training,代码行数:8,代码来源:test_preprocessing.py

示例11: train

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import get_input_tensors [as 别名]
def train(working_dir, tf_records, generation, params):
  """Train the model for a specific generation.

  Args:
    working_dir: The model working directory to save model parameters,
      drop logs, checkpoints, and so on.
    tf_records: A list of tf_record filenames for training input.
    generation: The generation to be trained.
    params: hyperparams of the model.

  Raises:
    ValueError: if generation is not greater than 0.
  """
  if generation <= 0:
    raise ValueError('Model 0 is random weights')
  estimator = tf.estimator.Estimator(
      dualnet_model.model_fn, model_dir=working_dir, params=params)
  max_steps = (generation * params.examples_per_generation
               // params.batch_size)
  profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600)

  def input_fn():
    return preprocessing.get_input_tensors(
        params, params.batch_size, tf_records)
  estimator.train(
      input_fn, hooks=[profiler_hook], max_steps=max_steps) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:28,代码来源:dualnet.py


注:本文中的preprocessing.get_input_tensors方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。