当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.Estimator方法代码示例

本文整理汇总了Python中tensorflow.Estimator方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.Estimator方法的具体用法?Python tensorflow.Estimator怎么用?Python tensorflow.Estimator使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.Estimator方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: construct_input_fn

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

    Args:
      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

    Returns:
      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.
    """
    pass 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:22,代码来源:base_estimator.py

示例2: evaluate

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def evaluate(self):
    """Runs `Estimator` validation.
    """
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:26,代码来源:base_estimator.py

示例3: _input_fn_inference

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

    Args:
      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
    Returns:
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    """
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
                                    checkpoint_path=checkpoint_path,
                                    predict_keys=predict_keys)
    return predictions 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:24,代码来源:base_estimator.py

示例4: export_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def export_model(working_dir, model_path):
    """Take the latest checkpoint and export it to model_path for selfplay.

    Assumes that all relevant model files are prefixed by the same name.
    (For example, foo.index, foo.meta and foo.data-00000-of-00001).

    Args:
        working_dir: The directory where tf.estimator keeps its checkpoints
        model_path: The path (can be a gs:// path) to export model to
    """
    estimator = tf.estimator.Estimator(model_fn, model_dir=working_dir,
                                       params='ignored')
    latest_checkpoint = estimator.latest_checkpoint()
    all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*')
    for filename in all_checkpoint_files:
        suffix = filename.partition(latest_checkpoint)[2]
        destination_path = model_path + suffix
        print("Copying {} to {}".format(filename, destination_path))
        tf.gfile.Copy(filename, destination_path) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:21,代码来源:dual_net.py

示例5: bootstrap

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def bootstrap():
    """Initialize a tf.Estimator run with random initial weights."""
    # a bit hacky - forge an initial checkpoint with the name that subsequent
    # Estimator runs will expect to find.
    #
    # Estimator will do this automatically when you call train(), but calling
    # train() requires data, and I didn't feel like creating training data in
    # order to run the full train pipeline for 1 step.
    maybe_set_seed()
    initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(FLAGS.work_dir, initial_checkpoint_name)
    sess = tf.Session(graph=tf.Graph())
    with sess.graph.as_default():
        features, labels = get_inference_input()
        model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
                 params=FLAGS.flag_values_dict())
        sess.run(tf.global_variables_initializer())
        tf.train.Saver().save(sess, save_file) 
开发者ID:mlperf,项目名称:training,代码行数:20,代码来源:dual_net.py

示例6: export_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def export_model(model_path):
    """Take the latest checkpoint and copy it to model_path.

    Assumes that all relevant model files are prefixed by the same name.
    (For example, foo.index, foo.meta and foo.data-00000-of-00001).

    Args:
        model_path: The path (can be a gs:// path) to export model
    """
    estimator = tf.estimator.Estimator(model_fn, model_dir=FLAGS.work_dir,
                                       params=FLAGS.flag_values_dict())
    latest_checkpoint = estimator.latest_checkpoint()
    all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*')
    for filename in all_checkpoint_files:
        suffix = filename.partition(latest_checkpoint)[2]
        destination_path = model_path + suffix
        print('Copying {} to {}'.format(filename, destination_path))
        tf.gfile.Copy(filename, destination_path) 
开发者ID:mlperf,项目名称:training,代码行数:20,代码来源:dual_net.py

示例7: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def __init__(self, model=None, atoms=None, to_eV=1.0,
                 properties=['energy', 'forces', 'stress']):
        """PiNN interface with ASE as a calculator

        Args:
            model: tf.Estimator object
            atoms: optional, ase Atoms object
            properties: properties to calculate.
                the properties to calculate is fixed for each calculator,
                to avoid resetting the predictor during get_* calls.
        """
        Calculator.__init__(self)
        self.implemented_properties = properties
        self.model = model
        self.pbc = False
        self.atoms = atoms
        self.predictor = None
        self.to_eV = to_eV 
开发者ID:Teoroo-CMC,项目名称:PiNN,代码行数:20,代码来源:calculator.py

