當前位置: 首頁>>代碼示例>>Python>>正文


Python preprocessing.get_input_tensors方法代碼示例

本文整理匯總了Python中preprocessing.get_input_tensors方法的典型用法代碼示例。如果您正苦於以下問題:Python preprocessing.get_input_tensors方法的具體用法?Python preprocessing.get_input_tensors怎麽用?Python preprocessing.get_input_tensors使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在preprocessing的用法示例。


在下文中一共展示了preprocessing.get_input_tensors方法的11個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: extract_data

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def extract_data(self, tf_record, filter_amount=1):
        pos_tensor, label_tensors = preprocessing.get_input_tensors(
            1, [tf_record], num_repeats=1, shuffle_records=False,
            shuffle_examples=False, filter_amount=filter_amount)
        recovered_data = []
        with tf.Session() as sess:
            while True:
                try:
                    pos_value, label_values = sess.run([pos_tensor, label_tensors])
                    recovered_data.append((
                        pos_value,
                        label_values['pi_tensor'],
                        label_values['value_tensor']))
                except tf.errors.OutOfRangeError:
                    break
        return recovered_data 
開發者ID:mlperf,項目名稱:training_results_v0.5,代碼行數:18,代碼來源:test_preprocessing.py

示例2: extract_data

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def extract_data(self, tf_record, filter_amount=1):
    pos_tensor, label_tensors = preprocessing.get_input_tensors(
        model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1,
        shuffle_records=False, shuffle_examples=False,
        filter_amount=filter_amount)
    recovered_data = []
    with tf.Session() as sess:
      while True:
        try:
          pos_value, label_values = sess.run([pos_tensor, label_tensors])
          recovered_data.append((
              pos_value,
              label_values['pi_tensor'],
              label_values['value_tensor']))
        except tf.errors.OutOfRangeError:
          break
    return recovered_data 
開發者ID:itsamitgoel,項目名稱:Gun-Detector,代碼行數:19,代碼來源:preprocessing_test.py

示例3: validate

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def validate(*tf_records):
    """Validate a model's performance on a set of holdout data."""
    if FLAGS.use_tpu:
        def _input_fn(params):
            return preprocessing.get_tpu_input_tensors(
                params['train_batch_size'], params['input_layout'], tf_records,
                filter_amount=1.0)
    else:
        def _input_fn():
            return preprocessing.get_input_tensors(
                FLAGS.train_batch_size, FLAGS.input_layout, tf_records,
                filter_amount=1.0, shuffle_examples=False)

    steps = FLAGS.examples_to_validate // FLAGS.train_batch_size
    if FLAGS.use_tpu:
        steps //= FLAGS.num_tpu_cores

    estimator = dual_net.get_estimator()
    with utils.logged_timer("Validating"):
        estimator.evaluate(_input_fn, steps=steps, name=FLAGS.validate_name) 
開發者ID:mlperf,項目名稱:training,代碼行數:22,代碼來源:validate.py

示例4: train

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def train(working_dir, tf_records, generation_num, **hparams):
    assert generation_num > 0, "Model 0 is random weights"
    estimator = get_estimator(working_dir, **hparams)
    print ("generations = ", generation_num)
    max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE
    print ("max_steps = ", max_steps)

    def input_fn(): return preprocessing.get_input_tensors(
        TRAIN_BATCH_SIZE, tf_records)
    update_ratio_hook = UpdateRatioSessionHook(working_dir)
    print("Train with TRAIN_BATCH_SIZE=", TRAIN_BATCH_SIZE)
    estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps) 
開發者ID:mlperf,項目名稱:training_results_v0.5,代碼行數:14,代碼來源:dual_net.py

示例5: validate

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def validate(working_dir, tf_records, checkpoint_name=None, **hparams):
    estimator = get_estimator(working_dir, **hparams)
    if checkpoint_name is None:
        checkpoint_name = estimator.latest_checkpoint()

    def input_fn(): return preprocessing.get_input_tensors(
        TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000,
        filter_amount=0.05)
    estimator.evaluate(input_fn, steps=1000) 
開發者ID:mlperf,項目名稱:training_results_v0.5,代碼行數:11,代碼來源:dual_net.py

示例6: train

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def train(working_dir, tf_records, generation_num, params):
  """Train the model for a specific generation.

  Args:
    working_dir: The model working directory to save model parameters,
      drop logs, checkpoints, and so on.
    tf_records: A list of tf_record filenames for training input.
    generation_num: The generation to be trained.
    params: hyperparams of the model.

  Raises:
    ValueError: if generation_num is not greater than 0.
  """
  if generation_num <= 0:
    raise ValueError('Model 0 is random weights')
  estimator = tf.estimator.Estimator(
      dualnet_model.model_fn, model_dir=working_dir, params=params)
  max_steps = (generation_num * params.examples_per_generation
               // params.batch_size)
  profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600)

  def input_fn():
    return preprocessing.get_input_tensors(
        params, params.batch_size, tf_records)
  estimator.train(
      input_fn, hooks=[profiler_hook], max_steps=max_steps) 
