当前位置: 首页>>代码示例>>Python>>正文


Python nest.assert_same_structure函数代码示例

本文整理汇总了Python中tensorflow.python.data.util.nest.assert_same_structure函数的典型用法代码示例。如果您正苦于以下问题:Python assert_same_structure函数的具体用法?Python assert_same_structure怎么用?Python assert_same_structure使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了assert_same_structure函数的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_initializer

  def make_initializer(self, dataset, name=None):
    """Returns a `tf.Operation` that initializes this iterator on `dataset`.

    Args:
      dataset: A `Dataset` with compatible structure to this iterator.
      name: (Optional.) A name for the created operation.

    Returns:
      A `tf.Operation` that can be run to initialize this iterator on the given
      `dataset`.

    Raises:
      TypeError: If `dataset` and this iterator do not have a compatible
        element structure.
    """
    with ops.name_scope(name, "make_initializer") as name:
      nest.assert_same_structure(self._output_types, dataset.output_types)
      nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
      for iterator_dtype, dataset_dtype in zip(
          nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
        if iterator_dtype != dataset_dtype:
          raise TypeError(
              "Expected output types %r but got dataset with output types %r." %
              (self._output_types, dataset.output_types))
      for iterator_shape, dataset_shape in zip(
          nest.flatten(self._output_shapes),
          nest.flatten(dataset.output_shapes)):
        if not iterator_shape.is_compatible_with(dataset_shape):
          raise TypeError("Expected output shapes compatible with %r but got "
                          "dataset with output shapes %r." %
                          (self._output_shapes, dataset.output_shapes))
    with ops.colocate_with(self._iterator_resource):
      return gen_dataset_ops.make_iterator(
          dataset._as_variant_tensor(), self._iterator_resource, name=name)  # pylint: disable=protected-access
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:34,代码来源:iterator.py

示例2: testSerializeDeserialize

 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:27,代码来源:sparse_test.py

示例3: from_string_handle

  def from_string_handle(string_handle, output_types, output_shapes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a @{tf.Session.run} call.
    In that case, `string_handle` would a @{tf.placeholder}, and you would feed
    it with the value of @{tf.data.Iterator.string_handle} in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
        objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
        component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    iterator_resource = gen_dataset_ops.iterator_from_string_handle(
        string_handle,
        output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)),
        output_shapes=nest.flatten(output_shapes))
    return Iterator(iterator_resource, None, output_types, output_shapes)
开发者ID:SylChan,项目名称:tensorflow,代码行数:57,代码来源:iterator_ops.py

示例4: __init__

  def __init__(self, dataset, output_types, output_shapes=None):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    super(_RestructuredDataset, self).__init__()
    self._dataset = dataset

    # Validate that the types are compatible.
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    flat_original_types = nest.flatten(dataset.output_types)
    flat_new_types = nest.flatten(output_types)
    if flat_original_types != flat_new_types:
      raise ValueError(
          "Dataset with output types %r cannot be restructured to have output "
          "types %r" % (dataset.output_types, output_types))

    self._output_types = output_types

    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      self._output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(
                                                      dataset.output_shapes))
    else:
      # Validate that the shapes are compatible.
      nest.assert_same_structure(output_types, output_shapes)
      flat_original_shapes = nest.flatten(dataset.output_shapes)
      flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

      for original_shape, new_shape in zip(flat_original_shapes,
                                           flat_new_shapes):
        if not original_shape.is_compatible_with(new_shape):
          raise ValueError(
              "Dataset with output shapes %r cannot be restructured to have "
              "incompatible output shapes %r" % (dataset.output_shapes,
                                                 output_shapes))
      self._output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:55,代码来源:batching.py

示例5: __eq__

  def __eq__(self, other):
    if not isinstance(other, NestedStructure):
      return False
    try:
      # pylint: disable=protected-access
      nest.assert_same_structure(self._nested_structure,
                                 other._nested_structure)
    except (ValueError, TypeError):
      return False

    return nest.flatten(self._nested_structure) == nest.flatten(
        other._nested_structure)
开发者ID:aritratony,项目名称:tensorflow,代码行数:12,代码来源:structure.py

