当前位置: 首页>>代码示例>>Python>>正文


Python checkpoint_state_pb2.CheckpointState方法代码示例

本文整理汇总了Python中tensorflow.python.training.checkpoint_state_pb2.CheckpointState方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint_state_pb2.CheckpointState方法的具体用法?Python checkpoint_state_pb2.CheckpointState怎么用?Python checkpoint_state_pb2.CheckpointState使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.training.checkpoint_state_pb2的用法示例。


在下文中一共展示了checkpoint_state_pb2.CheckpointState方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: ParseCheckpoint

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def ParseCheckpoint(checkpoint):
  """Parse a checkpoint file.

  Args:
    checkpoint: Path to checkpoint. The checkpoint is either a serialized
      CheckpointState proto or an actual checkpoint file.

  Returns:
    The path to an actual checkpoint file.
  """
  warnings.warn(
      "ParseCheckpoint is deprecated. "
      "Will be removed in DeepChem 1.4.", DeprecationWarning)
  with open(checkpoint) as f:
    try:
      cp = checkpoint_state_pb2.CheckpointState()
      text_format.Merge(f.read(), cp)
      return cp.model_checkpoint_path
    except text_format.ParseError:
      return checkpoint 
开发者ID:deepchem,项目名称:deepchem,代码行数:22,代码来源:utils.py

示例2: testParseCheckpoint

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def testParseCheckpoint(self):
    # parse CheckpointState proto
    with tempfile.NamedTemporaryFile(mode='w+') as f:
      cp = checkpoint_state_pb2.CheckpointState()
      cp.model_checkpoint_path = 'my-checkpoint'
      f.write(text_format.MessageToString(cp))
      f.file.flush()
      self.assertEqual(utils.ParseCheckpoint(f.name), 'my-checkpoint')
    # parse path to actual checkpoint
    with tempfile.NamedTemporaryFile(mode='w+') as f:
      f.write('This is not a CheckpointState proto.')
      f.file.flush()
      self.assertEqual(utils.ParseCheckpoint(f.name), f.name) 
开发者ID:deepchem,项目名称:deepchem,代码行数:15,代码来源:test_utils.py

示例3: _GetState

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _GetState(self):
    """Returns the latest checkpoint id."""
    state = CheckpointState()
    if file_io.file_exists(self._state_file):
      content = file_io.read_file_to_string(self._state_file)
      text_format.Merge(content, state)
    return state 
开发者ID:tensorflow,项目名称:lingvo,代码行数:9,代码来源:saver.py

示例4: _GetCheckpointFilename

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _GetCheckpointFilename(save_dir, latest_filename):
  """Returns a filename for storing the CheckpointState.

  Args:
    save_dir: The directory for saving and restoring checkpoints.
    latest_filename: Name of the file in 'save_dir' that is used
      to store the CheckpointState.

  Returns:
    The path of the file that contains the CheckpointState proto.
  """
  if latest_filename is None:
    latest_filename = "checkpoint"
  return os.path.join(save_dir, latest_filename) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:16,代码来源:saver.py

示例5: update_checkpoint_state

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def update_checkpoint_state(save_dir,
                            model_checkpoint_path,
                            all_model_checkpoint_paths=None,
                            latest_filename=None):
  """Updates the content of the 'checkpoint' file.

  This updates the checkpoint file containing a CheckpointState
  proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Raises:
    RuntimeError: If any of the model checkpoint paths conflict with the file
      containing CheckpointSate.
  """
  _update_checkpoint_state(
      save_dir=save_dir,
      model_checkpoint_path=model_checkpoint_path,
      all_model_checkpoint_paths=all_model_checkpoint_paths,
      latest_filename=latest_filename,
      save_relative_paths=False) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:31,代码来源:saver.py

