当前位置: 首页>>代码示例>>Python>>正文


Python util.list_objects函数代码示例

本文整理汇总了Python中tensorflow.python.training.checkpointable.util.list_objects函数的典型用法代码示例。如果您正苦于以下问题:Python list_objects函数的具体用法?Python list_objects怎么用?Python list_objects使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了list_objects函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testNestedLists

 def testNestedLists(self):
   a = tracking.Checkpointable()
   a.l = []
   b = tracking.Checkpointable()
   a.l.append([b])
   c = tracking.Checkpointable()
   a.l[0].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   a.l[0].append(1)
   d = tracking.Checkpointable()
   a.l[0].append(d)
   a_deps = util.list_objects(a)
   self.assertIn(d, a_deps)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertNotIn(1, a_deps)
   e = tracking.Checkpointable()
   f = tracking.Checkpointable()
   a.l1 = [[], [e]]
   a.l1[0].append(f)
   a_deps = util.list_objects(a)
   self.assertIn(e, a_deps)
   self.assertIn(f, a_deps)
   checkpoint = util.Checkpoint(a=a)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   a.l[0].append(data_structures.NoDependency([]))
   a.l[0][-1].append(5)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   # Dirtying the inner list means the root object is unsaveable.
   a.l[0][1] = 2
   with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:34,代码来源:tracking_test.py

示例2: get_non_optimizer_objects

 def get_non_optimizer_objects(m, g):
   """Gather set of model and optimizer checkpointable objects."""
   # Set default graph because optimizer.variables() returns optimizer
   # variables defined in the default graph.
   with g.as_default():
     all_objects = set(checkpointable_utils.list_objects(m))
     optimizer_and_variables = set()
     for obj in all_objects:
       if isinstance(obj, optimizers.TFOptimizer):
         optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
         optimizer_and_variables.update(set(obj.optimizer.variables()))
     return all_objects - optimizer_and_variables
开发者ID:AnishShah,项目名称:tensorflow,代码行数:12,代码来源:keras_saved_model.py

示例3: testAddVariableOverwrite

 def testAddVariableOverwrite(self):
   root = base.Checkpointable()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:16,代码来源:base_test.py

