本文整理匯總了Python中tensorflow.python.ops.control_flow_ops._SwitchRefOrTensor方法的典型用法代碼示例。如果您正苦於以下問題:Python control_flow_ops._SwitchRefOrTensor方法的具體用法?Python control_flow_ops._SwitchRefOrTensor怎麽用?Python control_flow_ops._SwitchRefOrTensor使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.ops.control_flow_ops
的用法示例。
在下文中一共展示了control_flow_ops._SwitchRefOrTensor方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: testRefSwitch
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 別名]
def testRefSwitch(self):
with self.test_session():
v = tf.Variable(7)
p = tf.constant(True)
v1 = control_flow_ops._SwitchRefOrTensor(v.ref(), p)
v2 = tf.assign(v1[1], 9)
tf.global_variables_initializer().run()
self.assertEqual(9, v2.eval())
示例2: _MergeGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 別名]
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = control_flow_ops._GetOutputContext(input_op)
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackPropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
# pylint: enable=protected-access
示例3: _MergeGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 別名]
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = input_op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackPropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
# pylint: enable=protected-access
示例4: _MergeGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _SwitchRefOrTensor [as 別名]
def _MergeGrad(op, grad, _):
"""Gradients for a Merge op are calculated using a Switch op."""
input_op = op.inputs[0].op
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = control_flow_ops._GetOutputContext(input_op)
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
# pylint: enable=protected-access
elif isinstance(op_ctxt, CondContext):
pred = op_ctxt.pred
if grad_ctxt and grad_ctxt.grad_state:
# This Merge node is part of a cond within a loop.
# The backprop needs to have the value of this predicate for every
# iteration. So we must have its values accumulated in the forward, and
# use the accumulated values as the predicate for this backprop switch.
grad_state = grad_ctxt.grad_state
real_pred = grad_state.history_map.get(pred.name)
if real_pred is None:
# Remember the value of pred for every iteration.
grad_ctxt = grad_state.grad_context
grad_ctxt.Exit()
history_pred = grad_state.AddForwardAccumulator(pred)
grad_ctxt.Enter()
# Add the stack pop op. If pred.op is in a (outer) CondContext,
# the stack pop will be guarded with a switch.
real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred)
grad_state.history_map[pred.name] = real_pred
pred = real_pred
# pylint: disable=protected-access
return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
# pylint: enable=protected-access
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
# pylint: disable=protected-access
return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
for i in xrange(num_inputs)]
# pylint: enable=protected-access
開發者ID:PacktPublishing,項目名稱:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代碼行數:45,代碼來源:control_flow_grad.py