当前位置: 首页>>代码示例>>Python>>正文


Python check_ops.assert_rank方法代码示例

本文整理汇总了Python中tensorflow.python.ops.check_ops.assert_rank方法的典型用法代码示例。如果您正苦于以下问题:Python check_ops.assert_rank方法的具体用法?Python check_ops.assert_rank怎么用?Python check_ops.assert_rank使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.ops.check_ops的用法示例。


在下文中一共展示了check_ops.assert_rank方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _sample_n

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _sample_n(self, n, seed=None):
    n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
    if self.total_count.get_shape().ndims is not None:
      if self.total_count.get_shape().ndims != 0:
        raise NotImplementedError(
            "Sample only supported for scalar number of draws.")
    elif self.validate_args:
      is_scalar = check_ops.assert_rank(
          n_draws, 0,
          message="Sample only supported for scalar number of draws.")
      n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
    k = self.event_shape_tensor()[0]
    # Flatten batch dims so logits has shape [B, k],
    # where B = reduce_prod(self.batch_shape_tensor()).
    draws = random_ops.multinomial(
        logits=array_ops.reshape(self.logits, [-1, k]),
        num_samples=n * n_draws,
        seed=seed)
    draws = array_ops.reshape(draws, shape=[-1, n, n_draws])
    x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                            axis=-2)  # shape: [B, n, k]
    x = array_ops.transpose(x, perm=[1, 0, 2])
    final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
    return array_ops.reshape(x, final_shape) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:26,代码来源:multinomial.py

示例2: _assert_non_negative_int32_scalar

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _assert_non_negative_int32_scalar(self, x):
    """Helper which ensures that input is a non-negative, int32, scalar."""
    x = ops.convert_to_tensor(x, name="x")
    if x.dtype.base_dtype != dtypes.int32.base_dtype:
      raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32))
    x_value_static = tensor_util.constant_value(x)
    if x.get_shape().ndims is not None and x_value_static is not None:
      if x.get_shape().ndims != 0:
        raise ValueError("%s.ndims=%d is not 0 (scalar)" %
                         (x.name, x.get_shape().ndims))
      if x_value_static < 0:
        raise ValueError("%s.value=%d cannot be negative" %
                         (x.name, x_value_static))
      return x
    if self.validate_args:
      x = control_flow_ops.with_dependencies([
          check_ops.assert_rank(x, 0),
          check_ops.assert_non_negative(x)], x)
    return x 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:21,代码来源:shape.py

示例3: _sample_n

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _sample_n(self, n, seed=None):
    n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
    if self.n.get_shape().ndims is not None:
      if self.n.get_shape().ndims != 0:
        raise NotImplementedError(
            "Sample only supported for scalar number of draws.")
    elif self.validate_args:
      is_scalar = check_ops.assert_rank(
          n_draws, 0,
          message="Sample only supported for scalar number of draws.")
      n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
    k = self.event_shape()[0]
    # Flatten batch dims so logits has shape [B, k],
    # where B = reduce_prod(self.batch_shape()).
    logits = array_ops.reshape(self.logits, [-1, k])
    draws = random_ops.multinomial(logits=logits,
                                   num_samples=n * n_draws,
                                   seed=seed)
    draws = array_ops.reshape(draws, shape=[-1, n, n_draws])
    x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                            reduction_indices=-2)  # shape: [B, n, k]
    x = array_ops.transpose(x, perm=[1, 0, 2])
    final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
    return array_ops.reshape(x, final_shape) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:26,代码来源:multinomial.py

示例4: _check_labels

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _check_labels(labels, expected_labels_dimension):
  """Check labels type and shape."""
  with ops.name_scope(None, 'labels', (labels,)) as scope:
    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    if isinstance(labels, sparse_tensor.SparseTensor):
      raise ValueError('SparseTensor labels are not supported.')
    labels_shape = array_ops.shape(labels)
    err_msg = 'labels shape must be [batch_size, {}]'.format(
        expected_labels_dimension)
    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
    with ops.control_dependencies([assert_rank]):
      static_shape = labels.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_labels_dimension):
          raise ValueError(
              'Mismatched label shape. '
              'Classifier configured with n_classes=%s.  Received %s. '
              'Suggested Fix: check your n_classes argument to the estimator '
              'and/or the shape of your label.' %
              (expected_labels_dimension, dim1))
      assert_dimension = check_ops.assert_equal(
          expected_labels_dimension, labels_shape[1], message=err_msg)
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(labels, name=scope) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:27,代码来源:head.py

