当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.case方法代码示例

本文整理汇总了Python中tensorflow.case方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.case方法的具体用法?Python tensorflow.case怎么用?Python tensorflow.case使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.case方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: adjust_max

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def adjust_max(start, stop, start_value, stop_value, name=None):
  with tf.name_scope(name, "AdjustMax", [start, stop, name]) as name:
    global_step = tf.train.get_or_create_global_step()
    if global_step is not None:
      start = tf.convert_to_tensor(start, dtype=tf.int64)
      stop = tf.convert_to_tensor(stop, dtype=tf.int64)
      start_value = tf.convert_to_tensor(start_value, dtype=tf.float32)
      stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32)

      pred_fn_pairs = {}
      pred_fn_pairs[global_step <= start] = lambda: start_value
      pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay(
                                  start_value, global_step-start, stop-start,
                                  end_learning_rate=stop_value, power=1.0, cycle=False)
      default = lambda: stop_value
      return tf.case(pred_fn_pairs, default, exclusive=True)
    else:
      return None 
开发者ID:calico,项目名称:basenji,代码行数:20,代码来源:ops.py

示例2: video_features

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def video_features(
      self, all_frames, all_actions, all_rewards, all_raw_frames):
    """Optional video wide features.

      If the model requires access to all of the video frames
      (e.g. in case of approximating one latent for the whole video)
      override this function to add them. They will be accessible
      as video_features in next_frame function.

    Args:
      all_frames: list of all frames including input and target frames.
      all_actions: list of all actions including input and target actions.
      all_rewards: list of all rewards including input and target rewards.
      all_raw_frames: list of all raw frames (before modalities).

    Returns:
      video_features: a dictionary containing video-wide features.
    """
    del all_frames, all_actions, all_rewards, all_raw_frames
    return None 
开发者ID:yyht,项目名称:BERT,代码行数:22,代码来源:base.py

示例3: video_extra_loss

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def video_extra_loss(self, frames_predicted, frames_target,
                       internal_states, video_features):
    """Optional video wide extra loss.

      If the model needs to calculate some extra loss across all predicted
      frames (e.g. in case of video GANS loss) override this function.

    Args:
      frames_predicted: list of all predicted frames.
      frames_target: list of all target frames.
      internal_states: internal states of the video.
      video_features: video wide features coming from video_features function.

    Returns:
      extra_loss: extra video side loss.
    """
    del frames_predicted, frames_target, internal_states, video_features
    return 0.0 
开发者ID:yyht,项目名称:BERT,代码行数:20,代码来源:base.py

示例4: finish

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def finish(self):
    """Finishes transconding and returns the video.

    Returns:
      bytes

    Raises:
      IOError: in case of transcoding error.
    """
    if self.proc is None:
      return None
    self.proc.stdin.close()
    for thread in (self._out_thread, self._err_thread):
      thread.join()
    (out, err) = [
        b"".join(chunks) for chunks in (self._out_chunks, self._err_chunks)
    ]
    self.proc.stdout.close()
    self.proc.stderr.close()
    if self.proc.returncode:
      err = "\n".join([" ".join(self.cmd), err.decode("utf8")])
      raise IOError(err)
    del self.proc
    self.proc = None
    return out 
开发者ID:yyht,项目名称:BERT,代码行数:27,代码来源:common_video.py

示例5: _bucketize

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def _bucketize(instances, feature, schema, metadata):
  """Applies the bucketize transform to a numeric field.
  """
  field = schema[feature.field]
  if not field.numeric:
    raise ValueError('A scale transform cannot be applied to non-numerical field "%s".' %
                     feature.field)

  transform = feature.transform
  boundaries = map(float, transform['boundaries'].split(','))

  # TODO: Figure out how to use tf.case instead of this contrib op
  from tensorflow.contrib.layers.python.ops.bucketization_op import bucketize

  # Create a one-hot encoded tensor. The dimension of this tensor is the set of buckets defined
  # by N boundaries == N + 1.
  # A squeeze is needed to remove the extra dimension added to the shape.
  value = instances[feature.field]

  value = tf.squeeze(tf.one_hot(bucketize(value, boundaries, name='bucket'),
                                depth=len(boundaries) + 1, on_value=1.0, off_value=0.0,
                                name='one_hot'),
                     axis=1, name='bucketize')
  value.set_shape((None, len(boundaries) + 1))
  return value 
开发者ID:TensorLab,项目名称:tensorfx,代码行数:27,代码来源:_transforms.py