示例4: testDictWrapperNoDependency

 def testDictWrapperNoDependency(self):
   a = tracking.Checkpointable()
   a.d = data_structures.NoDependency({})
   a.d[1] = [3]
   self.assertEqual([a], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:10,代码来源:data_structures_test.py

示例5: testNonStringKeyNotCheckpointableValue

 def testNonStringKeyNotCheckpointableValue(self):
   a = tracking.Checkpointable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = data_structures.NoDependency([3])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:data_structures_test.py

示例6: testListBasic

 def testListBasic(self):
   a = tracking.Checkpointable()
   b = tracking.Checkpointable()
   a.l = [b]
   c = tracking.Checkpointable()
   a.l.append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   direct_a_dep, = a._checkpoint_dependencies
   self.assertEqual("l", direct_a_dep.name)
   self.assertIn(b, direct_a_dep.ref)
   self.assertIn(c, direct_a_dep.ref)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:13,代码来源:tracking_test.py

示例7: testShallowCopyCheckpointable

 def testShallowCopyCheckpointable(self):
   original = tracking.Checkpointable()
   original_sub = tracking.Checkpointable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   shallow_copied = copy.copy(original)
   self.assertIs(original_sub, shallow_copied.b["a"])
   self.assertIsNot(original, shallow_copied)
   self.assertEqual([[1.]], shallow_copied.a)
   shallow_deps = util.list_objects(shallow_copied)
   self.assertIn(shallow_copied.a, shallow_deps)
   self.assertIn(shallow_copied.b, shallow_deps)
   self.assertIn(shallow_copied.b["a"], shallow_deps)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:13,代码来源:data_structures_test.py

示例8: test_checkpointable_save_restore

  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(
        self,
        [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
        checkpointable_utils.list_objects(save_template))
    manual_dep, = manual_scope._checkpoint_dependencies
    self.assertEqual("in_manual_scope", manual_dep.name)
    self.assertIs(manual_scope_v, manual_dep.ref)
    optimizer = adam.AdamOptimizer(0.0)
    save_root = checkpointable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in save_template.variables])
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = checkpointable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(3, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    self.assertEqual("ManualScope",
                     load_template._checkpoint_dependencies[2].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:49,代码来源:util_with_v1_optimizers_test.py

示例9: testDeepCopyCheckpointable

 def testDeepCopyCheckpointable(self):
   original = tracking.Checkpointable()
   original_sub = tracking.Checkpointable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   deep_copied = copy.deepcopy(original)
   self.assertIsNot(original, deep_copied)
   self.assertIsNot(original_sub, deep_copied.b["a"])
   self.assertEqual([[1.]], deep_copied.a)
   self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
   deps = util.list_objects(deep_copied)
   self.assertIn(deep_copied.a, deps)
   self.assertIn(deep_copied.b, deps)
   self.assertIn(deep_copied.b["a"], deps)
   self.assertNotIn(original_sub, deps)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:15,代码来源:data_structures_test.py

示例10: test_checkpointable_dependencies

  def test_checkpointable_dependencies(self):
    rnn = keras.layers.SimpleRNN
    with self.test_session():
      x = np.random.random((2, 2, 2))
      y = np.random.random((2, 2))
      model = keras.models.Sequential()
      model.add(rnn(2))
      model.compile(optimizer='rmsprop', loss='mse')
      model.fit(x, y, epochs=1, batch_size=1)

      # check whether the model variables are present in the
      # checkpointable list of objects
      checkpointed_objects = set(checkpointable_util.list_objects(model))
      for v in model.variables:
        self.assertIn(v, checkpointed_objects)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:15,代码来源:recurrent_test.py

示例11: test_timedistributed_dense

  def test_timedistributed_dense(self):
    model = keras.models.Sequential()
    model.add(
        keras.layers.TimeDistributed(
            keras.layers.Dense(2), input_shape=(3, 4)))
    model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
    model.fit(
        np.random.random((10, 3, 4)),
        np.random.random((10, 3, 2)),
        epochs=1,
        batch_size=10)

    # test config
    model.get_config()

    checkpointed_objects = set(checkpointable_util.list_objects(model))
    for v in model.variables:
      self.assertIn(v, checkpointed_objects)
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:18,代码来源:wrappers_test.py

示例12: testNonAppendNotCheckpointable

 def testNonAppendNotCheckpointable(self):
   # Non-append mutations (deleting or overwriting values) are OK when the
   # values aren't tracked.
   a = tracking.Checkpointable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = 3
   a.d[1] = 2
   self.assertEqual(2, a.d[1])
   del a.d[1]
   a.d[2] = data_structures.NoDependency(tracking.Checkpointable())
   second = tracking.Checkpointable()
   a.d[2] = data_structures.NoDependency(second)
   self.assertIs(second, a.d[2])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:20,代码来源:data_structures_test.py

示例13: testNoDependency

  def testNoDependency(self):
    root = tracking.Checkpointable()
    hasdep = tracking.Checkpointable()
    root.hasdep = hasdep
    nodep = tracking.Checkpointable()
    root.nodep = data_structures.NoDependency(nodep)
    self.assertEqual(1, len(root._checkpoint_dependencies))
    self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
    self.assertIs(root.hasdep, hasdep)
    self.assertIs(root.nodep, nodep)

    class NoDependencyModel(training.Model):

      @base.no_automatic_dependency_tracking
      def __init__(self):
        super(NoDependencyModel, self).__init__()
        self.a = []
        self.b = tracking.Checkpointable()

    nodeps = NoDependencyModel()
    self.assertEqual([nodeps], util.list_objects(nodeps))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:21,代码来源:tracking_test.py

示例14: testDictionariesBasic

 def testDictionariesBasic(self):
   a = training.Model()
   b = training.Model()
   a.attribute = {"b": b}
   c = training.Model()
   a.attribute["c"] = []
   a.attribute["c"].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertIs(b, a.attribute["b"])
   six.assertCountEqual(
       self,
       ["b", "c"],
       [dep.name for dep in a.attribute._checkpoint_dependencies])
   self.assertEqual([b, c], a.layers)
   self.assertEqual([b, c], a.attribute.layers)
   self.assertEqual([c], a.attribute["c"].layers)
   checkpoint = util.Checkpoint(a=a)
   save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   checkpoint.restore(save_path).assert_consumed()
开发者ID:dan-lennox,项目名称:tensorflow,代码行数:21,代码来源:tracking_test.py

示例15: testDictDeepCopy

  def testDictDeepCopy(self):
    root = tracking.Checkpointable()
    orig_dict = {"a": [1.]}
    root.a = orig_dict
    copied = copy.deepcopy(root.a)
    self.assertAllEqual([1.], copied["a"])
    self.assertIsNot(root.a, copied)
    self.assertIsNot(root.a["a"], copied["a"])

    # Dirtiness should be inherited
    util.list_objects(root.a)
    orig_dict["b"] = []
    with self.assertRaises(ValueError):
      util.list_objects(root.a)
    with self.assertRaises(ValueError):
      util.list_objects(copy.deepcopy(root.a))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:16,代码来源:data_structures_test.py


注:本文中的tensorflow.python.training.checkpointable.util.list_objects函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。