当前位置: 首页>>代码示例>>Python>>正文


Python checkpoint_utils.init_from_checkpoint函数代码示例

本文整理汇总了Python中tensorflow.python.training.checkpoint_utils.init_from_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python init_from_checkpoint函数的具体用法?Python init_from_checkpoint怎么用?Python init_from_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了init_from_checkpoint函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testInitFromCheckpoint

  def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable("my1", [1, 10])
          with variable_scope.variable_scope("some_other_scope"):
            my2 = variable_scope.get_variable("my2", [10, 10])
            with variable_scope.variable_scope("other_useful_scope"):
              my4 = variable_scope.get_variable("var4", [9, 9])
        my3 = variable_scope.get_variable("my3", [100, 100])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var1": "some_scope/my1",
            "useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
        })
        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var2": "some_scope/some_other_scope/my2",
            "var3": my3,
        })

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 29000)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:33,代码来源:checkpoint_utils_test.py

示例2: testInitialValueComesFromCheckpoint

  def testInitialValueComesFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope(
            "some_scope", initializer=init_ops.zeros_initializer()):
          my1 = variable_scope.get_variable("my1", [1, 10])

        # At this point, my1.initialized_value() will add ops that reference
        # the zeros initializer of my1.
        before = variables.Variable(my1.initialized_value(), name="before")

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        # At this point, my1.initialized_value() will add ops that reference
        # the newly set initializer of my1.
        after = variables.Variable(my1.initialized_value(), name="after")

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(session.run(my1), v1)
        self.assertAllEqual(session.run(my1.initialized_value()), v1)
        self.assertAllClose(session.run(before), [[0.0] * 10])
        self.assertAllClose(session.run(after), v1)
        with self.assertRaises(AssertionError):
          self.assertAllClose(session.run(before), session.run(after))
开发者ID:QiangCai,项目名称:tensorflow,代码行数:29,代码来源:checkpoint_utils_test.py

示例3: _warm_start_var

