Python tensorflow.Estimator方法代码示例

本文整理汇总了Python中tensorflow.Estimator方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.Estimator方法的具体用法?Python tensorflow.Estimator怎么用?Python tensorflow.Estimator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


示例1: construct_input_fn

def construct_input_fn(self, records, is_training):
    """Builds an estimator input_fn.

    The input_fn is used to pass feature and target data to the train,
    evaluate, and predict methods of the Estimator.

    Method to be overridden by implementations.

      records: A list of Strings, paths to TFRecords with image data.
      is_training: Boolean, whether or not we're training.

      Function, that has signature of ()->(dict of features, target).
        features is a dict mapping feature names to `Tensors`
        containing the corresponding feature data (typically, just a single
        key/value pair 'raw_data' -> image `Tensor` for TCN.
        labels is a 1-D int32 `Tensor` holding labels.

示例2: evaluate

def evaluate(self):
    """Runs `Estimator` validation.
    config = self._config

    # Get a list of validation tfrecords.
    validation_dir = config.data.validation
    validation_records = util.GetFilesRecursively(validation_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    validation_input_fn = self.construct_input_fn(
        validation_records, False)

    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Run validation.
    eval_batch_size = config.data.batch_size
    num_eval_samples = config.val.num_eval_samples
    num_eval_batches = int(num_eval_samples / eval_batch_size)
    estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches) 

示例3: _input_fn_inference