示例6: update_checkpoint_state

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def update_checkpoint_state(save_dir,
                            model_checkpoint_path,
                            all_model_checkpoint_paths=None,
                            latest_filename=None):
  """Updates the content of the 'checkpoint' file.

  This updates the checkpoint file containing a CheckpointState
  proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Raises:
    RuntimeError: If the save paths conflict.
  """
  # Writes the "checkpoint" file for the coordinator for later restoration.
  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
  ckpt = generate_checkpoint_state_proto(
      save_dir,
      model_checkpoint_path,
      all_model_checkpoint_paths=all_model_checkpoint_paths)

  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
    raise RuntimeError("Save path '%s' conflicts with path used for "
                       "checkpoint state.  Please use a different save path." %
                       model_checkpoint_path)

  # Preventing potential read/write race condition by *atomically* writing to a
  # file.
  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
                                      text_format.MessageToString(ckpt)) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:40,代码来源:saver.py

示例7: _get_current_checkpoint

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _get_current_checkpoint(self):
        try:
            checkpoint_metadata_filepath = os.path.abspath(
                os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME))
            checkpoint = CheckpointState()
            if os.path.exists(checkpoint_metadata_filepath) == False:
                return None

            contents = open(checkpoint_metadata_filepath, 'r').read()
            text_format.Merge(contents, checkpoint)
            return checkpoint
        except Exception as e:
            print("Got exception while reading checkpoint metadata", e)
            raise e 
开发者ID:aws-samples,项目名称:aws-builders-fair-projects,代码行数:16,代码来源:s3_boto_data_store.py

示例8: test_checkpoint_contains_relative_paths

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def test_checkpoint_contains_relative_paths(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.EstimatorV2(
        model_dir=tmpdir, model_fn=model_fn_global_step_incrementer)
    est.train(dummy_input_fn, steps=5)

    checkpoint_file_content = file_io.read_file_to_string(
        os.path.join(tmpdir, 'checkpoint'))
    ckpt = checkpoint_state_pb2.CheckpointState()
    text_format.Merge(checkpoint_file_content, ckpt)
    self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
    # TODO(b/78461127): Please modify tests to not directly rely on names of
    # checkpoints.
    self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],
                        ckpt.all_model_checkpoint_paths) 
开发者ID:tensorflow,项目名称:estimator,代码行数:17,代码来源:estimator_test.py

示例9: generate_checkpoint_state_proto

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def generate_checkpoint_state_proto(save_dir,
                                    model_checkpoint_path,
                                    all_model_checkpoint_paths=None):
  """Generates a checkpoint state proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.

  Returns:
    CheckpointState proto with model_checkpoint_path and
    all_model_checkpoint_paths updated to either absolute paths or
    relative paths to the current save_dir.
  """
  if all_model_checkpoint_paths is None:
    all_model_checkpoint_paths = []

  if (not all_model_checkpoint_paths or
      all_model_checkpoint_paths[-1] != model_checkpoint_path):
    logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
                 model_checkpoint_path)
    all_model_checkpoint_paths.append(model_checkpoint_path)

  # Relative paths need to be rewritten to be relative to the "save_dir"
  # if model_checkpoint_path already contains "save_dir".
  if not os.path.isabs(save_dir):
    if not os.path.isabs(model_checkpoint_path):
      model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
    for i in range(len(all_model_checkpoint_paths)):
      p = all_model_checkpoint_paths[i]
      if not os.path.isabs(p):
        all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)

  coord_checkpoint_proto = CheckpointState(
      model_checkpoint_path=model_checkpoint_path,
      all_model_checkpoint_paths=all_model_checkpoint_paths)

  return coord_checkpoint_proto 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:44,代码来源:saver.py