示例8: _verify_prefitting_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def _verify_prefitting_model(prefitting_model, feature_names):
  """Checks that prefitting_model has the proper input layer."""
  if isinstance(prefitting_model, tf.keras.Model):
    layer_names = [layer.name for layer in prefitting_model.layers]
  elif isinstance(prefitting_model, tf.estimator.Estimator):
    layer_names = prefitting_model.get_variable_names()
  else:
    raise ValueError('Invalid model type for prefitting_model: {}'.format(
        type(prefitting_model)))
  for feature_name in feature_names:
    if isinstance(prefitting_model, tf.keras.Model):
      input_layer_name = '{}_{}'.format(INPUT_LAYER_NAME, feature_name)
      if input_layer_name not in layer_names:
        raise ValueError(
            'prefitting_model does not match prefitting_model_config. Make '
            'sure that prefitting_model is the proper type and constructed '
            'from the prefitting_model_config: {}'.format(
                type(prefitting_model)))
    else:
      pwl_input_layer_name = '{}_{}/{}'.format(
          CALIB_LAYER_NAME, feature_name,
          pwl_calibration_layer.PWL_CALIBRATION_KERNEL_NAME)
      cat_input_layer_name = '{}_{}/{}'.format(
          CALIB_LAYER_NAME, feature_name,
          categorical_calibration_layer.CATEGORICAL_CALIBRATION_KERNEL_NAME)
      if (pwl_input_layer_name not in layer_names and
          cat_input_layer_name not in layer_names):
        raise ValueError(
            'prefitting_model does not match prefitting_model_config. Make '
            'sure that prefitting_model is the proper type and constructed '
            'from the prefitting_model_config: {}'.format(
                type(prefitting_model))) 
开发者ID:tensorflow,项目名称:lattice,代码行数:34,代码来源:premade_lib.py

示例9: _get_lattice_weights

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def _get_lattice_weights(prefitting_model, lattice_index):
  """Gets the weights of the lattice at the specfied index."""
  if isinstance(prefitting_model, tf.keras.Model):
    lattice_layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, lattice_index)
    weights = tf.keras.backend.get_value(
        prefitting_model.get_layer(lattice_layer_name).weights[0])
  else:
    # We have already checked the types by this point, so if prefitting_model
    # is not a keras Model it must be an Estimator.
    lattice_kernel_variable_name = '{}_{}/{}'.format(
        LATTICE_LAYER_NAME, lattice_index, lattice_layer.LATTICE_KERNEL_NAME)
    weights = prefitting_model.get_variable_value(lattice_kernel_variable_name)
  return weights 
开发者ID:tensorflow,项目名称:lattice,代码行数:15,代码来源:premade_lib.py

示例10: get_estimator

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def get_estimator(working_dir, **hparams):
    hparams = get_default_hyperparams(**hparams)
    return tf.estimator.Estimator(
        model_fn,
        model_dir=working_dir,
        params=hparams) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:8,代码来源:dual_net.py

示例11: bootstrap

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def bootstrap(working_dir, **hparams):
    """Initialize a tf.Estimator run with random initial weights.

    Args:
        working_dir: The directory where tf.estimator will drop logs,
            checkpoints, and so on
        hparams: hyperparams of the model.
    """
    hparams = get_default_hyperparams(**hparams)
    # a bit hacky - forge an initial checkpoint with the name that subsequent
    # Estimator runs will expect to find.
    #
    # Estimator will do this automatically when you call train(), but calling
    # train() requires data, and I didn't feel like creating training data in
    # order to run the full train pipeline for 1 step.
    estimator_initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(working_dir, estimator_initial_checkpoint_name)
    sess = tf.Session(graph=tf.Graph())
    with sess.graph.as_default():
        features, labels = get_inference_input()
        model_fn(features, labels, tf.estimator.ModeKeys.PREDICT, hparams)
        sess.run(tf.global_variables_initializer())
        tf.train.Saver().save(sess, save_file)

    with open("./minigo.pbtxt", "w") as f:
        f.write(str(sess.graph.as_graph_def())) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:28,代码来源:dual_net.py

示例12: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Estimator [as 别名]
def main(unused_argv):
  from official.transformer import transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file) 
开发者ID:PipelineAI,项目名称:models,代码行数:40,代码来源:translate.py


注:本文中的tensorflow.Estimator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。