本文整理汇总了Python中tensorflow.python.training.checkpointable.util.list_objects函数的典型用法代码示例。如果您正苦于以下问题:Python list_objects函数的具体用法?Python list_objects怎么用?Python list_objects使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了list_objects函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testNestedLists
def testNestedLists(self):
a = tracking.Checkpointable()
a.l = []
b = tracking.Checkpointable()
a.l.append([b])
c = tracking.Checkpointable()
a.l[0].append(c)
a_deps = util.list_objects(a)
self.assertIn(b, a_deps)
self.assertIn(c, a_deps)
a.l[0].append(1)
d = tracking.Checkpointable()
a.l[0].append(d)
a_deps = util.list_objects(a)
self.assertIn(d, a_deps)
self.assertIn(b, a_deps)
self.assertIn(c, a_deps)
self.assertNotIn(1, a_deps)
e = tracking.Checkpointable()
f = tracking.Checkpointable()
a.l1 = [[], [e]]
a.l1[0].append(f)
a_deps = util.list_objects(a)
self.assertIn(e, a_deps)
self.assertIn(f, a_deps)
checkpoint = util.Checkpoint(a=a)
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
a.l[0].append(data_structures.NoDependency([]))
a.l[0][-1].append(5)
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
# Dirtying the inner list means the root object is unsaveable.
a.l[0][1] = 2
with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
示例2: get_non_optimizer_objects
def get_non_optimizer_objects(m, g):
"""Gather set of model and optimizer checkpointable objects."""
# Set default graph because optimizer.variables() returns optimizer
# variables defined in the default graph.
with g.as_default():
all_objects = set(checkpointable_utils.list_objects(m))
optimizer_and_variables = set()
for obj in all_objects:
if isinstance(obj, optimizers.TFOptimizer):
optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
optimizer_and_variables.update(set(obj.optimizer.variables()))
return all_objects - optimizer_and_variables
示例3: testAddVariableOverwrite
def testAddVariableOverwrite(self):
root = base.Checkpointable()
a = root._add_variable_with_custom_getter(
name="v", shape=[], getter=variable_scope.get_variable)
self.assertEqual([root, a], util.list_objects(root))
with ops.Graph().as_default():
b = root._add_variable_with_custom_getter(
name="v", shape=[], overwrite=True,
getter=variable_scope.get_variable)
self.assertEqual([root, b], util.list_objects(root))
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError, "already declared as a dependency"):
root._add_variable_with_custom_getter(
name="v", shape=[], overwrite=False,
getter=variable_scope.get_variable)
示例4: testDictWrapperNoDependency
def testDictWrapperNoDependency(self):
a = tracking.Checkpointable()
a.d = data_structures.NoDependency({})
a.d[1] = [3]
self.assertEqual([a], util.list_objects(a))
model = training.Model()
model.sub = a
save_path = os.path.join(self.get_temp_dir(), "ckpt")
model.save_weights(save_path)
model.load_weights(save_path)
示例5: testNonStringKeyNotCheckpointableValue
def testNonStringKeyNotCheckpointableValue(self):
a = tracking.Checkpointable()
a.d = {}
a.d["a"] = [3]
a.d[1] = data_structures.NoDependency([3])
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
model = training.Model()
model.sub = a
save_path = os.path.join(self.get_temp_dir(), "ckpt")
model.save_weights(save_path)
model.load_weights(save_path)
示例6: testListBasic
def testListBasic(self):
a = tracking.Checkpointable()
b = tracking.Checkpointable()
a.l = [b]
c = tracking.Checkpointable()
a.l.append(c)
a_deps = util.list_objects(a)
self.assertIn(b, a_deps)
self.assertIn(c, a_deps)
direct_a_dep, = a._checkpoint_dependencies
self.assertEqual("l", direct_a_dep.name)
self.assertIn(b, direct_a_dep.ref)
self.assertIn(c, direct_a_dep.ref)
示例7: testShallowCopyCheckpointable
def testShallowCopyCheckpointable(self):
original = tracking.Checkpointable()
original_sub = tracking.Checkpointable()
original.a = [[1.]]
original.b = {"a": original_sub}
shallow_copied = copy.copy(original)
self.assertIs(original_sub, shallow_copied.b["a"])
self.assertIsNot(original, shallow_copied)
self.assertEqual([[1.]], shallow_copied.a)
shallow_deps = util.list_objects(shallow_copied)
self.assertIn(shallow_copied.a, shallow_deps)
self.assertIn(shallow_copied.b, shallow_deps)
self.assertIn(shallow_copied.b["a"], shallow_deps)
示例8: test_checkpointable_save_restore
def test_checkpointable_save_restore(self):
def _templated():
v = variable_scope.get_variable(
"v", shape=[1], initializer=init_ops.zeros_initializer(),
use_resource=True)
v2 = variable_scope.get_variable(
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
use_resource=True)
manual = _ManualScope()
return v, v + 1., v2, manual, manual()
save_template = template.make_template("s1", _templated)
v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
six.assertCountEqual(
self,
[v1_save, v2_save, manual_scope, manual_scope_v, save_template],
checkpointable_utils.list_objects(save_template))
manual_dep, = manual_scope._checkpoint_dependencies
self.assertEqual("in_manual_scope", manual_dep.name)
self.assertIs(manual_scope_v, manual_dep.ref)
optimizer = adam.AdamOptimizer(0.0)
save_root = checkpointable_utils.Checkpoint(
my_template=save_template, optimizer=optimizer)
optimizer.minimize(v1_save.read_value)
self.evaluate([v.initializer for v in save_template.variables])
self.evaluate([v.initializer for v in optimizer.variables()])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_path = save_root.save(checkpoint_prefix)
load_template = template.make_template("s2", _templated)
load_optimizer = adam.AdamOptimizer(0.0)
load_root = checkpointable_utils.Checkpoint(
my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
var, var_plus_one, var2, _, _ = load_template()
load_optimizer.minimize(var.read_value)
self.assertEqual(3, len(load_template._checkpoint_dependencies))
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
self.assertEqual("ManualScope",
load_template._checkpoint_dependencies[2].name)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([12.], self.evaluate(var))
self.assertAllEqual([13.], self.evaluate(var_plus_one))
self.assertAllEqual([14.], self.evaluate(var2))
示例9: testDeepCopyCheckpointable
def testDeepCopyCheckpointable(self):
original = tracking.Checkpointable()
original_sub = tracking.Checkpointable()
original.a = [[1.]]
original.b = {"a": original_sub}
deep_copied = copy.deepcopy(original)
self.assertIsNot(original, deep_copied)
self.assertIsNot(original_sub, deep_copied.b["a"])
self.assertEqual([[1.]], deep_copied.a)
self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
deps = util.list_objects(deep_copied)
self.assertIn(deep_copied.a, deps)
self.assertIn(deep_copied.b, deps)
self.assertIn(deep_copied.b["a"], deps)
self.assertNotIn(original_sub, deps)
示例10: test_checkpointable_dependencies
def test_checkpointable_dependencies(self):
rnn = keras.layers.SimpleRNN
with self.test_session():
x = np.random.random((2, 2, 2))
y = np.random.random((2, 2))
model = keras.models.Sequential()
model.add(rnn(2))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(x, y, epochs=1, batch_size=1)
# check whether the model variables are present in the
# checkpointable list of objects
checkpointed_objects = set(checkpointable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)
示例11: test_timedistributed_dense
def test_timedistributed_dense(self):
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
keras.layers.Dense(2), input_shape=(3, 4)))
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
model.fit(
np.random.random((10, 3, 4)),
np.random.random((10, 3, 2)),
epochs=1,
batch_size=10)
# test config
model.get_config()
checkpointed_objects = set(checkpointable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)
示例12: testNonAppendNotCheckpointable
def testNonAppendNotCheckpointable(self):
# Non-append mutations (deleting or overwriting values) are OK when the
# values aren't tracked.
a = tracking.Checkpointable()
a.d = {}
a.d["a"] = [3]
a.d[1] = 3
a.d[1] = 2
self.assertEqual(2, a.d[1])
del a.d[1]
a.d[2] = data_structures.NoDependency(tracking.Checkpointable())
second = tracking.Checkpointable()
a.d[2] = data_structures.NoDependency(second)
self.assertIs(second, a.d[2])
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
model = training.Model()
model.sub = a
save_path = os.path.join(self.get_temp_dir(), "ckpt")
model.save_weights(save_path)
model.load_weights(save_path)
示例13: testNoDependency
def testNoDependency(self):
root = tracking.Checkpointable()
hasdep = tracking.Checkpointable()
root.hasdep = hasdep
nodep = tracking.Checkpointable()
root.nodep = data_structures.NoDependency(nodep)
self.assertEqual(1, len(root._checkpoint_dependencies))
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
self.assertIs(root.hasdep, hasdep)
self.assertIs(root.nodep, nodep)
class NoDependencyModel(training.Model):
@base.no_automatic_dependency_tracking
def __init__(self):
super(NoDependencyModel, self).__init__()
self.a = []
self.b = tracking.Checkpointable()
nodeps = NoDependencyModel()
self.assertEqual([nodeps], util.list_objects(nodeps))
示例14: testDictionariesBasic
def testDictionariesBasic(self):
a = training.Model()
b = training.Model()
a.attribute = {"b": b}
c = training.Model()
a.attribute["c"] = []
a.attribute["c"].append(c)
a_deps = util.list_objects(a)
self.assertIn(b, a_deps)
self.assertIn(c, a_deps)
self.assertIs(b, a.attribute["b"])
six.assertCountEqual(
self,
["b", "c"],
[dep.name for dep in a.attribute._checkpoint_dependencies])
self.assertEqual([b, c], a.layers)
self.assertEqual([b, c], a.attribute.layers)
self.assertEqual([c], a.attribute["c"].layers)
checkpoint = util.Checkpoint(a=a)
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
checkpoint.restore(save_path).assert_consumed()
示例15: testDictDeepCopy
def testDictDeepCopy(self):
root = tracking.Checkpointable()
orig_dict = {"a": [1.]}
root.a = orig_dict
copied = copy.deepcopy(root.a)
self.assertAllEqual([1.], copied["a"])
self.assertIsNot(root.a, copied)
self.assertIsNot(root.a["a"], copied["a"])
# Dirtiness should be inherited
util.list_objects(root.a)
orig_dict["b"] = []
with self.assertRaises(ValueError):
util.list_objects(root.a)
with self.assertRaises(ValueError):
util.list_objects(copy.deepcopy(root.a))