本文整理汇总了Python中tensorflow.python.framework.smart_cond.smart_cond函数的典型用法代码示例。如果您正苦于以下问题:Python smart_cond函数的具体用法?Python smart_cond怎么用?Python smart_cond使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了smart_cond函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: categorical_crossentropy
def categorical_crossentropy(y_true,
y_pred,
from_logits=False,
label_smoothing=0):
"""Computes the categorical crossentropy loss.
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
Returns:
Categorical crossentropy loss value.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())
def _smooth_labels():
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
示例2: _contraction
def _contraction():
"""Performs a contraction."""
contracted = face_centroid - contraction * (face_centroid -
simplex[worst_index])
objective_at_contracted = objective_function(contracted)
is_contracted_acceptable = objective_at_contracted <= worst_objective_value
def _accept_contraction():
next_simplex = _replace_at_index(simplex, worst_index, contracted)
objective_at_next_simplex = _replace_at_index(
objective_values,
worst_index,
objective_at_contracted)
return (
False,
next_simplex,
objective_at_next_simplex,
1
)
def _reject_contraction():
return _shrink_towards_best(objective_function, simplex, best_index,
shrinkage, batch_evaluate_objective)
return smart_cond.smart_cond(is_contracted_acceptable,
_accept_contraction,
_reject_contraction)
示例3: write
def write(tag, tensor, step=None, metadata=None, name=None):
"""Writes a generic summary to the default SummaryWriter if one exists.
This exists primarily to support the definition of type-specific summary ops
like scalar() and image(), and is not intended for direct use unless defining
a new type-specific summary op.
Args:
tag: string tag used to identify the summary (e.g. in TensorBoard), usually
generated with `tf.summary.summary_scope`
tensor: the Tensor holding the summary data to write
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
metadata: Optional SummaryMetadata, as a proto or serialized bytes
name: Optional string name for this op.
Returns:
True on success, or false if no summary was written because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
with ops.name_scope(name, "write_summary") as scope:
if context.context().summary_writer is None:
return constant_op.constant(False)
if step is None:
step = get_step()
if step is None:
raise ValueError("No step set via 'step' argument or "
"tf.summary.experimental.set_step()")
if metadata is None:
serialized_metadata = b""
elif hasattr(metadata, "SerializeToString"):
serialized_metadata = metadata.SerializeToString()
else:
serialized_metadata = metadata
def record():
"""Record the actual summary and return True."""
# Note the identity to move the tensor to the CPU.
with ops.device("cpu:0"):
write_summary_op = gen_summary_ops.write_summary(
context.context().summary_writer._resource, # pylint: disable=protected-access
step,
array_ops.identity(tensor),
tag,
serialized_metadata,
name=scope)
with ops.control_dependencies([write_summary_op]):
return constant_op.constant(True)
with ops.device("cpu:0"):
op = smart_cond.smart_cond(
_should_record_summaries_v2(), record, _nothing, name="summary_cond")
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op
示例4: summary_writer_function
def summary_writer_function(name, tensor, function, family=None):
"""Helper function to write summaries.
Args:
name: name of the summary
tensor: main tensor to form the summary
function: function taking a tag and a scope which writes the summary
family: optional, the summary's family
Returns:
The result of writing the summary.
"""
name_scope = ops.get_name_scope()
if name_scope:
# Add a slash to allow reentering the name scope.
name_scope += "/"
def record():
with ops.name_scope(name_scope), summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
with ops.control_dependencies([function(tag, scope)]):
return constant_op.constant(True)
if context.context().summary_writer_resource is None:
return control_flow_ops.no_op()
with ops.device("cpu:0"):
op = smart_cond.smart_cond(
should_record_summaries(), record, _nothing, name="")
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op
示例5: testSmartCondTrue
def testSmartCondTrue(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
lambda: math_ops.multiply(y, 5))
self.assertEqual(z.eval(), 32)
示例6: testUnknown
def testUnknown(self):
with ops.Graph().as_default():
with session.Session():
x = array_ops.placeholder(dtype=dtypes.int32)
y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
lambda: constant_op.constant(2))
self.assertEqual(y.eval(feed_dict={x: 1}), 1)
self.assertEqual(y.eval(feed_dict={x: -1}), 2)
示例7: testSmartCondFalse
def testSmartCondFalse(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(4)
y = constant_op.constant(3)
z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
lambda: math_ops.multiply(y, 3))
self.assertEqual(z.eval(), 9)
示例8: testPlaceholderWithDefault
def testPlaceholderWithDefault(self):
with ops.Graph().as_default():
with session.Session():
x = array_ops.placeholder_with_default(1, shape=())
y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
lambda: constant_op.constant(2))
self.assertEqual(y.eval(), 1)
self.assertEqual(y.eval(feed_dict={x: -1}), 2)
示例9: binary_crossentropy
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
def _smooth_labels():
return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.mean(
K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
示例10: testEval
def testEval(self):
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)
y = constant_op.constant(2)
# x * y > 0 can be evaluated at graph construction time, so the false
# branch shouldn't be evaluated at all.
z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
raise_exception)
self.assertEqual(z.eval(feed_dict={x: 1}), 1)
示例11: binary_crossentropy
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): # pylint: disable=missing-docstring
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())
def _smooth_labels():
return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.mean(
K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
示例12: testEval
def testEval(self):
# Constant expression evaluation only works with the C API enabled.
if not ops._USE_C_API: return
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)
y = constant_op.constant(2)
# x * y > 0 can be evaluated at graph construction time, so the false
# branch shouldn't be evaluated at all.
z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
raise_exception)
self.assertEqual(z.eval(feed_dict={x: 1}), 1)
示例13: _maybe_convert_labels
def _maybe_convert_labels(y_true):
"""Converts binary labels into -1/1."""
are_zeros = math_ops.equal(y_true, 0)
are_ones = math_ops.equal(y_true, 1)
is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))
def _convert_binary_labels():
# Convert the binary labels to -1 or 1.
return 2. * y_true - 1.
updated_y_true = smart_cond.smart_cond(is_binary,
_convert_binary_labels, lambda: y_true)
return updated_y_true
示例14: result
def result(self, write_summary=True):
"""Returns the result of the Metric.
Args:
write_summary: bool indicating whether to feed the result to the summary
before returning.
Returns:
aggregated metric as float.
Raises:
ValueError: if the optional argument is not bool
"""
# Convert the boolean to tensor for tf.cond, if it is not.
if not isinstance(write_summary, ops.Tensor):
write_summary = ops.convert_to_tensor(write_summary)
t = self.numer / self.denom
def write_summary_f():
summary_ops.scalar(name=self.name, tensor=t)
return t
smart_cond.smart_cond(write_summary,
write_summary_f,
lambda: t,
name="")
return t
示例15: write_raw_pb
def write_raw_pb(tensor, step=None, name=None):
"""Writes a summary using raw `tf.compat.v1.Summary` protocol buffers.
Experimental: this exists to support the usage of V1-style manual summary
writing (via the construction of a `tf.compat.v1.Summary` protocol buffer)
with the V2 summary writing API.
Args:
tensor: the string Tensor holding one or more serialized `Summary` protobufs
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
name: Optional string name for this op.
Returns:
True on success, or false if no summary was written because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
with ops.name_scope(name, "write_raw_pb") as scope:
if context.context().summary_writer is None:
return constant_op.constant(False)
if step is None:
step = get_step()
if step is None:
raise ValueError("No step set via 'step' argument or "
"tf.summary.experimental.set_step()")
def record():
"""Record the actual summary and return True."""
# Note the identity to move the tensor to the CPU.
with ops.device("cpu:0"):
raw_summary_op = gen_summary_ops.write_raw_proto_summary(
context.context().summary_writer._resource, # pylint: disable=protected-access
step,
array_ops.identity(tensor),
name=scope)
with ops.control_dependencies([raw_summary_op]):
return constant_op.constant(True)
with ops.device("cpu:0"):
op = smart_cond.smart_cond(
_should_record_summaries_v2(), record, _nothing, name="summary_cond")
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op