def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None):
    """Mode 1: tf.Estimator inference.

      input_fn: Function, that has signature of ()->(dict of features, None).
        This is a function called by the estimator to get input tensors (stored
        in the features dict) to do inference over.
      checkpoint_path: String, path to a specific checkpoint to restore.
      predict_keys: List of strings, the keys of the `Tensors` in the features
        dict (returned by the input_fn) to evaluate during inference.
      predictions: An Iterator, yielding evaluated values of `Tensors`
        specified in `predict_keys`.
    # Create the estimator.
    estimator = self._build_estimator(is_training=False)

    # Create an iterator of predicted embeddings.
    predictions = estimator.predict(input_fn=input_fn,
    return predictions 

示例4: export_model

def export_model(working_dir, model_path):
    """Take the latest checkpoint and export it to model_path for selfplay.

    Assumes that all relevant model files are prefixed by the same name.
    (For example, foo.index, foo.meta and foo.data-00000-of-00001).

        working_dir: The directory where tf.estimator keeps its checkpoints
        model_path: The path (can be a gs:// path) to export model to
    estimator = tf.estimator.Estimator(model_fn, model_dir=working_dir,
    latest_checkpoint = estimator.latest_checkpoint()
    all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*')
    for filename in all_checkpoint_files:
        suffix = filename.partition(latest_checkpoint)[2]
        destination_path = model_path + suffix
        print("Copying {} to {}".format(filename, destination_path))
        tf.gfile.Copy(filename, destination_path) 

示例5: bootstrap

def bootstrap():
    """Initialize a tf.Estimator run with random initial weights."""
    # a bit hacky - forge an initial checkpoint with the name that subsequent
    # Estimator runs will expect to find.
    # Estimator will do this automatically when you call train(), but calling
    # train() requires data, and I didn't feel like creating training data in
    # order to run the full train pipeline for 1 step.
    initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(FLAGS.work_dir, initial_checkpoint_name)
    sess = tf.Session(graph=tf.Graph())
    with sess.graph.as_default():
        features, labels = get_inference_input()
        model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
        tf.train.Saver().save(sess, save_file) 

示例6: export_model

def export_model(model_path):
    """Take the latest checkpoint and copy it to model_path.

    Assumes that all relevant model files are prefixed by the same name.
    (For example, foo.index, foo.meta and foo.data-00000-of-00001).

        model_path: The path (can be a gs:// path) to export model
    estimator = tf.estimator.Estimator(model_fn, model_dir=FLAGS.work_dir,
    latest_checkpoint = estimator.latest_checkpoint()
    all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*')
    for filename in all_checkpoint_files:
        suffix = filename.partition(latest_checkpoint)[2]
        destination_path = model_path + suffix
        print('Copying {} to {}'.format(filename, destination_path))
        tf.gfile.Copy(filename, destination_path) 

示例7: __init__

def __init__(self, model=None, atoms=None, to_eV=1.0,
                 properties=['energy', 'forces', 'stress']):
        """PiNN interface with ASE as a calculator

            model: tf.Estimator object
            atoms: optional, ase Atoms object
            properties: properties to calculate.
                the properties to calculate is fixed for each calculator,
                to avoid resetting the predictor during get_* calls.
        self.implemented_properties = properties
        self.model = model
        self.pbc = False
        self.atoms = atoms
        self.predictor = None
        self.to_eV = to_eV 

示例8: _verify_prefitting_model

def _verify_prefitting_model(prefitting_model, feature_names):
  """Checks that prefitting_model has the proper input layer."""
  if isinstance(prefitting_model, tf.keras.Model):
    layer_names = [layer.name for layer in prefitting_model.layers]
  elif isinstance(prefitting_model, tf.estimator.Estimator):
    layer_names = prefitting_model.get_variable_names()
    raise ValueError('Invalid model type for prefitting_model: {}'.format(
  for feature_name in feature_names:
    if isinstance(prefitting_model, tf.keras.Model):
      input_layer_name = '{}_{}'.format(INPUT_LAYER_NAME, feature_name)
      if input_layer_name not in layer_names:
        raise ValueError(
            'prefitting_model does not match prefitting_model_config. Make '
            'sure that prefitting_model is the proper type and constructed '
            'from the prefitting_model_config: {}'.format(
      pwl_input_layer_name = '{}_{}/{}'.format(
          CALIB_LAYER_NAME, feature_name,
      cat_input_layer_name = '{}_{}/{}'.format(
          CALIB_LAYER_NAME, feature_name,
      if (pwl_input_layer_name not in layer_names and
          cat_input_layer_name not in layer_names):
        raise ValueError(
            'prefitting_model does not match prefitting_model_config. Make '
            'sure that prefitting_model is the proper type and constructed '
            'from the prefitting_model_config: {}'.format(

示例9: _get_lattice_weights

def _get_lattice_weights(prefitting_model, lattice_index):
  """Gets the weights of the lattice at the specfied index."""
  if isinstance(prefitting_model, tf.keras.Model):
    lattice_layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, lattice_index)
    weights = tf.keras.backend.get_value(
    # We have already checked the types by this point, so if prefitting_model
    # is not a keras Model it must be an Estimator.
    lattice_kernel_variable_name = '{}_{}/{}'.format(
        LATTICE_LAYER_NAME, lattice_index, lattice_layer.LATTICE_KERNEL_NAME)
    weights = prefitting_model.get_variable_value(lattice_kernel_variable_name)
  return weights 

示例10: get_estimator

def get_estimator(working_dir, **hparams):
    hparams = get_default_hyperparams(**hparams)
    return tf.estimator.Estimator(

示例11: bootstrap

def bootstrap(working_dir, **hparams):
    """Initialize a tf.Estimator run with random initial weights.

        working_dir: The directory where tf.estimator will drop logs,
            checkpoints, and so on
        hparams: hyperparams of the model.
    hparams = get_default_hyperparams(**hparams)
    # a bit hacky - forge an initial checkpoint with the name that subsequent
    # Estimator runs will expect to find.
    # Estimator will do this automatically when you call train(), but calling
    # train() requires data, and I didn't feel like creating training data in
    # order to run the full train pipeline for 1 step.
    estimator_initial_checkpoint_name = 'model.ckpt-1'
    save_file = os.path.join(working_dir, estimator_initial_checkpoint_name)
    sess = tf.Session(graph=tf.Graph())
    with sess.graph.as_default():
        features, labels = get_inference_input()
        model_fn(features, labels, tf.estimator.ModeKeys.PREDICT, hparams)
        tf.train.Saver().save(sess, save_file)

    with open("./minigo.pbtxt", "w") as f:

示例12: main

def main(unused_argv):
  from official.transformer import transformer_main


  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")

  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file) 