示例5: _check_logits

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _check_logits(logits, expected_logits_dimension):
  """Check logits type and shape."""
  with ops.name_scope(None, 'logits', (logits,)) as scope:
    logits = math_ops.to_float(logits)
    logits_shape = array_ops.shape(logits)
    assert_rank = check_ops.assert_rank(
        logits, 2, data=[logits_shape],
        message='logits shape must be [batch_size, logits_dimension]')
    with ops.control_dependencies([assert_rank]):
      static_shape = logits.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_logits_dimension):
          raise ValueError(
              'logits shape must be [batch_size, logits_dimension], got %s.' %
              (static_shape,))
      assert_dimension = check_ops.assert_equal(
          expected_logits_dimension, logits_shape[1], data=[logits_shape],
          message='logits shape must be [batch_size, logits_dimension]')
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(logits, name=scope) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:23,代码来源:head.py

示例6: _sample_n

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _sample_n(self, n, seed=None):
    n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
    if self.total_count.get_shape().ndims is not None:
      if self.total_count.get_shape().ndims != 0:
        raise NotImplementedError(
            "Sample only supported for scalar number of draws.")
    elif self.validate_args:
      is_scalar = check_ops.assert_rank(
          n_draws, 0,
          message="Sample only supported for scalar number of draws.")
      n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
    k = self.event_shape_tensor()[0]
    # Flatten batch dims so logits has shape [B, k],
    # where B = reduce_prod(self.batch_shape_tensor()).
    x = random_ops.multinomial(
        logits=array_ops.reshape(self.logits, [-1, k]),
        num_samples=n * n_draws,
        seed=seed)
    x = array_ops.reshape(x, shape=[-1, n, n_draws])
    x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
                            axis=-2)  # shape: [B, n, k]
    x = array_ops.transpose(x, perm=[1, 0, 2])
    final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
    x = array_ops.reshape(x, final_shape)
    return math_ops.cast(x, self.dtype) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:27,代码来源:multinomial.py

示例7: _check_num_rows_possibly_add_asserts

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _check_num_rows_possibly_add_asserts(self):
    """Static check of init arg `num_rows`, possibly add asserts."""
    # Possibly add asserts.
    if self._assert_proper_shapes:
      self._num_rows = control_flow_ops.with_dependencies(
          [
              check_ops.assert_rank(
                  self._num_rows,
                  0,
                  message="Argument num_rows must be a 0-D Tensor."),
              check_ops.assert_non_negative(
                  self._num_rows,
                  message="Argument num_rows must be non-negative."),
          ],
          self._num_rows)

    # Static checks.
    if not self._num_rows.dtype.is_integer:
      raise TypeError("Argument num_rows must be integer type.  Found:"
                      " %s" % self._num_rows)

    num_rows_static = self._num_rows_static

    if num_rows_static is None:
      return  # Cannot do any other static checks.

    if num_rows_static.ndim != 0:
      raise ValueError("Argument num_rows must be a 0-D Tensor.  Found:"
                       " %s" % num_rows_static)

    if num_rows_static < 0:
      raise ValueError("Argument num_rows must be non-negative.  Found:"
                       " %s" % num_rows_static) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:35,代码来源:linear_operator_identity.py

示例8: _check_batch_shape_possibly_add_asserts

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _check_batch_shape_possibly_add_asserts(self):
    """Static check of init arg `batch_shape`, possibly add asserts."""
    if self._batch_shape_arg is None:
      return

    # Possibly add asserts
    if self._assert_proper_shapes:
      self._batch_shape_arg = control_flow_ops.with_dependencies(
          [
              check_ops.assert_rank(
                  self._batch_shape_arg,
                  1,
                  message="Argument batch_shape must be a 1-D Tensor."),
              check_ops.assert_non_negative(
                  self._batch_shape_arg,
                  message="Argument batch_shape must be non-negative."),
          ],
          self._batch_shape_arg)

    # Static checks
    if not self._batch_shape_arg.dtype.is_integer:
      raise TypeError("Argument batch_shape must be integer type.  Found:"
                      " %s" % self._batch_shape_arg)

    if self._batch_shape_static is None:
      return  # Cannot do any other static checks.

    if self._batch_shape_static.ndim != 1:
      raise ValueError("Argument batch_shape must be a 1-D Tensor.  Found:"
                       " %s" % self._batch_shape_static)

    if np.any(self._batch_shape_static < 0):
      raise ValueError("Argument batch_shape must be non-negative.  Found:"
                       "%s" % self._batch_shape_static) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:36,代码来源:linear_operator_identity.py

