本文整理汇总了Python中tensorflow.python.ops.variables.local_variables方法的典型用法代码示例。如果您正苦于以下问题:Python variables.local_variables方法的具体用法?Python variables.local_variables怎么用?Python variables.local_variables使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.variables
的用法示例。
在下文中一共展示了variables.local_variables方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_epoch_variable
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def get_epoch_variable():
"""Returns the epoch variable, or [0] if not defined."""
# Grab epoch variable defined in
# //third_party/tensorflow/python/training/input.py::limit_epochs
for v in tf_variables.local_variables():
if 'limit_epochs/epoch' in v.op.name:
return array_ops.reshape(v, [1])
# TODO(thomaswc): Access epoch from the data feeder.
return [0]
# A simple container to hold the training variables for a single tree.
示例2: testVars
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def testVars(self):
classification.f1_score(
predictions=array_ops.ones((10, 1)),
labels=array_ops.ones((10, 1)),
num_thresholds=3)
expected = {'f1/true_positives:0', 'f1/false_positives:0',
'f1/false_negatives:0'}
self.assertEqual(
expected, set(v.name for v in variables.local_variables()))
self.assertEqual(
set(expected), set(v.name for v in variables.local_variables()))
self.assertEqual(
set(expected),
set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
示例3: test_local_variable
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def test_local_variable(self):
with self.cached_session() as sess:
self.assertEqual([], variables_lib.local_variables())
value0 = 42
variables_lib2.local_variable(value0)
value1 = 43
variables_lib2.local_variable(value1)
variables = variables_lib.local_variables()
self.assertEqual(2, len(variables))
self.assertRaises(errors_impl.OpError, sess.run, variables)
variables_lib.variables_initializer(variables).run()
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
示例4: testLocalVariableNotInAllVariables
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def testLocalVariableNotInAllVariables(self):
with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
self.assertNotIn(a, variables_lib.global_variables())
self.assertIn(a, variables_lib.local_variables())
示例5: testLocalVariableNotInVariablesToRestore
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def testLocalVariableNotInVariablesToRestore(self):
with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
self.assertNotIn(a, variables_lib2.get_variables_to_restore())
self.assertIn(a, variables_lib.local_variables())
示例6: testGlobalVariableNotInLocalVariables
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def testGlobalVariableNotInLocalVariables(self):
with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
self.assertNotIn(a, variables_lib.local_variables())
self.assertIn(a, variables_lib.global_variables())
示例7: testCreateVariable
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def testCreateVariable(self):
with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
self.assertEqual(a.op.name, 'A/a')
self.assertListEqual(a.get_shape().as_list(), [5])
self.assertIn(a, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
self.assertNotIn(a, ops.get_collection(ops.GraphKeys.MODEL_VARIABLES))
self.assertNotIn(a, variables_lib.local_variables())
示例8: testNotInLocalVariables
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def testNotInLocalVariables(self):
with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
self.assertIn(a, variables_lib.global_variables())
self.assertIn(a, ops.get_collection(ops.GraphKeys.MODEL_VARIABLES))
self.assertNotIn(a, variables_lib.local_variables())
示例9: run
# 需要导入模块: from tensorflow.python.ops import variables [as 别名]
# 或者: from tensorflow.python.ops.variables import local_variables [as 别名]
def run(self,
num_batches=None,
graph=None,
session=None,
start_queues=True,
initialize_variables=True,
**kwargs):
"""Builds and runs the columns of the `DataFrame` and yields batches.
This is a generator that yields a dictionary mapping column names to
evaluated columns.
Args:
num_batches: the maximum number of batches to produce. If none specified,
the returned value will iterate through infinite batches.
graph: the `Graph` in which the `DataFrame` should be built.
session: the `Session` in which to run the columns of the `DataFrame`.
start_queues: if true, queues will be started before running and halted
after producting `n` batches.
initialize_variables: if true, variables will be initialized.
**kwargs: Additional keyword arguments e.g. `num_epochs`.
Yields:
A dictionary, mapping column names to the values resulting from running
each column for a single batch.
"""
if graph is None:
graph = ops.get_default_graph()
with graph.as_default():
if session is None:
session = sess.Session()
self_built = self.build(**kwargs)
keys = list(self_built.keys())
cols = list(self_built.values())
if initialize_variables:
if variables.local_variables():
session.run(variables.local_variables_initializer())
if variables.global_variables():
session.run(variables.global_variables_initializer())
if start_queues:
coord = coordinator.Coordinator()
threads = qr.start_queue_runners(sess=session, coord=coord)
i = 0
while num_batches is None or i < num_batches:
i += 1
try:
values = session.run(cols)
yield collections.OrderedDict(zip(keys, values))
except errors.OutOfRangeError:
break
if start_queues:
coord.request_stop()
coord.join(threads)