def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
  """
  if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
    current_var_name = _infer_var_name([var])
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
    current_var_name = _infer_var_name(var)
  elif isinstance(var, variables_lib.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name
  checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
开发者ID:AnishShah,项目名称:tensorflow,代码行数:33,代码来源:warm_starting_util.py

示例4: testInitialValueComesFromCheckpoint

  def testInitialValueComesFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope(
            "some_scope", initializer=init_ops.zeros_initializer()):
          my1 = variable_scope.get_variable("my1", [1, 10])

        before = my1.initialized_value()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        after = my1.initialized_value()

        self.assertAllEqual(session.run(before), [[0.0] * 10])
        self.assertAllEqual(session.run(after), v1)

        session.run(variables.global_variables_initializer())

        self.assertAllEqual(session.run(my1), v1)
        self.assertAllEqual(session.run(my1.initialized_value()), v1)
        self.assertAllClose(session.run(before), v1)
        self.assertAllClose(session.run(after), v1)
        with self.assertRaises(AssertionError):
          self.assertAllClose(v1, [[0.0] * 10])
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:checkpoint_utils_test.py

示例5: testNoAdditionalReadOpsForResourceVariables

  def testNoAdditionalReadOpsForResourceVariables(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.session(graph=g) as session:
        my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1")

        with ops.name_scope("init_from_checkpoint"):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        # Basic sanity checks:
        session.run(variables.global_variables_initializer())
        self.assertAllEqual(session.run(my1), v1)

    ops_in_init_from_checkpoint_scope = [
        op for op in g.get_operations()
        if (op.name.startswith("init_from_checkpoint/") and
            not op.name.startswith("init_from_checkpoint/checkpoint_initializer"
                                  ) and
            op.type != "AssignVariableOp" and
            op.type != "Identity")
    ]
    self.assertEqual(ops_in_init_from_checkpoint_scope, [])
开发者ID:clsung,项目名称:tensorflow,代码行数:26,代码来源:checkpoint_utils_test.py

示例6: init_and_verify

 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": "new_var1",
   })
   with self.test_session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:8,代码来源:checkpoint_utils_test.py

示例7: testRestoreRunsOnSameDevice

  def testRestoreRunsOnSameDevice(self):
    checkpoint_dir = self.get_temp_dir()
    with self.cached_session() as session:
      _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default():
      with ops.device("/job:ps"):
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:12,代码来源:checkpoint_utils_test.py

示例8: init_and_verify

 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   # Use string add to create new object in each replica
   prefix = "new_"
   suffix = "var1"
   new_var1 = prefix + suffix
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": new_var1,
   })
   with self.test_session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
开发者ID:jackd,项目名称:tensorflow,代码行数:12,代码来源:checkpoint_utils_test.py

示例9: init_and_verify

 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   v2 = variable_scope.get_variable(
       "new_var2", [10, 10],
       synchronization=variable_scope.VariableSynchronization.ON_READ,
       aggregation=variable_scope.VariableAggregation.MEAN)
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": "new_var1",
       "var2": "new_var2"
   })
   with self.session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
     self.assertAllEqual(v2_value, self.evaluate(v2))
开发者ID:becster,项目名称:tensorflow,代码行数:14,代码来源:checkpoint_utils_test.py

示例10: testRestoreRunsOnSameDevice

  def testRestoreRunsOnSameDevice(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default():
      with ops.device("/job:ps"):
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
      # initializer runs on the same task but always on CPU.
      self.assertEqual(my4._initializer_op.op.inputs[1].device,
                       "/job:ps/device:CPU:0")
开发者ID:QiangCai,项目名称:tensorflow,代码行数:15,代码来源:checkpoint_utils_test.py

示例11: testInitWithScopeDoesNotCaptureSuffixes

  def testInitWithScopeDoesNotCaptureSuffixes(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default() as g:
      with variable_scope.variable_scope("useful_scope"):
        my4 = variable_scope.get_variable("var4", [9, 9])
      with variable_scope.variable_scope("useful_scope_1"):
        my5_init = [[1.0, 2.0], [3.0, 4.0]]
        my5 = variable_scope.get_variable("var5", initializer=my5_init)

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
      with self.test_session(graph=g) as session:
        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my4.eval(session), v4)
        self.assertAllEqual(my5.eval(session), my5_init)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:18,代码来源:checkpoint_utils_test.py

示例12: testInitToRootCheckpoint

  def testInitToRootCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        my1 = variable_scope.get_variable("var1", [1, 10])
        my2 = variable_scope.get_variable("var2", [10, 10])
        my3 = variable_scope.get_variable("var3", [100, 100])
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"/": "/",})

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:22,代码来源:checkpoint_utils_test.py

示例13: testInitFromCheckpointMissing

  def testInitFromCheckpointMissing(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          _ = variable_scope.get_variable("my1", [10, 10])
          _ = variable_scope.get_variable(
              "my2", [1, 10],
              dtype=dtypes.int64,
              initializer=init_ops.zeros_initializer())

        # No directory.
        with self.assertRaises(errors_impl.OpError):
          checkpoint_utils.init_from_checkpoint("no_dir",
                                                {"var1": "some_scope/my1"})

        # No variable in checkpoint.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"no_var": "some_scope/my1"})

        # No variable in the graph.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var3": "some_scope/no_var"})

        # Shape mismatch.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var1": "some_scope/my1"})

        # Variable 'my1' and 'my2' are missing in given checkpoint scope.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(
              checkpoint_dir, {"useful_scope/": "some_scope/"})

        # Mapping is not to scope name.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"useful_scope": "some_scope/"})
开发者ID:QiangCai,项目名称:tensorflow,代码行数:44,代码来源:checkpoint_utils_test.py

示例14: testInitFromPartitionVar

  def testInitFromPartitionVar(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1 = _create_partition_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.zeros_initializer(),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()
        # Create another variable with different partitions than the variable in
        # the checkpoint.
        with variable_scope.variable_scope("some_other_scope"):
          my2 = variable_scope.get_variable(
              name="var1",
              shape=[100, 100],
              initializer=init_ops.zeros_initializer(),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=16 << 10))
          my2_var_list = my2._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "scope/var1": "some_scope/my1",
            "scope/": "some_other_scope/"})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
        my2_values = session.run(my2_var_list)
        # Verify we created different number of partitions.
        self.assertNotEquals(len(my2_values), len(v1))
        # Verify the values were correctly initialized inspite of different
        # partitions.
        full_my2_values = np.concatenate(my2_values, axis=0)
        full_v1_values = np.concatenate(v1, axis=0)
        self.assertAllEqual(full_my2_values, full_v1_values)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.truncated_normal_initializer(0.5),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"scope/var1": my1_var_list,})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:61,代码来源:checkpoint_utils_test.py

示例15: warm_start


#.........这里部分代码省略.........
      Defaults to `'.*'`, which warm-starts all variables in the
      TRAINABLE_VARIABLES collection.  Note that this excludes variables such
      as accumulators and moving statistics from batch norm.
    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
      `tf.estimator.VocabInfo`. The variable names should be "full" variables,
      not the names of the partitions.  If not explicitly provided, the variable
      is assumed to have no (changes to) vocabulary.
    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
      name of the previously-trained variable in `ckpt_to_initialize_from`. If
      not explicitly provided, the name of the variable is assumed to be same
      between previous checkpoint and current model.  Note that this has no
      effect on the set of variables that is warm-started, and only controls
      name mapping (use `vars_to_warm_start` for controlling what variables to
      warm-start).
  Raises:
    ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
      configuration for variable names that are not used.  This is to ensure
      a stronger check for variable configuration than relying on users to
      examine the logs.
  """
  if var_name_to_vocab_info is None:
    var_name_to_vocab_info = {}
  if var_name_to_prev_var_name is None:
    var_name_to_prev_var_name = {}
  logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
  grouped_variables = _get_grouped_variables(vars_to_warm_start)

  # Keep track of which var_names in var_name_to_prev_var_name and
  # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
  # exception if any are unused by the end of the loop.  It is easy to misname
  # a variable during this configuration, in which case without this check, we
  # would fail to warm-start silently.
  prev_var_name_used = set()
  vocab_info_used = set()

  # Group the vocabless vars into one call to init_from_checkpoint.
  vocabless_vars = {}
  for var_name, variable in six.iteritems(grouped_variables):
    prev_var_name = var_name_to_prev_var_name.get(var_name)
    if prev_var_name:
      prev_var_name_used.add(var_name)
    vocab_info = var_name_to_vocab_info.get(var_name)
    if vocab_info:
      vocab_info_used.add(var_name)
      logging.info(
          "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
          " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
          " initializer: {}".format(
              var_name,
              vocab_info.new_vocab,
              vocab_info.new_vocab_size,
              vocab_info.old_vocab,
              (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0
               else "All"),
              vocab_info.num_oov_buckets,
              prev_var_name or "Unchanged",
              vocab_info.backup_initializer or "zero-initialized"))
      _warm_start_var_with_vocab(
          variable,
          current_vocab_path=vocab_info.new_vocab,
          current_vocab_size=vocab_info.new_vocab_size,
          prev_ckpt=ckpt_to_initialize_from,
          prev_vocab_path=vocab_info.old_vocab,
          previous_vocab_size=vocab_info.old_vocab_size,
          current_oov_buckets=vocab_info.num_oov_buckets,
          prev_tensor_name=prev_var_name,
          initializer=vocab_info.backup_initializer,
          axis=vocab_info.axis)
    else:
      # For the special value of vars_to_warm_start = None,
      # we only warm-start variables with explicitly specified vocabularies.
      if vars_to_warm_start:
        logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
            var_name, prev_var_name or "Unchanged"))
        # Because we use a default empty list in grouped_variables, single
        # unpartitioned variables will be lists here, which we rectify in order
        # for init_from_checkpoint logic to work correctly.
        if len(variable) == 1:
          variable = variable[0]
        prev_tensor_name, var = _get_var_info(variable, prev_var_name)
        vocabless_vars[prev_tensor_name] = var

  checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars)
  prev_var_name_not_used = set(
      var_name_to_prev_var_name.keys()) - prev_var_name_used
  vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used

  if prev_var_name_not_used:
    raise ValueError(
        "You provided the following variables in "
        "var_name_to_prev_var_name that were not used: "
        "{0}.  Perhaps you misspelled them?  Here is the list of viable "
        "variable names: {1}".format(prev_var_name_not_used,
                                     grouped_variables.keys()))
  if vocab_info_not_used:
    raise ValueError(
        "You provided the following variables in "
        "var_name_to_vocab_info that were not used: {0}. "
        " Perhaps you misspelled them?  Here is the list of viable variable "
        "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:101,代码来源:warm_starting_util.py


注:本文中的tensorflow.python.training.checkpoint_utils.init_from_checkpoint函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。