當前位置: 首頁>>代碼示例>>Python>>正文


Python io_ops.save_v2方法代碼示例

本文整理匯總了Python中tensorflow.python.ops.io_ops.save_v2方法的典型用法代碼示例。如果您正苦於以下問題:Python io_ops.save_v2方法的具體用法?Python io_ops.save_v2怎麽用?Python io_ops.save_v2使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.python.ops.io_ops的用法示例。


在下文中一共展示了io_ops.save_v2方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def main(argv):
  del argv  # Unused.

  dataset = getattr(tf.keras.datasets, FLAGS.dataset)
  (x_train, y_train), (x_test, y_test) = dataset.load_data()

  def wrap(val):
    dtype = tf.as_dtype(val.dtype)
    assert dtype != tf.string  # tf.string is not supported by py_func.
    return tf.py_func(lambda: val, [], dtype)

  out_prefix = FLAGS.out or os.path.join("/tmp", FLAGS.dataset, FLAGS.dataset)
  tf.logging.info("Save %s dataset to %s ckpt." %
                       (FLAGS.dataset, out_prefix))

  with tf.Session() as sess:
    sess.run(
        io_ops.save_v2(
            prefix=out_prefix,
            tensor_names=["x_train", "y_train", "x_test", "y_test"],
            shape_and_slices=[""] * 4,
            tensors=[wrap(x_train),
                     wrap(y_train),
                     wrap(x_test),
                     wrap(y_test)])) 
開發者ID:tensorflow,項目名稱:lingvo,代碼行數:27,代碼來源:keras2ckpt.py

示例2: FakeMnistData

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def FakeMnistData(tmpdir, train_size=60000, test_size=10000):
  """Fake Mnist data for unit tests."""
  data_path = os.path.join(tmpdir, 'ckpt')
  with tf.Graph().as_default():
    with tf.Session() as sess:
      x_train = tf.ones((train_size, 28, 28, 1), dtype=tf.uint8)
      y_train = tf.ones((train_size), dtype=tf.uint8)
      x_test = tf.ones((test_size, 28, 28, 1), dtype=tf.uint8)
      y_test = tf.ones((test_size), dtype=tf.uint8)
      sess.run(
          io_ops.save_v2(data_path, ['x_train', 'y_train', 'x_test', 'y_test'],
                         [''] * 4, [x_train, y_train, x_test, y_test]))
  return data_path 
開發者ID:tensorflow,項目名稱:lingvo,代碼行數:15,代碼來源:input_generator.py

示例3: _BuildSave

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def _BuildSave(self):
    """Builds save ops."""
    self._save_global_step = py_utils.GetGlobalStep()
    self._save_prefix = tf.strings.join([
        self._logdir_ph, "/ckpt-",
        tf.as_string(self._save_global_step, width=8, fill="0")
    ])
    self._save_op = io_ops.save_v2(
        prefix=self._save_prefix,
        tensor_names=[_VarKey(v) for v in self._vars],
        tensors=[v.read_value() for v in self._vars],
        shape_and_slices=[""] * len(self._vars)) 
開發者ID:tensorflow,項目名稱:lingvo,代碼行數:14,代碼來源:saver.py

示例4: WriteNpArrays

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def WriteNpArrays(file_prefix, nmap):
  """Writes a NestedMap of numpy arrays into a TF checkpoint.

  Args:
    file_prefix: A TF checkpoint filename prefix.
    nmap: A NestedMap of numpy arrays.
  """
  g = tf.Graph()
  with g.as_default():

    def Wrap(val):
      dtype = tf.as_dtype(val.dtype)
      assert dtype != tf.string  # tf.string is not supported by py_func.
      return tf.py_func(lambda: val, [], dtype)

    names, values = [], []
    for k, v in nmap.FlattenItems():
      names.append(k)
      assert isinstance(v, np.ndarray)
      values.append(Wrap(v))

    save = io_ops.save_v2(
        prefix=file_prefix,
        tensor_names=names,
        tensors=values,
        shape_and_slices=[""] * len(names))

  with tf.Session(graph=g) as sess:
    sess.run(save) 
