本文整理匯總了Python中tensorflow.python.ops.control_flow_ops._AddNextAndBackEdge方法的典型用法代碼示例。如果您正苦於以下問題:Python control_flow_ops._AddNextAndBackEdge方法的具體用法?Python control_flow_ops._AddNextAndBackEdge怎麽用?Python control_flow_ops._AddNextAndBackEdge使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.ops.control_flow_ops
的用法示例。
在下文中一共展示了control_flow_ops._AddNextAndBackEdge方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: _SwitchGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 別名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO: Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
else:
# This is the first time this Switch is visited. It always comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
示例2: _SwitchGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 別名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO(yuanbyu): Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
示例3: _SwitchGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 別名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO: Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
示例4: _SwitchGrad
# 需要導入模塊: from tensorflow.python.ops import control_flow_ops [as 別名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 別名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO(yuanbyu): Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
# Unfortunately, we may still get None here for not trainable data types.
if zero_grad is None:
return None, None
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
開發者ID:PacktPublishing,項目名稱:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代碼行數:51,代碼來源:control_flow_grad.py