当前位置: 首页>>代码示例>>Python>>正文


Python check_ops.assert_rank函数代码示例

本文整理汇总了Python中tensorflow.python.ops.check_ops.assert_rank函数的典型用法代码示例。如果您正苦于以下问题:Python assert_rank函数的具体用法?Python assert_rank怎么用?Python assert_rank使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了assert_rank函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: calculate_reshape

def calculate_reshape(original_shape, new_shape, validate=False, name=None):
  """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
  batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
  if batch_shape_static.is_fully_defined():
    return np.int32(batch_shape_static.as_list()), batch_shape_static, []
  with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
    original_size = math_ops.reduce_prod(original_shape)
    implicit_dim = math_ops.equal(new_shape, -1)
    size_implicit_dim = (
        original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
    new_ndims = array_ops.shape(new_shape)
    expanded_new_shape = array_ops.where(  # Assumes exactly one `-1`.
        implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
    validations = [] if not validate else [
        check_ops.assert_rank(
            original_shape, 1, message="Original shape must be a vector."),
        check_ops.assert_rank(
            new_shape, 1, message="New shape must be a vector."),
        check_ops.assert_less_equal(
            math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
            1,
            message="At most one dimension can be unknown."),
        check_ops.assert_positive(
            expanded_new_shape, message="Shape elements must be >=-1."),
        check_ops.assert_equal(
            math_ops.reduce_prod(expanded_new_shape),
            original_size,
            message="Shape sizes do not match."),
    ]
    return expanded_new_shape, batch_shape_static, validations
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:batch_reshape.py

示例2: _check_shapes_dynamic

  def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.op_scope([operator, v, diag], 'check_shapes'):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(v, r_op))
      if diag is not None:
        checks.append(check_ops.assert_rank(diag, r_op - 1))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag
开发者ID:10imaging,项目名称:tensorflow,代码行数:34,代码来源:operator_pd_vdvt_update.py

示例3: _check_domain_range_possibly_add_asserts

  def _check_domain_range_possibly_add_asserts(self):
    """Static check of init arg `num_rows`, possibly add asserts."""
    # Possibly add asserts.
    if self._assert_proper_shapes:
      self._num_rows = control_flow_ops.with_dependencies([
          check_ops.assert_rank(
              self._num_rows,
              0,
              message="Argument num_rows must be a 0-D Tensor."),
          check_ops.assert_non_negative(
              self._num_rows,
              message="Argument num_rows must be non-negative."),
      ], self._num_rows)
      self._num_columns = control_flow_ops.with_dependencies([
          check_ops.assert_rank(
              self._num_columns,
              0,
              message="Argument num_columns must be a 0-D Tensor."),
          check_ops.assert_non_negative(
              self._num_columns,
              message="Argument num_columns must be non-negative."),
      ], self._num_columns)

    # Static checks.
    if not self._num_rows.dtype.is_integer:
      raise TypeError("Argument num_rows must be integer type.  Found:"
                      " %s" % self._num_rows)

    if not self._num_columns.dtype.is_integer:
      raise TypeError("Argument num_columns must be integer type.  Found:"
                      " %s" % self._num_columns)

    num_rows_static = self._num_rows_static
    num_columns_static = self._num_columns_static

    if num_rows_static is not None:
      if num_rows_static.ndim != 0:
        raise ValueError("Argument num_rows must be a 0-D Tensor.  Found:"
                         " %s" % num_rows_static)

      if num_rows_static < 0:
        raise ValueError("Argument num_rows must be non-negative.  Found:"
                         " %s" % num_rows_static)
    if num_columns_static is not None:
      if num_columns_static.ndim != 0:
        raise ValueError("Argument num_columns must be a 0-D Tensor.  Found:"
                         " %s" % num_columns_static)

      if num_columns_static < 0:
        raise ValueError("Argument num_columns must be non-negative.  Found:"
                         " %s" % num_columns_static)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:51,代码来源:linear_operator_zeros.py

示例4: test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank

 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
   with self.test_session():
     tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
     desired_rank = 1
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py

示例5: _check_batch_shape_possibly_add_asserts

  def _check_batch_shape_possibly_add_asserts(self):
    """Static check of init arg `batch_shape`, possibly add asserts."""
    if self._batch_shape_arg is None:
      return

    # Possibly add asserts
    if self._assert_proper_shapes:
      self._batch_shape_arg = control_flow_ops.with_dependencies(
          [
              check_ops.assert_rank(
                  self._batch_shape_arg,
                  1,
                  message="Argument batch_shape must be a 1-D Tensor."),
              check_ops.assert_non_negative(
                  self._batch_shape_arg,
                  message="Argument batch_shape must be non-negative."),
          ],
          self._batch_shape_arg)

    # Static checks
    if not self._batch_shape_arg.dtype.is_integer:
      raise TypeError("Argument batch_shape must be integer type.  Found:"
                      " %s" % self._batch_shape_arg)

    if self._batch_shape_static is None:
      return  # Cannot do any other static checks.

    if self._batch_shape_static.ndim != 1:
      raise ValueError("Argument batch_shape must be a 1-D Tensor.  Found:"
                       " %s" % self._batch_shape_static)

    if np.any(self._batch_shape_static < 0):
      raise ValueError("Argument batch_shape must be non-negative.  Found:"
                       "%s" % self._batch_shape_static)
开发者ID:Y-owen,项目名称:tensorflow,代码行数:34,代码来源:linear_operator_identity.py

示例6: _sample_n

 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
   if self.total_count.get_shape().ndims is not None:
     if self.total_count.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape_tensor()[0]
   # Flatten batch dims so logits has shape [B, k],
   # where B = reduce_prod(self.batch_shape_tensor()).
   x = random_ops.multinomial(
       logits=array_ops.reshape(self.logits, [-1, k]),
       num_samples=n * n_draws,
       seed=seed)
   x = array_ops.reshape(x, shape=[-1, n, n_draws])
   x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
                           axis=-2)  # shape: [B, n, k]
   x = array_ops.transpose(x, perm=[1, 0, 2])
   final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
   x = array_ops.reshape(x, final_shape)
   return math_ops.cast(x, self.dtype)
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:25,代码来源:multinomial.py

示例7: _check_valid_event_ndims

  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
    """Check whether event_ndims is atleast min_event_ndims."""
    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
    event_ndims_ = tensor_util.constant_value(event_ndims)
    assertions = []

    if not event_ndims.dtype.is_integer:
      raise ValueError("Expected integer dtype, got dtype {}".format(
          event_ndims.dtype))

    if event_ndims_ is not None:
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar event_ndims, got shape {}".format(
            event_ndims.shape))
      if min_event_ndims > event_ndims_:
        raise ValueError("event_ndims ({}) must be larger than "
                         "min_event_ndims ({})".format(
                             event_ndims_, min_event_ndims))
    elif self.validate_args:
      assertions += [
          check_ops.assert_greater_equal(event_ndims, min_event_ndims)]

    if event_ndims.shape.is_fully_defined():
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar shape, got ndims {}".format(
            event_ndims.shape.ndims))

    elif self.validate_args:
      assertions += [
          check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
    return assertions
开发者ID:AnishShah,项目名称:tensorflow,代码行数:31,代码来源:bijector_impl.py

示例8: _sample_n

 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
   if self.n.get_shape().ndims is not None:
     if self.n.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.alpha,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                           reduction_indices=-2)
   final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
   return array_ops.reshape(x, final_shape)
开发者ID:ivankreso,项目名称:tensorflow,代码行数:27,代码来源:dirichlet_multinomial.py

示例9: _check_labels

def _check_labels(labels, expected_labels_dimension):
  """Check labels type and shape."""
  with ops.name_scope(None, 'labels', (labels,)) as scope:
    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    if isinstance(labels, sparse_tensor.SparseTensor):
      raise ValueError('SparseTensor labels are not supported.')
    labels_shape = array_ops.shape(labels)
    err_msg = 'labels shape must be [batch_size, {}]'.format(
        expected_labels_dimension)
    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
    with ops.control_dependencies([assert_rank]):
      static_shape = labels.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_labels_dimension):
          raise ValueError(
              'Mismatched label shape. '
              'Classifier configured with n_classes=%s.  Received %s. '
              'Suggested Fix: check your n_classes argument to the estimator '
              'and/or the shape of your label.' %
              (expected_labels_dimension, dim1))
      assert_dimension = check_ops.assert_equal(
          expected_labels_dimension, labels_shape[1], message=err_msg)
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(labels, name=scope)
开发者ID:cneeruko,项目名称:tensorflow,代码行数:25,代码来源:head.py

示例10: test_rank_one_tensor_raises_if_rank_too_small_static_rank

 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
   tensor = constant_op.constant([1, 2], name="my_tensor")
   desired_rank = 2
   with self.assertRaisesRegexp(ValueError, "rank"):
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       self.evaluate(array_ops.identity(tensor))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py

示例11: test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank

 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     desired_rank = 1
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       array_ops.identity(tensor).eval()
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py

示例12: concatenate_context_input

def concatenate_context_input(context_input, sequence_input):
  """Replicates `context_input` across all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
  seq_rank_check = check_ops.assert_rank(
      sequence_input,
      3,
      message='sequence_input must have rank 3',
      data=[array_ops.shape(sequence_input)])
  seq_type_check = check_ops.assert_type(
      sequence_input,
      dtypes.float32,
      message='sequence_input must have dtype float32; got {}.'.format(
          sequence_input.dtype))
  ctx_rank_check = check_ops.assert_rank(
      context_input,
      2,
      message='context_input must have rank 2',
      data=[array_ops.shape(context_input)])
  ctx_type_check = check_ops.assert_type(
      context_input,
      dtypes.float32,
      message='context_input must have dtype float32; got {}.'.format(
          context_input.dtype))
  with ops.control_dependencies(
      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
    padded_length = array_ops.shape(sequence_input)[1]
    tiled_context_input = array_ops.tile(
        array_ops.expand_dims(context_input, 1),
        array_ops.concat([[1], [padded_length], [1]], 0))
  return array_ops.concat([sequence_input, tiled_context_input], 2)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:47,代码来源:sequence_feature_column.py

示例13: validate_init_args

def validate_init_args(
    distribution,
    batch_shape,
    validate_args,
    batch_shape_static):
  """Helper to __init__ which makes or raises assertions."""
  with ops.name_scope(name="validate_init_args",
                      values=[batch_shape] + distribution._graph_parents):  # pylint: disable=protected-access
    runtime_assertions = []

    if batch_shape.shape.ndims is not None:
      if batch_shape.shape.ndims != 1:
        raise ValueError("`batch_shape` must be a vector "
                         "(saw rank: {}).".format(
                             batch_shape.shape.ndims))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_rank(
              batch_shape,
              1,
              message="`batch_shape` must be a vector.",
              name="assert_batch_shape_is_vector"),
      ]

    batch_size_static = np.prod(batch_shape_static)
    dist_batch_size_static = (
        None if not distribution.batch_shape.is_fully_defined()
        else np.prod(distribution.batch_shape).value)

    if batch_size_static is not None and dist_batch_size_static is not None:
      if batch_size_static != dist_batch_size_static:
        raise ValueError("`batch_shape` size ({}) must match "
                         "`distribution.batch_shape` size ({}).".format(
                             batch_size_static,
                             dist_batch_size_static))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_equal(
              math_ops.reduce_prod(batch_shape),
              math_ops.reduce_prod(distribution.batch_shape_tensor()),
              message=("`batch_shape` size must match "
                       "`distributions.batch_shape` size."),
              name="assert_batch_size"),
      ]

    if batch_shape_static is not None:
      if np.any(batch_shape_static < 1):
        raise ValueError("`batch_shape` elements must be positive "
                         "(i.e., larger than zero).")
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_positive(
              batch_shape,
              message=("`batch_shape` elements must be positive "
                       "(i.e., larger than zero)."),
              name="assert_batch_shape_positive")
      ]

    return runtime_assertions
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:59,代码来源:batch_reshape.py

示例14: test_rank_one_tensor_raises_if_rank_too_large_static_rank

 def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self):
   with self.test_session():
     tensor = constant_op.constant([1, 2], name="my_tensor")
     desired_rank = 0
     with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
       with ops.control_dependencies(
           [check_ops.assert_rank(tensor, desired_rank)]):
         array_ops.identity(tensor).eval()
开发者ID:1000sprites,项目名称:tensorflow,代码行数:8,代码来源:check_ops_test.py

示例15: test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank

 def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
   with self.test_session():
     tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
     desired_rank = 2
     with ops.control_dependencies(
         [check_ops.assert_rank(tensor, desired_rank)]):
       with self.assertRaisesOpError("my_tensor.*rank"):
         array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
开发者ID:1000sprites,项目名称:tensorflow,代码行数:8,代码来源:check_ops_test.py


注:本文中的tensorflow.python.ops.check_ops.assert_rank函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。