開發者ID:tensorflow,項目名稱:lingvo,代碼行數:31,代碼來源:saver.py

示例5: save_op

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def save_op(self, filename_tensor, saveables):
    """Create an Op to save 'saveables'.

    This is intended to be overridden by subclasses that want to generate
    different Ops.

    Args:
      filename_tensor: String Tensor.
      saveables: A list of BaseSaverBuilder.SaveableObject objects.

    Returns:
      An Operation that save the variables.

    Raises:
      RuntimeError: (implementation detail) if "self._write_version" is an
        unexpected value.
    """
    # pylint: disable=protected-access
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in saveables:
      for spec in saveable.specs:
        tensor_names.append(spec.name)
        tensors.append(spec.tensor)
        tensor_slices.append(spec.slice_spec)
    if self._write_version == saver_pb2.SaverDef.V1:
      return io_ops._save(
          filename=filename_tensor,
          tensor_names=tensor_names,
          tensors=tensors,
          tensor_slices=tensor_slices)
    elif self._write_version == saver_pb2.SaverDef.V2:
      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
                            tensors)
    else:
      raise RuntimeError("Unexpected write_version: " + self._write_version)

  # pylint: disable=unused-argument 
開發者ID:ryfeus,項目名稱:lambda-packs,代碼行數:43,代碼來源:saver.py

示例6: save_op

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def save_op(self, filename_tensor, saveables):
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in saveables:
      for spec in saveable.specs:
        if spec.name.startswith('replicated_'):
          if spec.name.startswith('replicated_0') or 'avg' in spec.name:
            tensor_names.append('/'.join(spec.name.split('/')[1:]))
            tensors.append(spec.tensor)
            tensor_slices.append(spec.slice_spec)
        else:
          tensor_names.append(spec.name)
          tensors.append(spec.tensor)
          tensor_slices.append(spec.slice_spec)
    if self._write_version == saver_pb2.SaverDef.V1:
      return io_ops._save(
        filename=filename_tensor,
        tensor_names=tensor_names,
        tensors=tensors,
        tensor_slices=tensor_slices)
    elif self._write_version == saver_pb2.SaverDef.V2:
      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
                  tensors)
    else:
      raise RuntimeError("Unexpected write_version: " + self._write_version) 
開發者ID:medivhna,項目名稱:TF_Face_Toolbox,代碼行數:30,代碼來源:saver.py

示例7: save_op

# 需要導入模塊: from tensorflow.python.ops import io_ops [as 別名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 別名]
def save_op(self, filename_tensor, saveables):
    """Create an Op to save 'saveables'.

    This is intended to be overridden by subclasses that want to generate
    different Ops.

    Args:
      filename_tensor: String Tensor.
      saveables: A list of BaseSaverBuilder.SaveableObject objects.

    Returns:
      An Operation that save the variables.

    Raises:
      RuntimeError: (implementation detail) if "self._write_version" is an
        unexpected value.
    """
    # pylint: disable=protected-access
    tensor_names = []
    tensors = []
    tensor_slices = []
    for saveable in saveables:
      for spec in saveable.specs:
        tensor_names.append(spec.name)
        tensors.append(spec.tensor)
        tensor_slices.append(spec.slice_spec)

    if self._write_version == saver_pb2.SaverDef.V1:
      return io_ops._save(
          filename=filename_tensor,
          tensor_names=tensor_names,
          tensors=tensors,
          tensor_slices=tensor_slices)
    elif self._write_version == saver_pb2.SaverDef.V2:
      # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
      # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
      return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
                            tensors)
    else:
      raise RuntimeError("Unexpected write_version: " + self._write_version)

  # pylint: disable=unused-argument 
開發者ID:abhisuri97,項目名稱:auto-alt-text-lambda-api,代碼行數:44,代碼來源:saver.py


注:本文中的tensorflow.python.ops.io_ops.save_v2方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。