本文整理汇总了Python中tensorflow.python.ops.control_flow_ops._AddNextAndBackEdge方法的典型用法代码示例。如果您正苦于以下问题:Python control_flow_ops._AddNextAndBackEdge方法的具体用法?Python control_flow_ops._AddNextAndBackEdge怎么用?Python control_flow_ops._AddNextAndBackEdge使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.control_flow_ops
的用法示例。
在下文中一共展示了control_flow_ops._AddNextAndBackEdge方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _SwitchGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 别名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO: Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
else:
# This is the first time this Switch is visited. It always comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
示例2: _SwitchGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 别名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO(yuanbyu): Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
示例3: _SwitchGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 别名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO: Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
示例4: _SwitchGrad
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import _AddNextAndBackEdge [as 别名]
def _SwitchGrad(op, *grad):
"""Gradients for a Switch op is calculated using a Merge op.
If the switch is a loop switch, it will be visited twice. We create
the merge on the first visit, and update the other input of the merge
on the second visit. A next_iteration is also added on second visit.
"""
graph = ops.get_default_graph()
# pylint: disable=protected-access
op_ctxt = op._get_control_flow_context()
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if isinstance(op_ctxt, WhileContext):
merge_grad = grad_ctxt.grad_state.switch_map.get(op)
if merge_grad is not None:
# This is the second time this Switch is visited. It comes from
# the non-exit branch of the Switch, so update the second input
# to the Merge.
# TODO(yuanbyu): Perform shape inference with this new input.
if grad[1] is not None:
# pylint: disable=protected-access
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
elif grad[0] is not None:
# This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
else:
# This is the first time this Switch is visited. It comes from the
# Identity branch. Such a Switch has `None` gradient for the Exit branch,
# meaning the output is not differentiable.
return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
# At this point, we have created zero_grad guarded by the right switch.
# Unfortunately, we may still get None here for not trainable data types.
if zero_grad is None:
return None, None
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
return merge([false_grad, true_grad])[0], None
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:51,代码来源:control_flow_grad.py