示例9: _sample_n

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _sample_n(self, n, seed=None):
    n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
    if self.n.get_shape().ndims is not None:
      if self.n.get_shape().ndims != 0:
        raise NotImplementedError(
            "Sample only supported for scalar number of draws.")
    elif self.validate_args:
      is_scalar = check_ops.assert_rank(
          n_draws, 0,
          message="Sample only supported for scalar number of draws.")
      n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
    k = self.event_shape()[0]
    unnormalized_logits = array_ops.reshape(
        math_ops.log(random_ops.random_gamma(
            shape=[n],
            alpha=self.alpha,
            dtype=self.dtype,
            seed=seed)),
        shape=[-1, k])
    draws = random_ops.multinomial(
        logits=unnormalized_logits,
        num_samples=n_draws,
        seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
    x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                            reduction_indices=-2)
    final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
    return array_ops.reshape(x, final_shape) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:29,代码来源:dirichlet_multinomial.py

示例10: _check_shapes_dynamic

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.name_scope("check_shapes", values=[operator, v, diag]):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(
          v, r_op, message="v is not the same rank as operator."))
      if diag is not None:
        checks.append(check_ops.assert_rank(
            diag, r_op - 1, message="diag is not the same rank as operator."))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.strided_slice(s_v, [0], [r_v - 2]),
          message="v does not have same batch shape as operator."))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.strided_slice(
                s_d, [0], [r_d - 1]),
            message="diag does not have same batch shape as operator."))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2),
          message="v does not have same event shape as operator."))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1),
            message="diag does not have same event shape as v."))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:43,代码来源:operator_pd_vdvt_update.py

示例11: _check_shapes_dynamic

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.name_scope("check_shapes", values=[operator, v, diag]):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(v, r_op))
      if diag is not None:
        checks.append(check_ops.assert_rank(diag, r_op - 1))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:36,代码来源:operator_pd_vdvt_update.py

示例12: _maybe_validate_shape_override

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                     validate_args, name):
    """Helper to __init__ which ensures override batch/event_shape are valid."""
    if override_shape is None:
      override_shape = []

    override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32,
                                           name=name)

    if not override_shape.dtype.is_integer:
      raise TypeError("shape override must be an integer")

    override_is_scalar = _is_scalar_from_shape(override_shape)
    if tensor_util.constant_value(override_is_scalar):
      return self._empty

    dynamic_assertions = []

    if override_shape.get_shape().ndims is not None:
      if override_shape.get_shape().ndims != 1:
        raise ValueError("shape override must be a vector")
    elif validate_args:
      dynamic_assertions += [check_ops.assert_rank(
          override_shape, 1,
          message="shape override must be a vector")]

    if tensor_util.constant_value(override_shape) is not None:
      if any(s <= 0 for s in tensor_util.constant_value(override_shape)):
        raise ValueError("shape override must have positive elements")
    elif validate_args:
      dynamic_assertions += [check_ops.assert_positive(
          override_shape,
          message="shape override must have positive elements")]

    is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
                                     _logical_not(override_is_scalar))
    if tensor_util.constant_value(is_both_nonscalar) is not None:
      if tensor_util.constant_value(is_both_nonscalar):
        raise ValueError("base distribution not scalar")
    elif validate_args:
      dynamic_assertions += [check_ops.assert_equal(
          is_both_nonscalar, False,
          message="base distribution not scalar")]

    if not dynamic_assertions:
      return override_shape
    return control_flow_ops.with_dependencies(
        dynamic_assertions, override_shape) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:50,代码来源:transformed_distribution.py

