当前位置: 首页>>代码示例>>Python>>正文


Python checkpoint_utils.load_variable函数代码示例

本文整理汇总了Python中tensorflow.contrib.framework.python.framework.checkpoint_utils.load_variable函数的典型用法代码示例。如果您正苦于以下问题:Python load_variable函数的具体用法?Python load_variable怎么用?Python load_variable使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了load_variable函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_save_steps_saves_periodically

 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     monitor = learn.monitors.CheckpointSaver(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     monitor.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       self._run(monitor, 1, self.train_op, sess)
       self._run(monitor, 2, self.train_op, sess)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 3, self.train_op, sess)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 4, self.train_op, sess)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 5, self.train_op, sess)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:29,代码来源:monitors_test.py

示例2: test_save_steps_saves_periodically

 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:30,代码来源:basic_session_run_hooks_test.py

示例3: test_train_max_steps_is_not_incremental

  def test_train_max_steps_is_not_incremental(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=15)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(15, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py

示例4: test_train_skip_train_if_max_step_already_saved

  def test_train_skip_train_if_max_step_already_saved(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py

示例5: testGetTensor

 def testGetTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:12,代码来源:checkpoint_utils_test.py

示例6: print_tensors_in_checkpoint_file

def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
开发者ID:kadeng,项目名称:tensorflow,代码行数:25,代码来源:inspect_checkpoint.py

示例7: testNoTensor

 def testNoTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:7,代码来源:checkpoint_utils_test.py

示例8: test_save_secs_saves_periodically

  def test_save_secs_saves_periodically(self, mock_time):
    # Let's have a realistic start time
    current_time = 1484695987.209386

    with self.graph.as_default():
      mock_time.return_value = current_time
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_secs=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()

      with session_lib.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])

        mock_time.return_value = current_time
        mon_sess.run(self.train_op)  # Saved.

        mock_time.return_value = current_time + 0.5
        mon_sess.run(self.train_op)  # Not saved.

        self.assertEqual(1,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))

        # Simulate 2.5 seconds of sleep.
        mock_time.return_value = current_time + 2.5
        mon_sess.run(self.train_op)  # Saved.

        mock_time.return_value = current_time + 2.6
        mon_sess.run(self.train_op)  # Not saved.

        mock_time.return_value = current_time + 2.7
        mon_sess.run(self.train_op)  # Not saved.

        self.assertEqual(3,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))

        # Simulate 7.5 more seconds of sleep (10 seconds from start.
        mock_time.return_value = current_time + 10
        mon_sess.run(self.train_op)  # Saved.
        self.assertEqual(6,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:45,代码来源:basic_session_run_hooks_test.py

示例9: test_saves_when_saver_and_scaffold_both_missing

 def test_saves_when_saver_and_scaffold_both_missing(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=1)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:13,代码来源:basic_session_run_hooks_test.py

示例10: test_save_saves_at_end

 def test_save_saves_at_end(self):
   with self.graph.as_default():
     monitor = learn.monitors.CheckpointSaver(
         self.model_dir, save_secs=2, scaffold=self.scaffold)
     monitor.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       self._run(monitor, 1, self.train_op, sess)
       self._run(monitor, 2, self.train_op, sess)
       monitor.end(sess)
       self.assertEqual(2,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:14,代码来源:monitors_test.py

示例11: covariances

 def covariances(self):
   """Returns the covariances."""
   return checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
开发者ID:ivankreso,项目名称:tensorflow,代码行数:4,代码来源:gmm.py

示例12: clusters

 def clusters(self):
   """Returns cluster centers."""
   clusters = checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
   return np.squeeze(clusters, 1)
开发者ID:ivankreso,项目名称:tensorflow,代码行数:5,代码来源:gmm.py

示例13: weights

 def weights(self):
   """Returns the cluster weights."""
   return checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
开发者ID:Immexxx,项目名称:tensorflow,代码行数:4,代码来源:gmm.py

示例14: load_variable

def load_variable(checkpoint_dir, name):
  """See `tf.contrib.framework.load_variable`."""
  return checkpoint_utils.load_variable(checkpoint_dir, name)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:3,代码来源:checkpoints.py

示例15: testNoCheckpoints

 def testNoCheckpoints(self):
   checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:5,代码来源:checkpoint_utils_test.py


注:本文中的tensorflow.contrib.framework.python.framework.checkpoint_utils.load_variable函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。