本文整理汇总了Python中tensorflow.python.training.checkpoint_utils.init_from_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python init_from_checkpoint函数的具体用法?Python init_from_checkpoint怎么用?Python init_from_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了init_from_checkpoint函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testInitFromCheckpoint
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable("my1", [1, 10])
with variable_scope.variable_scope("some_other_scope"):
my2 = variable_scope.get_variable("my2", [10, 10])
with variable_scope.variable_scope("other_useful_scope"):
my4 = variable_scope.get_variable("var4", [9, 9])
my3 = variable_scope.get_variable("my3", [100, 100])
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"var1": "some_scope/my1",
"useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
})
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"var2": "some_scope/some_other_scope/my2",
"var3": my3,
})
session.run(variables.global_variables_initializer())
self.assertAllEqual(my1.eval(session), v1)
self.assertAllEqual(my2.eval(session), v2)
self.assertAllEqual(my3.eval(session), v3)
self.assertAllEqual(my4.eval(session), v4)
# Check that tensors are not explicitly in the graph.
self.assertLess(len(str(session.graph.as_graph_def())), 29000)
示例2: testInitialValueComesFromCheckpoint
def testInitialValueComesFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope(
"some_scope", initializer=init_ops.zeros_initializer()):
my1 = variable_scope.get_variable("my1", [1, 10])
# At this point, my1.initialized_value() will add ops that reference
# the zeros initializer of my1.
before = variables.Variable(my1.initialized_value(), name="before")
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})
# At this point, my1.initialized_value() will add ops that reference
# the newly set initializer of my1.
after = variables.Variable(my1.initialized_value(), name="after")
session.run(variables.global_variables_initializer())
self.assertAllEqual(session.run(my1), v1)
self.assertAllEqual(session.run(my1.initialized_value()), v1)
self.assertAllClose(session.run(before), [[0.0] * 10])
self.assertAllClose(session.run(after), v1)
with self.assertRaises(AssertionError):
self.assertAllClose(session.run(before), session.run(after))
示例3: _warm_start_var
def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Args:
var: Current graph's variable that needs to be warm-started (initialized).
Can be either of the following:
(i) `Variable`
(ii) `ResourceVariable`
(iii) list of `Variable`: The list must contain slices of the same larger
variable.
(iv) `PartitionedVariable`
prev_ckpt: A string specifying the directory with checkpoint file(s) or path
to checkpoint. The given checkpoint must have tensor with name
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
"""
if checkpoint_utils._is_variable(var): # pylint: disable=protected-access
current_var_name = _infer_var_name([var])
elif (isinstance(var, list) and
all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access
current_var_name = _infer_var_name(var)
elif isinstance(var, variables_lib.PartitionedVariable):
current_var_name = _infer_var_name([var])
var = var._get_variable_list() # pylint: disable=protected-access
else:
raise TypeError(
"var MUST be one of the following: a Variable, list of Variable or "
"PartitionedVariable, but is {}".format(type(var)))
if not prev_tensor_name:
# Assume tensor name remains the same.
prev_tensor_name = current_var_name
checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
示例4: testInitialValueComesFromCheckpoint
def testInitialValueComesFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope(
"some_scope", initializer=init_ops.zeros_initializer()):
my1 = variable_scope.get_variable("my1", [1, 10])
before = my1.initialized_value()
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})
after = my1.initialized_value()
self.assertAllEqual(session.run(before), [[0.0] * 10])
self.assertAllEqual(session.run(after), v1)
session.run(variables.global_variables_initializer())
self.assertAllEqual(session.run(my1), v1)
self.assertAllEqual(session.run(my1.initialized_value()), v1)
self.assertAllClose(session.run(before), v1)
self.assertAllClose(session.run(after), v1)
with self.assertRaises(AssertionError):
self.assertAllClose(v1, [[0.0] * 10])
示例5: testNoAdditionalReadOpsForResourceVariables
def testNoAdditionalReadOpsForResourceVariables(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.session(graph=g) as session:
my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1")
with ops.name_scope("init_from_checkpoint"):
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})
# Basic sanity checks:
session.run(variables.global_variables_initializer())
self.assertAllEqual(session.run(my1), v1)
ops_in_init_from_checkpoint_scope = [
op for op in g.get_operations()
if (op.name.startswith("init_from_checkpoint/") and
not op.name.startswith("init_from_checkpoint/checkpoint_initializer"
) and
op.type != "AssignVariableOp" and
op.type != "Identity")
]
self.assertEqual(ops_in_init_from_checkpoint_scope, [])
示例6: init_and_verify
def init_and_verify(g):
v1 = variable_scope.get_variable("new_var1", [1, 10])
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"var1": "new_var1",
})
with self.test_session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(v1_value, self.evaluate(v1))
示例7: testRestoreRunsOnSameDevice
def testRestoreRunsOnSameDevice(self):
checkpoint_dir = self.get_temp_dir()
with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default():
with ops.device("/job:ps"):
with variable_scope.variable_scope("useful_scope"):
my4 = variable_scope.get_variable("var4", [9, 9])
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope/": "useful_scope/"})
示例8: init_and_verify
def init_and_verify(g):
v1 = variable_scope.get_variable("new_var1", [1, 10])
# Use string add to create new object in each replica
prefix = "new_"
suffix = "var1"
new_var1 = prefix + suffix
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"var1": new_var1,
})
with self.test_session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(v1_value, self.evaluate(v1))
示例9: init_and_verify
def init_and_verify(g):
v1 = variable_scope.get_variable("new_var1", [1, 10])
v2 = variable_scope.get_variable(
"new_var2", [10, 10],
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.MEAN)
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"var1": "new_var1",
"var2": "new_var2"
})
with self.session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(v1_value, self.evaluate(v1))
self.assertAllEqual(v2_value, self.evaluate(v2))
示例10: testRestoreRunsOnSameDevice
def testRestoreRunsOnSameDevice(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default():
with ops.device("/job:ps"):
with variable_scope.variable_scope("useful_scope"):
my4 = variable_scope.get_variable("var4", [9, 9])
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope/": "useful_scope/"})
# initializer runs on the same task but always on CPU.
self.assertEqual(my4._initializer_op.op.inputs[1].device,
"/job:ps/device:CPU:0")
示例11: testInitWithScopeDoesNotCaptureSuffixes
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
with variable_scope.variable_scope("useful_scope"):
my4 = variable_scope.get_variable("var4", [9, 9])
with variable_scope.variable_scope("useful_scope_1"):
my5_init = [[1.0, 2.0], [3.0, 4.0]]
my5 = variable_scope.get_variable("var5", initializer=my5_init)
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope/": "useful_scope/"})
with self.test_session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(my4.eval(session), v4)
self.assertAllEqual(my5.eval(session), my5_init)
示例12: testInitToRootCheckpoint
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
my1 = variable_scope.get_variable("var1", [1, 10])
my2 = variable_scope.get_variable("var2", [10, 10])
my3 = variable_scope.get_variable("var3", [100, 100])
with variable_scope.variable_scope("useful_scope"):
my4 = variable_scope.get_variable("var4", [9, 9])
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"/": "/",})
session.run(variables.global_variables_initializer())
self.assertAllEqual(my1.eval(session), v1)
self.assertAllEqual(my2.eval(session), v2)
self.assertAllEqual(my3.eval(session), v3)
self.assertAllEqual(my4.eval(session), v4)
示例13: testInitFromCheckpointMissing
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
_ = variable_scope.get_variable("my1", [10, 10])
_ = variable_scope.get_variable(
"my2", [1, 10],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer())
# No directory.
with self.assertRaises(errors_impl.OpError):
checkpoint_utils.init_from_checkpoint("no_dir",
{"var1": "some_scope/my1"})
# No variable in checkpoint.
with self.assertRaises(ValueError):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"no_var": "some_scope/my1"})
# No variable in the graph.
with self.assertRaises(ValueError):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"var3": "some_scope/no_var"})
# Shape mismatch.
with self.assertRaises(ValueError):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"var1": "some_scope/my1"})
# Variable 'my1' and 'my2' are missing in given checkpoint scope.
with self.assertRaises(ValueError):
checkpoint_utils.init_from_checkpoint(
checkpoint_dir, {"useful_scope/": "some_scope/"})
# Mapping is not to scope name.
with self.assertRaises(ValueError):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope": "some_scope/"})
示例14: testInitFromPartitionVar
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
shape=[100, 100],
initializer=init_ops.zeros_initializer(),
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my1_var_list = my1._get_variable_list()
# Create another variable with different partitions than the variable in
# the checkpoint.
with variable_scope.variable_scope("some_other_scope"):
my2 = variable_scope.get_variable(
name="var1",
shape=[100, 100],
initializer=init_ops.zeros_initializer(),
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=16 << 10))
my2_var_list = my2._get_variable_list()
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"scope/var1": "some_scope/my1",
"scope/": "some_other_scope/"})
session.run(variables.global_variables_initializer())
my1_values = session.run(my1_var_list)
self.assertAllEqual(my1_values, v1)
my2_values = session.run(my2_var_list)
# Verify we created different number of partitions.
self.assertNotEquals(len(my2_values), len(v1))
# Verify the values were correctly initialized inspite of different
# partitions.
full_my2_values = np.concatenate(my2_values, axis=0)
full_v1_values = np.concatenate(v1, axis=0)
self.assertAllEqual(full_my2_values, full_v1_values)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
shape=[100, 100],
initializer=init_ops.truncated_normal_initializer(0.5),
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my1_var_list = my1._get_variable_list()
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"scope/var1": my1_var_list,})
session.run(variables.global_variables_initializer())
my1_values = session.run(my1_var_list)
self.assertAllEqual(my1_values, v1)
示例15: warm_start
#.........这里部分代码省略.........
Defaults to `'.*'`, which warm-starts all variables in the
TRAINABLE_VARIABLES collection. Note that this excludes variables such
as accumulators and moving statistics from batch norm.
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
`tf.estimator.VocabInfo`. The variable names should be "full" variables,
not the names of the partitions. If not explicitly provided, the variable
is assumed to have no (changes to) vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
between previous checkpoint and current model. Note that this has no
effect on the set of variables that is warm-started, and only controls
name mapping (use `vars_to_warm_start` for controlling what variables to
warm-start).
Raises:
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
configuration for variable names that are not used. This is to ensure
a stronger check for variable configuration than relying on users to
examine the logs.
"""
if var_name_to_vocab_info is None:
var_name_to_vocab_info = {}
if var_name_to_prev_var_name is None:
var_name_to_prev_var_name = {}
logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
grouped_variables = _get_grouped_variables(vars_to_warm_start)
# Keep track of which var_names in var_name_to_prev_var_name and
# var_name_to_vocab_info have been used. Err on the safer side by throwing an
# exception if any are unused by the end of the loop. It is easy to misname
# a variable during this configuration, in which case without this check, we
# would fail to warm-start silently.
prev_var_name_used = set()
vocab_info_used = set()
# Group the vocabless vars into one call to init_from_checkpoint.
vocabless_vars = {}
for var_name, variable in six.iteritems(grouped_variables):
prev_var_name = var_name_to_prev_var_name.get(var_name)
if prev_var_name:
prev_var_name_used.add(var_name)
vocab_info = var_name_to_vocab_info.get(var_name)
if vocab_info:
vocab_info_used.add(var_name)
logging.info(
"Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
" prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
" initializer: {}".format(
var_name,
vocab_info.new_vocab,
vocab_info.new_vocab_size,
vocab_info.old_vocab,
(vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0
else "All"),
vocab_info.num_oov_buckets,
prev_var_name or "Unchanged",
vocab_info.backup_initializer or "zero-initialized"))
_warm_start_var_with_vocab(
variable,
current_vocab_path=vocab_info.new_vocab,
current_vocab_size=vocab_info.new_vocab_size,
prev_ckpt=ckpt_to_initialize_from,
prev_vocab_path=vocab_info.old_vocab,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
initializer=vocab_info.backup_initializer,
axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.
if vars_to_warm_start:
logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
var_name, prev_var_name or "Unchanged"))
# Because we use a default empty list in grouped_variables, single
# unpartitioned variables will be lists here, which we rectify in order
# for init_from_checkpoint logic to work correctly.
if len(variable) == 1:
variable = variable[0]
prev_tensor_name, var = _get_var_info(variable, prev_var_name)
vocabless_vars[prev_tensor_name] = var
checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars)
prev_var_name_not_used = set(
var_name_to_prev_var_name.keys()) - prev_var_name_used
vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used
if prev_var_name_not_used:
raise ValueError(
"You provided the following variables in "
"var_name_to_prev_var_name that were not used: "
"{0}. Perhaps you misspelled them? Here is the list of viable "
"variable names: {1}".format(prev_var_name_not_used,
grouped_variables.keys()))
if vocab_info_not_used:
raise ValueError(
"You provided the following variables in "
"var_name_to_vocab_info that were not used: {0}. "
" Perhaps you misspelled them? Here is the list of viable variable "
"names: {1}".format(vocab_info_not_used, grouped_variables.keys()))