当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.nest.pack_sequence_as用法及代码示例


返回打包到给定结构中的给定展平序列。

用法

tf.nest.pack_sequence_as(
    structure, flat_sequence, expand_composites=False
)

参数

  • structure 嵌套结构,其结构由嵌套列表、元组和字典给出。注意:numpy 数组和字符串被认为是标量。
  • flat_sequence 要打包的扁平序列。
  • expand_composites 如果为真,则复合张量(例如 tf.sparse.SparseTensortf.RaggedTensor)将扩展为它们的分量张量。

返回

  • packed flat_sequence 转换为具有与 structure 相同的递归结构。

抛出

  • ValueError 如果flat_sequencestructure 具有不同的原子数。
  • TypeError structure 是或包含带有不可排序键的字典。

有关结构的定义,请参阅tf.nest

如果structure是一个原子,flat_sequence必须是一个single-item列表;在这种情况下,返回值为 flat_sequence[0]

如果 structure 是或包含 dict 实例,则将对键进行排序以按确定顺序打包平面序列。对于OrderedDict 实例也是如此:它们的序列顺序被忽略,而是使用键的排序顺序。 flatten 遵循相同的约定。这会在 dicts 和 OrderedDict 被展平后正确地重新打包它们,并且还允许展平 OrderedDict 然后使用相应的普通 dict 将其重新打包,反之亦然。具有不可排序键的字典不能被展平。

例子:

  1. Python字典:
structure = { "key3":"", "key1":"", "key2":"" }
  flat_sequence = ["value1", "value2", "value3"]
  tf.nest.pack_sequence_as(structure, flat_sequence)
    {'key3':'value3', 'key1':'value1', 'key2':'value2'}
  1. 对于嵌套的 python 元组:
structure = (('a','b'), ('c','d','e'), 'f')
  flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
  tf.nest.pack_sequence_as(structure, flat_sequence)
    ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
  1. 对于字典的嵌套字典:
structure = { "key3":{"c":('alpha', 'beta'), "a":('gamma')},
                "key1":{"e":"val1", "d":"val2"} }
  flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0]
  tf.nest.pack_sequence_as(structure, flat_sequence)
    {'key3':{'c':(1.0, 2.0), 'a':3.0}, 'key1':{'e':'val1', 'd':'val2'} }
  1. Numpy 数组(视为标量):
structure = ['a']
  flat_sequence = [np.array([[1, 2], [3, 4]])]
  tf.nest.pack_sequence_as(structure, flat_sequence)
    [array([[1, 2],
           [3, 4]])]
  1. tf.Tensor(视为标量):
structure = ['a']
  flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])]
  tf.nest.pack_sequence_as(structure, flat_sequence)
    [<tf.Tensor:shape=(2, 3), dtype=float32,
     numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>]
  1. tf.RaggedTensor :这是一个复合张量,它的表示由 'values' 的扁平化列表和 'row_splits' 的列表组成,它们指示如何将扁平化列表分割成不同的行。有关 tf.RaggedTensor 的更多详细信息,请访问 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor。

使用 expand_composites=False ,我们将 RaggedTensor 视为标量。

structure = { "foo":tf.ragged.constant([[1, 2], [3]]),
                "bar":tf.constant([[5]]) }
  flat_sequence = [ "one", "two" ]
  tf.nest.pack_sequence_as(structure, flat_sequence,
  expand_composites=False)
    {'foo':'two', 'bar':'one'}

对于 expand_composites=True ,我们期望扁平化的输入包含构成参差不齐的张量的张量,即值和 row_splits 张量。

structure = { "foo":tf.ragged.constant([[1., 2.], [3.]]),
                "bar":tf.constant([[5.]]) }
  tensors = tf.nest.flatten(structure, expand_composites=True)
  print(tensors)
    [<tf.Tensor:shape=(1, 1), dtype=float32, numpy=array([[5.]],
     dtype=float32)>,
     <tf.Tensor:shape=(3,), dtype=float32, numpy=array([1., 2., 3.],
     dtype=float32)>,
     <tf.Tensor:shape=(3,), dtype=int64, numpy=array([0, 2, 3])>]
  verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor:')
                      if t.dtype==tf.float32 else t
                      for t in tensors]
  tf.nest.pack_sequence_as(structure, verified_tensors,
                           expand_composites=True)
    {'foo':<tf.RaggedTensor [[1.0, 2.0], [3.0]]>,
     'bar':<tf.Tensor:shape=(1, 1), dtype=float32, numpy=array([[5.]],
     dtype=float32)>}

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.nest.pack_sequence_as。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。