示例10: _update_checkpoint_state

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _update_checkpoint_state(save_dir,
                             model_checkpoint_path,
                             all_model_checkpoint_paths=None,
                             latest_filename=None,
                             save_relative_paths=False):
  """Updates the content of the 'checkpoint' file.

  This updates the checkpoint file containing a CheckpointState
  proto.

  Args:
    save_dir: Directory where the model was saved.
    model_checkpoint_path: The checkpoint file.
    all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
      checkpoints, sorted from oldest to newest.  If this is a non-empty list,
      the last element must be equal to model_checkpoint_path.  These paths
      are also saved in the CheckpointState proto.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.
    save_relative_paths: If `True`, will write relative paths to the checkpoint
      state file.

  Raises:
    RuntimeError: If any of the model checkpoint paths conflict with the file
      containing CheckpointSate.
  """
  # Writes the "checkpoint" file for the coordinator for later restoration.
  coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
  if save_relative_paths:
    if os.path.isabs(model_checkpoint_path):
      rel_model_checkpoint_path = os.path.relpath(
          model_checkpoint_path, save_dir)
    else:
      rel_model_checkpoint_path = model_checkpoint_path
    rel_all_model_checkpoint_paths = []
    for p in all_model_checkpoint_paths:
      if os.path.isabs(p):
        rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
      else:
        rel_all_model_checkpoint_paths.append(p)
    ckpt = generate_checkpoint_state_proto(
        save_dir,
        rel_model_checkpoint_path,
        all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
  else:
    ckpt = generate_checkpoint_state_proto(
        save_dir,
        model_checkpoint_path,
        all_model_checkpoint_paths=all_model_checkpoint_paths)

  if coord_checkpoint_filename == ckpt.model_checkpoint_path:
    raise RuntimeError("Save path '%s' conflicts with path used for "
                       "checkpoint state.  Please use a different save path." %
                       model_checkpoint_path)

  # Preventing potential read/write race condition by *atomically* writing to a
  # file.
  file_io.atomic_write_string_to_file(coord_checkpoint_filename,
                                      text_format.MessageToString(ckpt)) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:61,代码来源:saver.py

示例11: get_checkpoint_state

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.

  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if file_io.file_exists(coord_checkpoint_filename):
      file_content = file_io.read_file_to_string(
          coord_checkpoint_filename)
      ckpt = CheckpointState()
      text_format.Merge(file_content, ckpt)
      if not ckpt.model_checkpoint_path:
        raise ValueError("Invalid checkpoint state loaded from %s",
                         checkpoint_dir)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(ckpt.model_checkpoint_path):
        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
                                                  ckpt.model_checkpoint_path)
      for i in range(len(ckpt.all_model_checkpoint_paths)):
        p = ckpt.all_model_checkpoint_paths[i]
        if not os.path.isabs(p):
          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except errors.OpError as e:
    # It's ok if the file cannot be read
    logging.warning("%s: %s", type(e).__name__, e)
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  except text_format.ParseError as e:
    logging.warning("%s: %s", type(e).__name__, e)
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:57,代码来源:saver.py

示例12: get_checkpoint_state

# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.

  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if file_io.file_exists(coord_checkpoint_filename):
      file_content = file_io.read_file_to_string(
          coord_checkpoint_filename).decode("utf-8")
      ckpt = CheckpointState()
      text_format.Merge(file_content, ckpt)
      if not ckpt.model_checkpoint_path:
        raise ValueError("Invalid checkpoint state loaded from %s",
                         checkpoint_dir)
      # For relative model_checkpoint_path and all_model_checkpoint_paths,
      # prepend checkpoint_dir.
      if not os.path.isabs(ckpt.model_checkpoint_path):
        ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
                                                  ckpt.model_checkpoint_path)
      for i in range(len(ckpt.all_model_checkpoint_paths)):
        p = ckpt.all_model_checkpoint_paths[i]
        if not os.path.isabs(p):
          ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
  except errors.OpError as e:
    # It's ok if the file cannot be read
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  except text_format.ParseError as e:
    logging.warning(str(e))
    logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
    return None
  finally:
    if f:
      f.close()
  return ckpt 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:57,代码来源:saver.py


注:本文中的tensorflow.python.training.checkpoint_state_pb2.CheckpointState方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。