将 data
张量中的值交错成单个张量。
用法
tf.dynamic_stitch(
indices, data, name=None
)
参数
-
indices
至少 1 个Tensor
类型为int32
的对象的列表。 -
data
与具有相同类型的Tensor
对象的indices
长度相同的列表。 -
name
操作的名称(可选)。
返回
-
一个
Tensor
。具有与data
相同的类型。
构建一个合并的张量,使得
merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]
例如,如果每个 indices[m]
是标量或向量,我们有
# Scalar indices:
merged[indices[m], ...] = data[m][...]
# Vector indices:
merged[indices[m][i], ...] = data[m][i, ...]
每个 data[i].shape
必须以相应的 indices[i].shape
开头,并且 data[i].shape
的其余部分必须是常量 w.r.t。 i
。也就是说,我们必须有 data[i].shape = indices[i].shape + constant
。就这个 constant
而言,输出形状是
merged.shape = [max(indices)] + constant
值按顺序合并,因此如果索引同时出现在 indices[m][i]
和 indices[n][j]
中,则 (m,i) < (n,j)
切片 data[n][j]
将出现在合并结果中。如果您不需要此保证,ParallelDynamicStitch 可能在某些设备上表现更好。
例如:
indices[0] = 6
indices[1] = [4, 1]
indices[2] = [[5, 2], [0, 3]]
data[0] = [61, 62]
data[1] = [[41, 42], [11, 12]]
data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],
[51, 52], [61, 62]]
此方法可用于合并由dynamic_partition
创建的分区,如下例所示:
# Apply function (increments x_i) on elements for which a certain condition
# apply (x_i != -1 in this example).
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])
condition_mask=tf.not_equal(x,tf.constant(-1.))
partitioned_data = tf.dynamic_partition(
x, tf.cast(condition_mask, tf.int32) , 2)
partitioned_data[1] = partitioned_data[1] + 1.0
condition_indices = tf.dynamic_partition(
tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)
x = tf.dynamic_stitch(condition_indices, partitioned_data)
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
# unchanged.
相关用法
- Python tf.dynamic_partition用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.data.Dataset.take_while用法及代码示例
- Python tf.data.experimental.RandomDataset.group_by_window用法及代码示例
- Python tf.data.TFRecordDataset.filter用法及代码示例
- Python tf.data.TextLineDataset.reduce用法及代码示例
- Python tf.data.TextLineDataset.with_options用法及代码示例
- Python tf.data.experimental.SqlDataset.enumerate用法及代码示例
- Python tf.data.TextLineDataset.as_numpy_iterator用法及代码示例
- Python tf.data.experimental.make_saveable_from_iterator用法及代码示例
- Python tf.distribute.TPUStrategy用法及代码示例
- Python tf.data.TextLineDataset.random用法及代码示例
- Python tf.data.FixedLengthRecordDataset.repeat用法及代码示例
- Python tf.data.TFRecordDataset.random用法及代码示例
- Python tf.data.Dataset.cardinality用法及代码示例
- Python tf.distribute.experimental_set_strategy用法及代码示例
- Python tf.data.FixedLengthRecordDataset.cardinality用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.gather用法及代码示例
- Python tf.distribute.cluster_resolver.TFConfigClusterResolver用法及代码示例
- Python tf.data.TextLineDataset.take_while用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.dynamic_stitch。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。