示例6: _get_qmodel_quantities

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def _get_qmodel_quantities(self, grads_and_vars):

    # Compute "preconditioned gradient".
    precon_grads_and_vars = self._multiply_preconditioner(grads_and_vars)

    var_list = tuple(var for (_, var) in grads_and_vars)
    prev_updates_and_vars = self._compute_prev_updates(var_list)

    # While it might seem like this call performs needless computations
    # involving prev_updates_and_vars in the case where it is zero, because
    # we extract out only the part of the solution that is not zero the rest
    # of it will not actually be computed by TensorFlow (I think).
    m, c, b = self._compute_qmodel(
        precon_grads_and_vars, prev_updates_and_vars, grads_and_vars)

    return precon_grads_and_vars, m, c, b 
开发者ID:tensorflow,项目名称:kfac,代码行数:18,代码来源:optimizer.py

示例7: finish

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def finish(self):
    """Finishes transconding and returns the video.

    Returns:
      bytes

    Raises:
      IOError: in case of transcoding error.
    """
    if self.proc is None:
      return None
    self.proc.stdin.close()
    for thread in (self._out_thread, self._err_thread):
      thread.join()
    (out, err) = [
        b"".join(chunks) for chunks in (self._out_chunks, self._err_chunks)
    ]
    if self.proc.returncode:
      err = "\n".join([" ".join(self.cmd), err.decode("utf8")])
      raise IOError(err)
    del self.proc
    self.proc = None
    return out 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:25,代码来源:common_video.py

示例8: piecewise_function

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def piecewise_function(param, values, changepoints, name=None,
                       dtype=tf.float32):
    """Compute a piecewise function.

    Arguments:
        param: The function parameter.
        values: List of function values (numbers or tensors).
        changepoints: Sorted list of points where the function changes from
            one value to the next. Must be one item shorter than `values`.
    """

    if len(changepoints) != len(values) - 1:
        raise ValueError("changepoints has length {}, expected {} (values "
                         "has length {})".format(len(changepoints),
                                                 len(values) - 1,
                                                 len(values)))

    with tf.name_scope(name, "PiecewiseFunction",
                       [param, values, changepoints]) as s_name:
        values = [tf.convert_to_tensor(y, dtype=dtype) for y in values]
        # this is a trick to make each lambda return a different y:
        lambdas = [lambda y=y: y for y in values]
        predicates = [tf.less(param, x) for x in changepoints]
        return tf.case(list(zip(predicates, lambdas[:-1])), lambdas[-1],
                       name=s_name) 
开发者ID:ufal,项目名称:neuralmonkey,代码行数:27,代码来源:functions.py

示例9: adjust_max

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def adjust_max(start, stop, start_value, stop_value, name=None):
    with ops.name_scope(name, "AdjustMax",
                        [start, stop, name]) as name:
        global_step = tf.train.get_global_step()
        if global_step is not None:
            start = tf.convert_to_tensor(start, dtype=tf.int64)
            stop = tf.convert_to_tensor(stop, dtype=tf.int64)
            start_value = tf.convert_to_tensor(start_value, dtype=tf.float32)
            stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32)

            pred_fn_pairs = {}
            pred_fn_pairs[global_step <= start] = lambda: start_value
            pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay(
                                        start_value, global_step-start, stop-start,
                                        end_learning_rate=stop_value, power=1.0, cycle=False)
            default = lambda: stop_value
            return tf.case(pred_fn_pairs, default, exclusive=True)
        else:
            return None 
开发者ID:shiyemin,项目名称:shuttleNet,代码行数:21,代码来源:ops.py

示例10: permute_rgb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def permute_rgb(inputs, permute, data_format='channels_last'):
    assert Conv2dUtilities.has_valid_shape(inputs)
    assert Conv2dUtilities.number_of_channels(inputs, data_format) == 3
    
    def _permute_rgb(inputs, permutation):
      channel_axis = Conv2dUtilities.channel_axis(inputs, data_format)
      result = tf.split(inputs, [1, 1, 1], channel_axis)
      result = tf.concat([result[permutation[0]], result[permutation[1]], result[permutation[2]]], channel_axis)
      return result
    
    cases =[
        (tf.equal(permute, 1), lambda: _permute_rgb(inputs, [0, 2, 1])),
        (tf.equal(permute, 2), lambda: _permute_rgb(inputs, [1, 0, 2])),
        (tf.equal(permute, 3), lambda: _permute_rgb(inputs, [1, 2, 0])),
        (tf.equal(permute, 4), lambda: _permute_rgb(inputs, [2, 0, 1])),
        (tf.equal(permute, 5), lambda: _permute_rgb(inputs, [2, 1, 0]))]
    inputs = tf.case(cases, default=lambda: inputs, exclusive=True)
    
    return inputs 
开发者ID:DeepBlender,项目名称:DeepDenoiser,代码行数:21,代码来源:DataAugmentation.py