示例6: __init__

 def __init__(self, variant_tensor, output_shapes, output_types,
              output_classes):
   # TODO(b/110122868): Consolidate the structure validation logic with the
   # similar logic in `Iterator.from_structure()` and
   # `Dataset.from_generator()`.
   output_types = nest.map_structure(dtypes.as_dtype, output_types)
   output_shapes = nest.map_structure_up_to(
       output_types, tensor_shape.as_shape, output_shapes)
   nest.assert_same_structure(output_types, output_shapes)
   nest.assert_same_structure(output_types, output_classes)
   self._variant_tensor = variant_tensor
   self._output_shapes = output_shapes
   self._output_types = output_types
   self._output_classes = output_classes
开发者ID:AnishShah,项目名称:tensorflow,代码行数:14,代码来源:optional_ops.py

示例7: _compareOutputToExpected

 def _compareOutputToExpected(self, result_values, expected_values,
                              assert_items_equal):
   if assert_items_equal:
     # TODO(shivaniagrawal): add support for nested elements containing sparse
     # tensors when needed.
     self.assertItemsEqual(result_values, expected_values)
     return
   for i in range(len(result_values)):
     nest.assert_same_structure(result_values[i], expected_values[i])
     for result_value, expected_value in zip(
         nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
       if sparse_tensor.is_sparse(result_value):
         self.assertSparseValuesEqual(result_value, expected_value)
       else:
         self.assertAllEqual(result_value, expected_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:15,代码来源:test_base.py

示例8: is_compatible_with

  def is_compatible_with(self, other):
    if not isinstance(other, NestedStructure):
      return False
    try:
      # pylint: disable=protected-access
      nest.assert_same_structure(self._nested_structure,
                                 other._nested_structure)
    except (ValueError, TypeError):
      return False

    return all(
        substructure.is_compatible_with(other_substructure)
        for substructure, other_substructure in zip(
            nest.flatten(self._nested_structure),
            nest.flatten(other._nested_structure)))
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:15,代码来源:structure.py

示例9: testMapStructure

  def testMapStructure(self):
    structure1 = (((1, 2), 3), 4, (5, 6))
    structure2 = (((7, 8), 9), 10, (11, 12))
    structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
    nest.assert_same_structure(structure1, structure1_plus1)
    self.assertAllEqual(
        [2, 3, 4, 5, 6, 7],
        nest.flatten(structure1_plus1))
    structure1_plus_structure2 = nest.map_structure(
        lambda x, y: x + y, structure1, structure2)
    self.assertEqual(
        (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
        structure1_plus_structure2)

    self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

    self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

    with self.assertRaisesRegexp(TypeError, "callable"):
      nest.map_structure("bad", structure1_plus1)

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, 3, (3,))

    with self.assertRaisesRegexp(TypeError, "same sequence type"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
                         check_types=False)

    with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
      nest.map_structure(lambda x: None, structure1, foo="a")

    with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
      nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:39,代码来源:nest_test.py

示例10: testSerializeDeserialize

 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected),
         sparse.get_sparse_types(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
开发者ID:SylChan,项目名称:tensorflow,代码行数:23,代码来源:sparse_test.py

示例11: testAssertSameStructure

  def testAssertSameStructure(self):
    structure1 = (((1, 2), 3), 4, (5, 6))
    structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
    structure_different_num_elements = ("spam", "eggs")
    structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
    nest.assert_same_structure(structure1, structure2)
    nest.assert_same_structure("abc", 1.0)
    nest.assert_same_structure("abc", np.array([0, 1]))
    nest.assert_same_structure("abc", constant_op.constant([0, 1]))

    with self.assertRaisesRegexp(ValueError,
                                 "don't have the same number of elements"):
      nest.assert_same_structure(structure1, structure_different_num_elements)

    with self.assertRaisesRegexp(ValueError,
                                 "don't have the same number of elements"):
      nest.assert_same_structure((0, 1), np.array([0, 1]))

    with self.assertRaisesRegexp(ValueError,
                                 "don't have the same number of elements"):
      nest.assert_same_structure(0, (0, 1))

    with self.assertRaisesRegexp(ValueError,
                                 "don't have the same nested structure"):
      nest.assert_same_structure(structure1, structure_different_nesting)

    named_type_0 = collections.namedtuple("named_0", ("a", "b"))
    named_type_1 = collections.namedtuple("named_1", ("a", "b"))
    self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
                      named_type_0("a", "b"))

    nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))

    self.assertRaises(TypeError, nest.assert_same_structure,
                      named_type_0(3, 4), named_type_1(3, 4))

    with self.assertRaisesRegexp(ValueError,
                                 "don't have the same nested structure"):
      nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4))

    with self.assertRaisesRegexp(ValueError,
                                 "don't have the same nested structure"):
      nest.assert_same_structure(((3,), 4), (3, (4,)))

    structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
    with self.assertRaisesRegexp(TypeError,
                                 "don't have the same sequence type"):
      nest.assert_same_structure(structure1, structure1_list)
    nest.assert_same_structure(structure1, structure2, check_types=False)
    nest.assert_same_structure(structure1, structure1_list, check_types=False)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:50,代码来源:nest_test.py

