本文整理汇总了Python中tensorflow.python.framework.ops.get_collection_ref方法的典型用法代码示例。如果您正苦于以下问题:Python ops.get_collection_ref方法的具体用法?Python ops.get_collection_ref怎么用?Python ops.get_collection_ref使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.framework.ops
的用法示例。
在下文中一共展示了ops.get_collection_ref方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _add_elements_to_collection
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def _add_elements_to_collection(elements, collection_list):
elements = _to_list(elements)
collection_list = _to_list(collection_list)
for name in collection_list:
collection = ops.get_collection_ref(name)
collection_set = set(collection)
for element in elements:
if element not in collection_set:
collection.append(element)
示例2: reset_uids
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def reset_uids():
layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS')
if layer_name_uids_collection:
layer_name_uids_collection.pop()
示例3: _set_axis_order
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def _set_axis_order(axis_order):
axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY)
if axis_order_list:
axis_order_list[0] = axis_order
else:
axis_order_list.append(axis_order)
示例4: _add_elements_to_collection
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def _add_elements_to_collection(elements, collections):
elements = _to_list(elements)
collections = _to_list(collections)
for name in collections:
collection = ops.get_collection_ref(name)
collection_set = set(collection)
for element in elements:
if element not in collection_set:
collection.append(element)
示例5: testWithIsRecomputeKwarg
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def testWithIsRecomputeKwarg(self):
kwarg_values = []
@rev_block_lib.recompute_grad
def layer_with_recompute(inputs, is_recomputing=False):
kwarg_values.append(is_recomputing)
out = core_layers.dense(inputs, 2)
out = normalization_layers.batch_normalization(out, training=True)
if is_recomputing:
# Ensure that the updates are not duplicated by popping off the latest
# 2 additions.
update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS)
update_ops.pop()
update_ops.pop()
return out
x = array_ops.ones((2, 4), dtypes.float32)
with variable_scope.variable_scope("layer1", use_resource=True):
y = layer_with_recompute(x)
loss = math_ops.reduce_sum(y)
tvars = variables.trainable_variables()
gradients_impl.gradients(loss, [x] + tvars)
update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
self.assertEqual(2, len(update_ops))
self.assertEqual([False, True], kwarg_values)
示例6: begin
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def begin(self):
local_vars = variables.trainable_variables()
global_vars = ops.get_collection_ref("global_model")
self._variable_init_op = self._fed_avg_optimizer._assign_vars(
local_vars,
global_vars)
示例7: _add_elements_to_collection
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection_ref [as 别名]
def _add_elements_to_collection(elements, collection_list):
if context.in_eager_mode():
raise RuntimeError('Using collections from Layers not supported in Eager '
'mode. Tried to add %s to %s' % (elements,
collection_list))
elements = _to_list(elements)
collection_list = _to_list(collection_list)
for name in collection_list:
collection = ops.get_collection_ref(name)
collection_set = set(collection)
for element in elements:
if element not in collection_set:
collection.append(element)
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:15,代码来源:base.py