示例11: _augment

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def _augment(image):
        """ 对图片进行数据增强:饱和度,对比度, 亮度,加噪
        :param image: 待增强图片 (H, W, ?)
        :return:
        """
        image = DatasetUtil._add_noise(image)
        # 数据增强顺序
        color_ordering = tf.random_uniform([], minval=0, maxval=4, dtype=tf.int32)
        image = tf.case(pred_fn_pairs=[(tf.equal(color_ordering, 0),
                                        lambda: DatasetUtil._augment_cond_0(image)),
                                       (tf.equal(color_ordering, 1),
                                        lambda: DatasetUtil._augment_cond_1(image)),
                                       (tf.equal(color_ordering, 2),
                                        lambda: DatasetUtil._augment_cond_2(image))],
                        default=lambda: image)
        image = tf.clip_by_value(image, 0.0, 1.0)  # 防止数据增强越界
        return image 
开发者ID:zheng-yuwei,项目名称:multi-label-classification,代码行数:19,代码来源:dataset_util.py

示例12: _metric_key

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def _metric_key(self, base_key: Text) -> Text:
    """Constructs a metric key, including user-specified prefix if necessary.

    In cases with multi-headed models, an evaluation may need multiple instances
    of the same metric for different predictions and/or labels. To support this
    case, the metric should be named with the specified label to disambiguate
    between the two (and prevent key collisions).

    Args:
      base_key: The original key for the metric, often from metric_keys.

    Returns:
      Either the base key, or the key augmented with a specified tag or label.
    """
    if self._metric_tag:
      return metric_keys.tagged_key(base_key, self._metric_tag)
    return base_key 
开发者ID:tensorflow,项目名称:model-analysis,代码行数:19,代码来源:post_export_metrics.py

示例13: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def __init__(self,
               example_weight_key: Optional[Text] = None,
               target_prediction_keys: Optional[List[Text]] = None,
               labels_key: Optional[Text] = None,
               metric_tag: Optional[Text] = None,
               tensor_index: Optional[int] = None):
    """Create a metric that computes calibration.

    Args:
      example_weight_key: The key of the example weight column in the features
        dict. If None, all predictions are given a weight of 1.0.
      target_prediction_keys: If provided, the prediction keys to look for in
        order.
      labels_key: If provided, a custom label key.
      metric_tag: If provided, a custom metric tag. Only necessary to
        disambiguate instances of the same metric on different predictions.
      tensor_index: Optional index to specify class predictions to calculate
        metrics on in the case of multi-class models.
    """

    self._example_weight_key = example_weight_key
    super(_Calibration, self).__init__(
        target_prediction_keys=target_prediction_keys,
        labels_key=labels_key,
        metric_tag=metric_tag) 
开发者ID:tensorflow,项目名称:model-analysis,代码行数:27,代码来源:post_export_metrics.py

示例14: build_feed_dict

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
                      keep_prob=None):
    """Build the feed dictionary, handles cases where there is no value defined.

    Args:
      train_name: The key into the datasets, to set the tf.case statement for
        the proper readin / readout matrices.
      data_bxtxd: The data tensor
      ext_input_bxtxi (optional): The external input tensor
      keep_prob: The drop out keep probability.

    Returns:
      The feed dictionary with TF tensors as keys and data as values, for use
      with tf.Session.run()

    """
    feed_dict = {}
    B, T, _ = data_bxtxd.shape
    feed_dict[self.dataName] = train_name
    feed_dict[self.dataset_ph] = data_bxtxd

    if self.ext_input is not None and ext_input_bxtxi is not None:
      feed_dict[self.ext_input] = ext_input_bxtxi

    if keep_prob is None:
      feed_dict[self.keep_prob] = self.hps.keep_prob
    else:
      feed_dict[self.keep_prob] = keep_prob

    return feed_dict 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:32,代码来源:lfads.py

示例15: piecewise_constant

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import case [as 别名]
def piecewise_constant(x, boundaries, values):
    """ Piecewise constant function.

    Arguments:
        x: A 0-D Tensor.
        boundaries: A 1-D NumPy array with strictly increasing entries.
        values: A 1-D NumPy array that specifies the values for the intervals
            defined by `boundaries`. (It should therefore have one more entry
            than `boundaries`.)

    Returns: A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
        `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., and
        values[-1] when `x > boundaries[-1]`.
    """

    pred_fn_pairs = {}
    pred_fn_pairs[x <= boundaries[0]] = lambda: tf.constant(values[0])
    pred_fn_pairs[x > boundaries[-1]] = lambda: tf.constant(values[-1])
    for lower, upper, value in zip(boundaries[:-1],
                                   boundaries[1:],
                                   values[1:-1]):
        # We need to bind value here; can do this with lambda value=value: ...
        pred = (x > lower) & (x <= upper)
        pred_fn_pairs[pred] = lambda value=value: tf.constant(value)

    return tf.case(pred_fn_pairs, lambda: tf.constant(values[0]),
                   exclusive=True) 
开发者ID:rdipietro,项目名称:miccai-2016-surgical-activity-rec,代码行数:29,代码来源:optimizers.py


注:本文中的tensorflow.case方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。