示例13: dict_to_state_tuple

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def dict_to_state_tuple(input_dict, cell):
  """Reconstructs nested `state` from a dict containing state `Tensor`s.

  Args:
    input_dict: A dict of `Tensor`s.
    cell: An instance of `RNNCell`.
  Returns:
    If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
    where `n` is the number of nested entries in `cell.state_size`, this
    function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
    is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
    tuple.
  Raises:
    ValueError: State is partially specified. The `input_dict` must contain
      values for all state components or none at all.
  """
  flat_state_sizes = nest.flatten(cell.state_size)
  state_tensors = []
  with ops.name_scope('dict_to_state_tuple'):
    for i, state_size in enumerate(flat_state_sizes):
      state_name = _get_state_name(i)
      state_tensor = input_dict.get(state_name)
      if state_tensor is not None:
        rank_check = check_ops.assert_rank(
            state_tensor, 2, name='check_state_{}_rank'.format(i))
        shape_check = check_ops.assert_equal(
            array_ops.shape(state_tensor)[1],
            state_size,
            name='check_state_{}_shape'.format(i))
        with ops.control_dependencies([rank_check, shape_check]):
          state_tensor = array_ops.identity(state_tensor, name=state_name)
        state_tensors.append(state_tensor)
    if not state_tensors:
      return None
    elif len(state_tensors) == len(flat_state_sizes):
      dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
      return nest.pack_sequence_as(dummy_state, state_tensors)
    else:
      raise ValueError(
          'RNN state was partially specified.'
          'Expected zero or {} state Tensors; got {}'.
          format(len(flat_state_sizes), len(state_tensors))) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:44,代码来源:dynamic_rnn_estimator.py

示例14: _concatenate_context_input

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _concatenate_context_input(sequence_input, context_input):
  """Replicates `context_input` across all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.
    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
  seq_rank_check = check_ops.assert_rank(
      sequence_input,
      3,
      message='sequence_input must have rank 3',
      data=[array_ops.shape(sequence_input)])
  seq_type_check = check_ops.assert_type(
      sequence_input,
      dtypes.float32,
      message='sequence_input must have dtype float32; got {}.'.format(
          sequence_input.dtype))
  ctx_rank_check = check_ops.assert_rank(
      context_input,
      2,
      message='context_input must have rank 2',
      data=[array_ops.shape(context_input)])
  ctx_type_check = check_ops.assert_type(
      context_input,
      dtypes.float32,
      message='context_input must have dtype float32; got {}.'.format(
          context_input.dtype))
  with ops.control_dependencies(
      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
    padded_length = array_ops.shape(sequence_input)[1]
    tiled_context_input = array_ops.tile(
        array_ops.expand_dims(context_input, 1),
        array_ops.concat([[1], [padded_length], [1]], 0))
  return array_ops.concat([sequence_input, tiled_context_input], 2) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:49,代码来源:dynamic_rnn_estimator.py

示例15: _concatenate_context_input

# 需要导入模块: from tensorflow.python.ops import check_ops [as 别名]
# 或者: from tensorflow.python.ops.check_ops import assert_rank [as 别名]
def _concatenate_context_input(sequence_input, context_input):
  """Replicates `context_input` accross all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.
    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
  seq_rank_check = check_ops.assert_rank(
      sequence_input,
      3,
      message='sequence_input must have rank 3',
      data=[array_ops.shape(sequence_input)])
  seq_type_check = check_ops.assert_type(
      sequence_input,
      dtypes.float32,
      message='sequence_input must have dtype float32; got {}.'.format(
          sequence_input.dtype))
  ctx_rank_check = check_ops.assert_rank(
      context_input,
      2,
      message='context_input must have rank 2',
      data=[array_ops.shape(context_input)])
  ctx_type_check = check_ops.assert_type(
      context_input,
      dtypes.float32,
      message='context_input must have dtype float32; got {}.'.format(
          context_input.dtype))
  with ops.control_dependencies(
      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
    padded_length = array_ops.shape(sequence_input)[1]
    tiled_context_input = array_ops.tile(
        array_ops.expand_dims(context_input, 1),
        array_ops.concat([[1], [padded_length], [1]], 0))
  return array_ops.concat([sequence_input, tiled_context_input], 2) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:49,代码来源:dynamic_rnn_estimator.py


注:本文中的tensorflow.python.ops.check_ops.assert_rank方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。