当前位置: 首页>>代码示例>>Python>>正文


Python smart_cond.smart_cond函数代码示例

本文整理汇总了Python中tensorflow.python.framework.smart_cond.smart_cond函数的典型用法代码示例。如果您正苦于以下问题:Python smart_cond函数的具体用法?Python smart_cond怎么用?Python smart_cond使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了smart_cond函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: categorical_crossentropy

def categorical_crossentropy(y_true,
                             y_pred,
                             from_logits=False,
                             label_smoothing=0):
  """Computes the categorical crossentropy loss.

  Args:
    y_true: tensor of true targets.
    y_pred: tensor of predicted targets.
    from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
      we assume that `y_pred` encodes a probability distribution.
    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.

  Returns:
    Categorical crossentropy loss value.
  """
  y_pred = ops.convert_to_tensor(y_pred)
  y_true = math_ops.cast(y_true, y_pred.dtype)
  label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())

  def _smooth_labels():
    num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
    return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:losses.py

示例2: _contraction

  def _contraction():
    """Performs a contraction."""
    contracted = face_centroid - contraction * (face_centroid -
                                                simplex[worst_index])
    objective_at_contracted = objective_function(contracted)
    is_contracted_acceptable = objective_at_contracted <= worst_objective_value
    def _accept_contraction():
      next_simplex = _replace_at_index(simplex, worst_index, contracted)
      objective_at_next_simplex = _replace_at_index(
          objective_values,
          worst_index,
          objective_at_contracted)
      return (
          False,
          next_simplex,
          objective_at_next_simplex,
          1
      )

    def _reject_contraction():
      return _shrink_towards_best(objective_function, simplex, best_index,
                                  shrinkage, batch_evaluate_objective)

    return smart_cond.smart_cond(is_contracted_acceptable,
                                 _accept_contraction,
                                 _reject_contraction)
开发者ID:lewisKit,项目名称:probability,代码行数:26,代码来源:nelder_mead.py

示例3: write