開發者ID:itsamitgoel,項目名稱:Gun-Detector,代碼行數:28,代碼來源:dualnet.py

示例7: validate

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def validate(working_dir, tf_records, params):
  """Perform model validation on the hold out data.

  Args:
    working_dir: The model working directory.
    tf_records: A list of tf_records filenames for holdout data.
    params: hyperparams of the model.
  """
  estimator = tf.estimator.Estimator(
      dualnet_model.model_fn, model_dir=working_dir, params=params)
  def input_fn():
    return preprocessing.get_input_tensors(
        params, params.batch_size, tf_records, filter_amount=0.05)
  estimator.evaluate(input_fn, steps=1000) 
開發者ID:itsamitgoel,項目名稱:Gun-Detector,代碼行數:16,代碼來源:dualnet.py

示例8: train

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def train(estimator_dir, tf_records, model_version, **kwargs):
    """
    Main training function for the PolicyValueNetwork
    Args:
        estimator_dir (str): Path to the estimator directory
        tf_records (list): A list of TFRecords from which we parse the training examples
        model_version (int): The version of the model
    """
    model = get_estimator(estimator_dir, **kwargs)
    logger.info("Training model version: {}".format(model_version))
    max_steps = model_version * GLOBAL_PARAMETER_STORE.EXAMPLES_PER_GENERATION // \
                GLOBAL_PARAMETER_STORE.TRAIN_BATCH_SIZE
    model.train(input_fn=lambda: preprocessing.get_input_tensors(list_tf_records=tf_records),
                max_steps=max_steps)
    logger.info("Trained model version: {}".format(model_version)) 
開發者ID:PacktPublishing,項目名稱:Python-Reinforcement-Learning-Projects,代碼行數:17,代碼來源:network.py

示例9: validate

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def validate(estimator_dir, tf_records, checkpoint_path=None, **kwargs):
    model = get_estimator(estimator_dir, **kwargs)
    if checkpoint_path is None:
        checkpoint_path = model.latest_checkpoint()
    model.evaluate(input_fn=lambda: preprocessing.get_input_tensors(
        list_tf_records=tf_records,
        buffer_size=GLOBAL_PARAMETER_STORE.VALIDATION_BUFFER_SIZE),
                   steps=GLOBAL_PARAMETER_STORE.VALIDATION_NUMBER_OF_STEPS,
                   checkpoint_path=checkpoint_path) 
開發者ID:PacktPublishing,項目名稱:Python-Reinforcement-Learning-Projects,代碼行數:11,代碼來源:network.py

示例10: extract_data

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def extract_data(self, tf_record, filter_amount=1, random_rotation=False):
        pos_tensor, label_tensors = preprocessing.get_input_tensors(
            1, [tf_record], num_repeats=1, shuffle_records=False,
            shuffle_examples=False, filter_amount=filter_amount,
            random_rotation=random_rotation)
        return self.get_data_tensors(pos_tensor, label_tensors) 
開發者ID:mlperf,項目名稱:training,代碼行數:8,代碼來源:test_preprocessing.py

示例11: train

# 需要導入模塊: import preprocessing [as 別名]
# 或者: from preprocessing import get_input_tensors [as 別名]
def train(working_dir, tf_records, generation, params):
  """Train the model for a specific generation.

  Args:
    working_dir: The model working directory to save model parameters,
      drop logs, checkpoints, and so on.
    tf_records: A list of tf_record filenames for training input.
    generation: The generation to be trained.
    params: hyperparams of the model.

  Raises:
    ValueError: if generation is not greater than 0.
  """
  if generation <= 0:
    raise ValueError('Model 0 is random weights')
  estimator = tf.estimator.Estimator(
      dualnet_model.model_fn, model_dir=working_dir, params=params)
  max_steps = (generation * params.examples_per_generation
               // params.batch_size)
  profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600)

  def input_fn():
    return preprocessing.get_input_tensors(
        params, params.batch_size, tf_records)
  estimator.train(
      input_fn, hooks=[profiler_hook], max_steps=max_steps) 
開發者ID:generalized-iou,項目名稱:g-tensorflow-models,代碼行數:28,代碼來源:dualnet.py


注:本文中的preprocessing.get_input_tensors方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。