示例12: from_structure

  def from_structure(output_types,
                     output_shapes=None,
                     shared_name=None,
                     output_classes=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    iterator_resource = gen_dataset_ops.iterator(
        container="",
        shared_name=shared_name,
        output_types=nest.flatten(
            sparse.as_dense_types(output_types, output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
开发者ID:modkzs,项目名称:tensorflow,代码行数:90,代码来源:iterator_ops.py

示例13: from_string_handle

  def from_string_handle(string_handle,
                         output_types,
                         output_shapes=None,
                         output_classes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a `tf.Session.run` call.
    In that case, `string_handle` would be a `tf.placeholder`, and you would
    feed it with the value of `tf.data.Iterator.string_handle` in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    output_structure = structure_lib.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    # pylint: disable=protected-access
    if compat.forward_compatible(2018, 8, 3):
      if _device_stack_is_empty():
        with ops.device("/cpu:0"):
          iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
              string_handle,
              output_types=output_structure._flat_types,
              output_shapes=output_structure._flat_shapes)
      else:
        iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
            string_handle,
            output_types=output_structure._flat_types,
            output_shapes=output_structure._flat_shapes)
    else:
      iterator_resource = gen_dataset_ops.iterator_from_string_handle(
          string_handle,
          output_types=output_structure._flat_types,
          output_shapes=output_structure._flat_shapes)
    # pylint: enable=protected-access
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:82,代码来源:iterator_ops.py

示例14: __init__

  def __init__(self,
               dataset,
               output_types,
               output_shapes=None,
               output_classes=None,
               allow_unsafe_cast=False):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types.
        If omitted, the class types will be inherited from `dataset`.
      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
        reported output types and shapes of the restructured dataset, e.g. to
        switch a sparse tensor represented as `tf.variant` to its user-visible
        type and shape.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    self._input_dataset = dataset

    input_types = dataset_ops.get_legacy_output_types(dataset)
    if not allow_unsafe_cast:
      # Validate that the types are compatible.
      output_types = nest.map_structure(dtypes.as_dtype, output_types)
      flat_original_types = nest.flatten(input_types)
      flat_new_types = nest.flatten(output_types)
      if flat_original_types != flat_new_types:
        raise ValueError(
            "Dataset with output types %r cannot be restructured to have "
            "output types %r" %
            (dataset_ops.get_legacy_output_types(dataset), output_types))

    input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      output_shapes = nest.pack_sequence_as(
          output_types, nest.flatten(input_shapes))
    else:
      if not allow_unsafe_cast:
        # Validate that the shapes are compatible.
        nest.assert_same_structure(output_types, output_shapes)
        flat_original_shapes = nest.flatten(input_shapes)
        flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

        for original_shape, new_shape in zip(flat_original_shapes,
                                             flat_new_shapes):
          if not original_shape.is_compatible_with(new_shape):
            raise ValueError(
                "Dataset with output shapes %r cannot be restructured to have "
                "incompatible output shapes %r" % (input_shapes,
                                                   output_shapes))
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)

    input_classes = dataset_ops.get_legacy_output_classes(dataset)
    if output_classes is None:
      # Inherit class types from the original `dataset`.
      output_classes = nest.pack_sequence_as(
          output_types, nest.flatten(input_classes))

    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:78,代码来源:batching.py


注:本文中的tensorflow.python.data.util.nest.assert_same_structure函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。