def write(tag, tensor, step=None, metadata=None, name=None):
  """Writes a generic summary to the default SummaryWriter if one exists.

  This exists primarily to support the definition of type-specific summary ops
  like scalar() and image(), and is not intended for direct use unless defining
  a new type-specific summary op.

  Args:
    tag: string tag used to identify the summary (e.g. in TensorBoard), usually
      generated with `tf.summary.summary_scope`
    tensor: the Tensor holding the summary data to write
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    metadata: Optional SummaryMetadata, as a proto or serialized bytes
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
  with ops.name_scope(name, "write_summary") as scope:
    if context.context().summary_writer is None:
      return constant_op.constant(False)
    if step is None:
      step = get_step()
      if step is None:
        raise ValueError("No step set via 'step' argument or "
                         "tf.summary.experimental.set_step()")
    if metadata is None:
      serialized_metadata = b""
    elif hasattr(metadata, "SerializeToString"):
      serialized_metadata = metadata.SerializeToString()
    else:
      serialized_metadata = metadata

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        write_summary_op = gen_summary_ops.write_summary(
            context.context().summary_writer._resource,  # pylint: disable=protected-access
            step,
            array_ops.identity(tensor),
            tag,
            serialized_metadata,
            name=scope)
        with ops.control_dependencies([write_summary_op]):
          return constant_op.constant(True)

    with ops.device("cpu:0"):
      op = smart_cond.smart_cond(
          _should_record_summaries_v2(), record, _nothing, name="summary_cond")
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
      return op
开发者ID:aritratony,项目名称:tensorflow,代码行数:60,代码来源:summary_ops_v2.py

示例4: summary_writer_function

def summary_writer_function(name, tensor, function, family=None):
  """Helper function to write summaries.

  Args:
    name: name of the summary
    tensor: main tensor to form the summary
    function: function taking a tag and a scope which writes the summary
    family: optional, the summary's family

  Returns:
    The result of writing the summary.
  """
  name_scope = ops.get_name_scope()
  if name_scope:
    # Add a slash to allow reentering the name scope.
    name_scope += "/"
  def record():
    with ops.name_scope(name_scope), summary_op_util.summary_scope(
        name, family, values=[tensor]) as (tag, scope):
      with ops.control_dependencies([function(tag, scope)]):
        return constant_op.constant(True)

  if context.context().summary_writer_resource is None:
    return control_flow_ops.no_op()
  with ops.device("cpu:0"):
    op = smart_cond.smart_cond(
        should_record_summaries(), record, _nothing, name="")
    if not context.executing_eagerly():
      ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
  return op
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:summary_ops_v2.py

示例5: testSmartCondTrue

 def testSmartCondTrue(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(2)
       y = constant_op.constant(5)
       z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
                                 lambda: math_ops.multiply(y, 5))
       self.assertEqual(z.eval(), 32)
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py

示例6: testUnknown

 def testUnknown(self):
   with ops.Graph().as_default():
     with session.Session():
       x = array_ops.placeholder(dtype=dtypes.int32)
       y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
                                 lambda: constant_op.constant(2))
       self.assertEqual(y.eval(feed_dict={x: 1}), 1)
       self.assertEqual(y.eval(feed_dict={x: -1}), 2)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py

示例7: testSmartCondFalse

 def testSmartCondFalse(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(4)
       y = constant_op.constant(3)
       z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
                                 lambda: math_ops.multiply(y, 3))
       self.assertEqual(z.eval(), 9)
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py

示例8: testPlaceholderWithDefault

 def testPlaceholderWithDefault(self):
   with ops.Graph().as_default():
     with session.Session():
       x = array_ops.placeholder_with_default(1, shape=())
       y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
                                 lambda: constant_op.constant(2))
       self.assertEqual(y.eval(), 1)
       self.assertEqual(y.eval(feed_dict={x: -1}), 2)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py

示例9: binary_crossentropy

def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):

  def _smooth_labels():
    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:losses.py

示例10: testEval

 def testEval(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(1)
       y = constant_op.constant(2)
       # x * y > 0 can be evaluated at graph construction time, so the false
       # branch shouldn't be evaluated at all.
       z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
                                 raise_exception)
       self.assertEqual(z.eval(feed_dict={x: 1}), 1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:smart_cond_test.py

示例11: binary_crossentropy

def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):  # pylint: disable=missing-docstring
  y_pred = ops.convert_to_tensor(y_pred)
  y_true = math_ops.cast(y_true, y_pred.dtype)
  label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())

  def _smooth_labels():
    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:12,代码来源:losses.py

示例12: testEval

  def testEval(self):
    # Constant expression evaluation only works with the C API enabled.
    if not ops._USE_C_API: return

    with ops.Graph().as_default():
      with session.Session():
        x = constant_op.constant(1)
        y = constant_op.constant(2)
        # x * y > 0 can be evaluated at graph construction time, so the false
        # branch shouldn't be evaluated at all.
        z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
                                  raise_exception)
        self.assertEqual(z.eval(feed_dict={x: 1}), 1)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:13,代码来源:smart_cond_test.py

示例13: _maybe_convert_labels

def _maybe_convert_labels(y_true):
  """Converts binary labels into -1/1."""
  are_zeros = math_ops.equal(y_true, 0)
  are_ones = math_ops.equal(y_true, 1)
  is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))

  def _convert_binary_labels():
    # Convert the binary labels to -1 or 1.
    return 2. * y_true - 1.

  updated_y_true = smart_cond.smart_cond(is_binary,
                                         _convert_binary_labels, lambda: y_true)
  return updated_y_true
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:13,代码来源:losses.py

示例14: result

  def result(self, write_summary=True):
    """Returns the result of the Metric.

    Args:
      write_summary: bool indicating whether to feed the result to the summary
        before returning.
    Returns:
      aggregated metric as float.
    Raises:
      ValueError: if the optional argument is not bool
    """
    # Convert the boolean to tensor for tf.cond, if it is not.
    if not isinstance(write_summary, ops.Tensor):
      write_summary = ops.convert_to_tensor(write_summary)
    t = self.numer / self.denom
    def write_summary_f():
      summary_ops.scalar(name=self.name, tensor=t)
      return t
    smart_cond.smart_cond(write_summary,
                          write_summary_f,
                          lambda: t,
                          name="")
    return t
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:23,代码来源:metrics_impl.py

示例15: write_raw_pb

def write_raw_pb(tensor, step=None, name=None):
  """Writes a summary using raw `tf.compat.v1.Summary` protocol buffers.

  Experimental: this exists to support the usage of V1-style manual summary
  writing (via the construction of a `tf.compat.v1.Summary` protocol buffer)
  with the V2 summary writing API.

  Args:
    tensor: the string Tensor holding one or more serialized `Summary` protobufs
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
  with ops.name_scope(name, "write_raw_pb") as scope:
    if context.context().summary_writer is None:
      return constant_op.constant(False)
    if step is None:
      step = get_step()
      if step is None:
        raise ValueError("No step set via 'step' argument or "
                         "tf.summary.experimental.set_step()")

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        raw_summary_op = gen_summary_ops.write_raw_proto_summary(
            context.context().summary_writer._resource,  # pylint: disable=protected-access
            step,
            array_ops.identity(tensor),
            name=scope)
        with ops.control_dependencies([raw_summary_op]):
          return constant_op.constant(True)

    with ops.device("cpu:0"):
      op = smart_cond.smart_cond(
          _should_record_summaries_v2(), record, _nothing, name="summary_cond")
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
      return op
开发者ID:aritratony,项目名称:tensorflow,代码行数:49,代码来源:summary_ops_v2.py


注:本文中的tensorflow.python.framework.smart_cond.smart_cond函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。