本文整理汇总了Python中tensorflow.python.ops.control_flow_ops.cond方法的典型用法代码示例。如果您正苦于以下问题:Python control_flow_ops.cond方法的具体用法?Python control_flow_ops.cond怎么用?Python control_flow_ops.cond使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.control_flow_ops
的用法示例。
在下文中一共展示了control_flow_ops.cond方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: static_cond
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def static_cond(pred, fn1, fn2):
"""Return either fn1() or fn2() based on the boolean value of `pred`.
Same signature as `control_flow_ops.cond()` but requires pred to be a bool.
Args:
pred: A value determining whether to return the result of `fn1` or `fn2`.
fn1: The callable to be performed if pred is true.
fn2: The callable to be performed if pred is false.
Returns:
Tensors returned by the call to either `fn1` or `fn2`.
Raises:
TypeError: if `fn1` or `fn2` is not callable.
"""
if not callable(fn1):
raise TypeError('fn1 must be callable.')
if not callable(fn2):
raise TypeError('fn2 must be callable.')
if pred:
return fn1()
else:
return fn2()
示例2: smart_cond
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def smart_cond(pred, fn1, fn2, name=None):
"""Return either fn1() or fn2() based on the boolean predicate/value `pred`.
If `pred` is bool or has a constant value it would use `static_cond`,
otherwise it would use `tf.cond`.
Args:
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
fn1: The callable to be performed if pred is true.
fn2: The callable to be performed if pred is false.
name: Optional name prefix when using tf.cond
Returns:
Tensors returned by the call to either `fn1` or `fn2`.
"""
pred_value = constant_value(pred)
if pred_value is not None:
# Use static_cond if pred has a constant value.
return static_cond(pred_value, fn1, fn2)
else:
# Use dynamic cond otherwise.
return control_flow_ops.cond(pred, fn1, fn2, name)
示例3: _assert
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def _assert(cond, ex_type, msg):
"""A polymorphic assert, works with tensors and boolean expressions.
If `cond` is not a tensor, behave like an ordinary assert statement, except
that a empty list is returned. If `cond` is a tensor, return a list
containing a single TensorFlow assert op.
Args:
cond: Something evaluates to a boolean value. May be a tensor.
ex_type: The exception class to use.
msg: The error message.
Returns:
A list, containing at most one assert op.
"""
if _is_tensor(cond):
return [control_flow_ops.Assert(cond, [msg])]
else:
if not cond:
raise ex_type(msg)
else:
return []
示例4: random_flip_left_right
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def random_flip_left_right(image, bboxes, seed=None):
"""Random flip left-right of an image and its bounding boxes.
"""
def flip_bboxes(bboxes):
"""Flip bounding boxes coordinates.
"""
bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3],
bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1)
return bboxes
# Random flip. Tensorflow implementation.
with tf.name_scope('random_flip_left_right'):
image = ops.convert_to_tensor(image, name='image')
_Check3DImage(image, require_static=False)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
# Flip image.
result = control_flow_ops.cond(mirror_cond,
lambda: array_ops.reverse_v2(image, [1]),
lambda: image)
# Flip bboxes.
bboxes = control_flow_ops.cond(mirror_cond,
lambda: flip_bboxes(bboxes),
lambda: bboxes)
return fix_image_flip_shape(image, result), bboxes
示例5: testDebugCondWatchingWholeGraphWorks
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def testDebugCondWatchingWholeGraphWorks(self):
with session.Session() as sess:
x = variables.Variable(10.0, name="x")
y = variables.Variable(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
sess.run(variables.global_variables_initializer())
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options,
sess.graph,
debug_urls=self._debug_urls())
run_metadata = config_pb2.RunMetadata()
self.assertEqual(
21, sess.run(cond, options=run_options, run_metadata=run_metadata))
dump = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=run_metadata.partition_graphs)
self.assertAllClose(
[21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity"))
示例6: _safe_scalar_div
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def _safe_scalar_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is 0.
Args:
numerator: A scalar `float64` `Tensor`.
denominator: A scalar `float64` `Tensor`.
name: Name for the returned op.
Returns:
0 if `denominator` == 0, else `numerator` / `denominator`
"""
numerator.get_shape().with_rank_at_most(1)
denominator.get_shape().with_rank_at_most(1)
return control_flow_ops.cond(
math_ops.equal(
array_ops.constant(0.0, dtype=dtypes.float64), denominator),
lambda: array_ops.constant(0.0, dtype=dtypes.float64),
lambda: math_ops.div(numerator, denominator),
name=name)
示例7: initialized_value
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def initialized_value(self):
"""Returns the value of the initialized variable.
You should use this instead of the variable itself to initialize another
variable with a value that depends on the value of this variable.
```python
# Initialize 'v' with a random tensor.
v = tf.Variable(tf.truncated_normal([10, 40]))
# Use `initialized_value` to guarantee that `v` has been
# initialized before its value is used to initialize `w`.
# The random values are picked only once.
w = tf.Variable(v.initialized_value() * 2.0)
```
Returns:
A `Tensor` holding the value of this variable after its initializer
has run.
"""
with ops.control_dependencies(None):
return control_flow_ops.cond(is_variable_initialized(self),
self.read_value,
lambda: self.initial_value)
示例8: _assert
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def _assert(cond, ex_type, msg):
"""A polymorphic assert, works with tensors and boolean expressions.
If `cond` is not a tensor, behave like an ordinary assert statement, except
that a empty list is returned. If `cond` is a tensor, return a list
containing a single TensorFlow assert op.
Args:
cond: Something evaluates to a boolean value. May be a tensor.
ex_type: The exception class to use.
msg: The error message.
Returns:
A list, containing at most one assert op.
"""
if _is_tensor(cond):
return [control_flow_ops.Assert(cond, [msg])]
else:
if not cond:
raise ex_type(msg)
else:
return []
示例9: _log_prob
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def _log_prob(self, event):
event = self._maybe_assert_valid_sample(event)
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
# inconsistent behavior for logits = inf/-inf.
event = math_ops.cast(event, self.logits.dtype)
logits = self.logits
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
# so we do this here.
def _broadcast(logits, event):
return (array_ops.ones_like(event) * logits,
array_ops.ones_like(logits) * event)
# First check static shape.
if (event.get_shape().is_fully_defined() and
logits.get_shape().is_fully_defined()):
if event.get_shape() != logits.get_shape():
logits, event = _broadcast(logits, event)
else:
logits, event = control_flow_ops.cond(
distribution_util.same_dynamic_shape(logits, event),
lambda: (logits, event),
lambda: _broadcast(logits, event))
return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
示例10: average_impurity
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def average_impurity(self):
"""Constructs a TF graph for evaluating the average leaf impurity of a tree.
If in regression mode, this is the leaf variance. If in classification mode,
this is the gini impurity.
Returns:
The last op in the graph.
"""
children = array_ops.squeeze(array_ops.slice(
self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
is_leaf = math_ops.equal(constants.LEAF_NODE, children)
leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
squeeze_dims=[1]))
counts = array_ops.gather(self.variables.node_sums, leaves)
gini = self._weighted_gini(counts)
# Guard against step 1, when there often are no leaves yet.
def impurity():
return gini
# Since average impurity can be used for loss, when there's no data just
# return a big number so that loss always decreases.
def big():
return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
return control_flow_ops.cond(math_ops.greater(
array_ops.shape(leaves)[0], 0), impurity, big)
示例11: insert
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def insert(self, ids, scores):
"""Insert the ids and scores into the TopN."""
with ops.control_dependencies(self.last_ops):
scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores)
larger_scores = math_ops.greater(scores, self.sl_scores[0])
def shortlist_insert():
larger_ids = array_ops.boolean_mask(
math_ops.to_int64(ids), larger_scores)
larger_score_values = array_ops.boolean_mask(scores, larger_scores)
shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids)
u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores)
return control_flow_ops.group(u1, u2)
# We only need to insert into the shortlist if there are any
# scores larger than the threshold.
cond_op = control_flow_ops.cond(
math_ops.reduce_any(larger_scores), shortlist_insert,
control_flow_ops.no_op)
with ops.control_dependencies([cond_op]):
self.last_ops = [scatter_op, cond_op]
示例12: _log_prob
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def _log_prob(self, event):
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
# inconsistent behavior for logits = inf/-inf.
event = ops.convert_to_tensor(event, name="event")
event = math_ops.cast(event, self.logits.dtype)
logits = self.logits
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
# so we do this here.
broadcast = lambda logits, event: (
array_ops.ones_like(event) * logits,
array_ops.ones_like(logits) * event)
# First check static shape.
if (event.get_shape().is_fully_defined() and
logits.get_shape().is_fully_defined()):
if event.get_shape() != logits.get_shape():
logits, event = broadcast(logits, event)
else:
logits, event = control_flow_ops.cond(
distribution_util.same_dynamic_shape(logits, event),
lambda: (logits, event),
lambda: broadcast(logits, event))
return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
示例13: _wrap_computation_in_while_loop_with_stopping_signals
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
def cond(scalar_stopping_signal):
return math_ops.logical_not(
_StopSignals.should_stop(scalar_stopping_signal))
def computation(unused_scalar_stopping_signal):
return_value = op_fn()
execute_ops = return_value['ops']
signals = return_value['signals']
with ops.control_dependencies(execute_ops):
return _StopSignals.as_scalar_stopping_signal(signals)
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with ops.device(device):
return control_flow_ops.while_loop(
cond,
computation, [_StopSignals.NON_STOPPING_SIGNAL],
parallel_iterations=1)
示例14: AddCrossEntropy
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def AddCrossEntropy(batch_size, n):
"""Adds a cross entropy cost function."""
cross_entropies = []
def _Pass():
return tf.constant(0, dtype=tf.float32, shape=[1])
for beam_id in range(batch_size):
beam_gold_slot = tf.reshape(
tf.strided_slice(n['gold_slot'], [beam_id], [beam_id + 1]), [1])
def _ComputeCrossEntropy():
"""Adds ops to compute cross entropy of the gold path in a beam."""
# Requires a cast so that UnsortedSegmentSum, in the gradient,
# is happy with the type of its input 'segment_ids', which
# must be int32.
idx = tf.cast(
tf.reshape(
tf.where(tf.equal(n['beam_ids'], beam_id)), [-1]), tf.int32)
beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
num = tf.shape(idx)
return tf.nn.softmax_cross_entropy_with_logits(
labels=tf.expand_dims(
tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0),
logits=beam_scores)
# The conditional here is needed to deal with the last few batches of the
# corpus which can contain -1 in beam_gold_slot for empty batch slots.
cross_entropies.append(cf.cond(
beam_gold_slot[0] >= 0, _ComputeCrossEntropy, _Pass))
return {'cross_entropy': tf.div(tf.add_n(cross_entropies), batch_size)}
示例15: initialize
# 需要导入模块: from tensorflow.python.ops import control_flow_ops [as 别名]
# 或者: from tensorflow.python.ops.control_flow_ops import cond [as 别名]
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)