本文整理汇总了Python中tensorflow.python.ops.control_flow_ops._SwitchRefOrTensor方法的典型用法代码示例。如果您正苦于以下问题:Python control_flow_ops._SwitchRefOrTensor方法的具体用法?Python control_flow_ops._SwitchRefOrTensor怎么用?Python control_flow_ops._SwitchRefOrTensor使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.control_flow_ops
的用法示例。
在下文中一共展示了control_flow_ops._SwitchRefOrTensor方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testRefSwitch
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 别名]
def testRefSwitch(self):
with self.test_session():
v = tf.Variable(7)
p = tf.constant(True)
v1 = control_flow_ops._SwitchRefOrTensor(v.ref(), p)
v2 = tf.assign(v1[1], 9)
tf.global_variables_initializer().run()
self.assertEqual(9, v2.eval())
示例2: _MergeGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 别名]
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = control_flow_ops._GetOutputContext(input_op)
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackPropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
# pylint: enable=protected-access
示例3: _MergeGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 别名]
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = input_op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackPropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
# pylint: enable=protected-access
示例4: _MergeGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 别名]
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = control_flow_ops._GetOutputContext(input_op)
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
# pylint: enable=protected-access
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:45,代码来源:control_flow_grad.py