当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.nest.flatten用法及代码示例


从给定结构返回一个平面列表。

用法

tf.nest.flatten(
    structure, expand_composites=False
)

参数

  • structure 一个原子或嵌套结构。请注意,numpy 数组被认为是原子,不会被展平。
  • expand_composites 如果为真,则复合张量(例如 tf.sparse.SparseTensortf.RaggedTensor)将扩展为它们的分量张量。

返回

  • Python 列表,输入的扁平化版本。

抛出

  • TypeError 嵌套是或包含具有不可排序键的字典。

有关结构的定义,请参阅tf.nest

如果结构是原子,则返回 single-item 列表:[结构]。

这与 nest.pack_sequence_as 方法相反,该方法接受一个扁平列表并将其重新打包到嵌套结构中。

在 dict 实例的情况下,序列由值组成,按键排序以确保确定性行为。对于 OrderedDict 实例也是如此:它们的序列顺序被忽略,而是使用键的排序顺序。 nest.pack_sequence_as 遵循相同的约定。这会在 dicts 和 OrderedDicts 被展平后正确地重新打包它们,并且还允许展平 OrderedDict 然后使用相应的普通 dict 重新打包它,反之亦然。具有不可排序键的字典不能被展平。

在此函数运行时,用户不得修改嵌套中使用的任何集合。

例子:

  1. Python dict(按键排序):
dict = { "key3":"value3", "key1":"value1", "key2":"value2" }
  tf.nest.flatten(dict)
    ['value1', 'value2', 'value3']
  1. 对于嵌套的 python 元组:
tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
  tf.nest.flatten(tuple)
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
  1. 对于字典的嵌套字典:
dict = { "key3":{"c":(1.0, 2.0), "a":(3.0)},
  "key1":{"m":"val1", "g":"val2"} }
  tf.nest.flatten(dict)
    ['val2', 'val1', 3.0, 1.0, 2.0]
  1. Numpy 数组(不会展平):
array = np.array([[1, 2], [3, 4]])
  tf.nest.flatten(array)
        [array([[1, 2],
                [3, 4]])]
  1. tf.Tensor(不会变平):
tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
  tf.nest.flatten(tensor)
        [<tf.Tensor:shape=(3, 3), dtype=float32, numpy=
          array([[1., 2., 3.],
                 [4., 5., 6.],
                 [7., 8., 9.]], dtype=float32)>]
  1. tf.RaggedTensor :这是一个复合张量,它的表示由 'values' 的扁平化列表和 'row_splits' 的列表组成,它们指示如何将扁平化列表分割成不同的行。有关 tf.RaggedTensor 的更多详细信息,请访问 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor。

使用 expand_composites=False ,我们只需按原样返回 RaggedTensor。

tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
  tf.nest.flatten(tensor, expand_composites=False)
    [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>]

使用 expand_composites=True ,我们返回组成 RaggedTensor 表示的组件张量(值和 row_splits 张量)

tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
  tf.nest.flatten(tensor, expand_composites=True)
    [<tf.Tensor:shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2],
                                                      dtype=int32)>,
     <tf.Tensor:shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>]

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.nest.flatten。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。