本文整理汇总了Python中tensorflow.tensor_scatter_nd_update方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.tensor_scatter_nd_update方法的具体用法?Python tensorflow.tensor_scatter_nd_update怎么用?Python tensorflow.tensor_scatter_nd_update使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.tensor_scatter_nd_update方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: disjoint_signal_to_batch
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def disjoint_signal_to_batch(X, I):
"""
Converts a disjoint graph signal to batch node by zero-padding.
:param X: Tensor, node features of shape (nodes, features).
:param I: Tensor, graph IDs of shape `(N, )`;
:return batch: Tensor, batched node features of shape (batch, N_max, F)
"""
I = tf.cast(I, tf.int32)
num_nodes = tf.math.segment_sum(tf.ones_like(I), I)
start_index = tf.cumsum(num_nodes, exclusive=True)
n_graphs = tf.shape(num_nodes)[0]
max_n_nodes = tf.reduce_max(num_nodes)
batch_n_nodes = tf.shape(I)[0]
feature_dim = tf.shape(X)[-1]
index = tf.range(batch_n_nodes)
index = (index - tf.gather(start_index, I)) + (I * max_n_nodes)
dense = tf.zeros((n_graphs * max_n_nodes, feature_dim), dtype=X.dtype)
dense = tf.tensor_scatter_nd_update(dense, index[..., None], X)
batch = tf.reshape(dense, (n_graphs, max_n_nodes, feature_dim))
return batch
示例2: update_active_spikes
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def update_active_spikes(self, spikes):
""" Given some spikes, add them to active spikes with the appropraite delays
Parameters:
spikes (array like): The spikes that have just occured
Returns:
None
"""
delays_some_hot = spikes * self.delays # (100, 2)
idxs = tf.where(tf.not_equal(delays_some_hot, 0)) # Will give indices of delays (num_spikes * num_neurns, 2) elements are indices into delays_some_hot that are not 0
just_delays = tf.gather_nd(delays_some_hot, idxs) # These become the idx's in delay dimension? (after correction) (num_spikes * num_neurns, 2) elements are delays (floats)
# adjust for variable step size and circular array
delay_dim_idxs = tf.reshape(self.spike_arrival_step(just_delays), [-1, 1]) # Okay now is the arrival index (num_spikes * num_neurns, 1) elements are the correction step at which this spike will arrive
full_idxs = tf.concat([idxs, delay_dim_idxs], axis=1) # add delay indices as a column since they are an index and not more examples
self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.ones(full_idxs.shape[0]))
示例3: encode
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def encode(self, x):
"""Build a simple Lookup Table and set as input `x` if it exists, or `self.x` otherwise.
:param x: An optional input sub-graph to bind to this operation or use `self.x` if `None`
:return: The sub-graph output
"""
self.x = x
e0 = tf.tensor_scatter_nd_update(
self.W, tf.constant(Offsets.PAD, dtype=tf.int32, shape=[1, 1]), tf.zeros(shape=[1, self.dsz])
)
with tf.control_dependencies([e0]):
# The ablation table (4) in https://arxiv.org/pdf/1708.02182.pdf shows this has a massive impact
embedding_w_dropout = self.drop(self.W, training=TRAIN_FLAG())
word_embeddings = tf.nn.embedding_lookup(embedding_w_dropout, self.x)
return word_embeddings
示例4: transform_targets_for_output
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def transform_targets_for_output(y_true, grid_y, grid_x, anchor_idxs, classes):
# y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
N = tf.shape(y_true)[0]
# y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
y_true_out = tf.zeros((N, grid_y, grid_x, tf.shape(anchor_idxs)[0], 6))
anchor_idxs = tf.cast(anchor_idxs, tf.int32)
indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
idx = 0
for i in tf.range(N):
for j in tf.range(tf.shape(y_true)[1]):
if tf.equal(y_true[i][j][2], 0):
continue
anchor_eq = tf.equal(anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))
if tf.reduce_any(anchor_eq):
box = y_true[i][j][0:4]
box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2.
anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
grid_size = tf.cast(tf.stack([grid_x, grid_y], axis=-1), tf.float32)
grid_xy = tf.cast(box_xy * grid_size, tf.int32)
# grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
indexes = indexes.write(idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
updates = updates.write(idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
idx += 1
y_ture_out = tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack())
return y_ture_out
示例5: basis_message_func
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def basis_message_func(self, edges):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = tf.reshape(self.weight, (self.num_bases,
self.in_feat * self.out_feat))
weight = tf.reshape(tf.matmul(self.w_comp, weight), (
self.num_rels, self.in_feat, self.out_feat))
else:
weight = self.weight
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != tf.int64 and self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = weight[etype]
src = tf.boolean_mask(edges.src['h'], loc)
sub_msg = tf.matmul(src, w)
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
msg = utils.bmm_maybe_select(
edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
示例6: bdd_message_func
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
if ((edges.src['h'].dtype == tf.int64) and
len(edges.src['h'].shape) == 1):
raise TypeError(
'Block decomposition does not allow integer ID feature.')
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = tf.reshape(self.weight[etype],
(self.num_bases, self.submat_in, self.submat_out))
src = tf.reshape(tf.boolean_mask(edges.src['h'], loc),
(-1, self.num_bases, self.submat_in))
sub_msg = tf.einsum('abc,bcd->abd', src, w)
sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
weight = tf.reshape(tf.gather(
self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out))
node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in))
msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
示例7: scatter_row
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def scatter_row(data, row_index, value):
row_index = tf.expand_dims(row_index, 1)
return tf.tensor_scatter_nd_update(data, row_index, value)
示例8: clear_current_active_spikes
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def clear_current_active_spikes(self):
""" Remove any spikes that arrived at the current time step
Parameters:
None
Returns:
None
"""
# Fill in any 1's with zeros
spike_idxs = tf.where(tf.not_equal(self.active_spikes[:, :, self.get_active_spike_idx()], 0) )
full_idxs = tf.concat([spike_idxs, tf.ones((spike_idxs.shape[0], 1), dtype=tf.int64) * self.get_active_spike_idx()], axis=1)
self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.zeros(full_idxs.shape[0]))
示例9: transform_targets_for_output
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def transform_targets_for_output(y_true, grid_size, anchor_idxs):
# y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
N = tf.shape(y_true)[0]
# y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
y_true_out = tf.zeros(
(N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6))
anchor_idxs = tf.cast(anchor_idxs, tf.int32)
indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
idx = 0
for i in tf.range(N):
for j in tf.range(tf.shape(y_true)[1]):
if tf.equal(y_true[i][j][2], 0):
continue
anchor_eq = tf.equal(
anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))
if tf.reduce_any(anchor_eq):
box = y_true[i][j][0:4]
box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2
anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32)
# grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
indexes = indexes.write(
idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
updates = updates.write(
idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
idx += 1
# tf.print(indexes.stack())
# tf.print(updates.stack())
return tf.tensor_scatter_nd_update(
y_true_out, indexes.stack(), updates.stack())
示例10: call
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def call(self, inputs, training=True):
"""
parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
"""
atom_features = inputs[0]
# each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
# each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
parents = tf.cast(inputs[1], dtype=tf.int32)
# target atoms for each step: (batch_size*max_atoms) * max_atoms
calculation_orders = inputs[2]
calculation_masks = inputs[3]
n_atoms = tf.squeeze(inputs[4])
graph_features = tf.zeros((self.max_atoms * self.batch_size,
self.max_atoms + 1, self.n_graph_feat))
for count in range(self.max_atoms):
# `count`-th step
# extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
mask = calculation_masks[:, count]
current_round = tf.boolean_mask(calculation_orders[:, count], mask)
batch_atom_features = tf.gather(atom_features, current_round)
# generating index for graph features used in the inputs
stack1 = tf.reshape(
tf.stack(
[tf.boolean_mask(tf.range(n_atoms), mask)] * (self.max_atoms - 1),
axis=1), [-1])
stack2 = tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1])
index = tf.stack([stack1, stack2], axis=1)
# extracting graph features for parents of the target atoms, then flatten
# shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
batch_graph_features = tf.reshape(
tf.gather_nd(graph_features, index),
[-1, (self.max_atoms - 1) * self.n_graph_feat])
# concat into the input tensor: (batch_size*max_atoms) * n_inputs
batch_inputs = tf.concat(
axis=1, values=[batch_atom_features, batch_graph_features])
# DAGgraph_step maps from batch_inputs to a batch of graph_features
# of shape: (batch_size*max_atoms) * n_graph_features
# representing the graph features of target atoms in each graph
batch_outputs = _DAGgraph_step(batch_inputs, self.W_list, self.b_list,
self.activation_fn, self.dropouts,
training)
# index for targe atoms
target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
target_index = tf.boolean_mask(target_index, mask)
graph_features = tf.tensor_scatter_nd_update(graph_features, target_index,
batch_outputs)
return batch_outputs
示例11: segment_top_k
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def segment_top_k(x, I, ratio, top_k_var):
"""
Returns indices to get the top K values in x segment-wise, according to
the segments defined in I. K is not fixed, but it is defined as a ratio of
the number of elements in each segment.
:param x: a rank 1 Tensor;
:param I: a rank 1 Tensor with segment IDs for x;
:param ratio: float, ratio of elements to keep for each segment;
:param top_k_var: a tf.Variable created without shape validation (i.e.,
`tf.Variable(0.0, validate_shape=False)`);
:return: a rank 1 Tensor containing the indices to get the top K values of
each segment in x.
"""
I = tf.cast(I, tf.int32)
num_nodes = tf.math.segment_sum(tf.ones_like(I), I) # Number of nodes in each graph
cumsum = tf.cumsum(num_nodes) # Cumulative number of nodes (A, A+B, A+B+C)
cumsum_start = cumsum - num_nodes # Start index of each graph
n_graphs = tf.shape(num_nodes)[0] # Number of graphs in batch
max_n_nodes = tf.reduce_max(num_nodes) # Order of biggest graph in batch
batch_n_nodes = tf.shape(I)[0] # Number of overall nodes in batch
to_keep = tf.math.ceil(ratio * tf.cast(num_nodes, tf.float32))
to_keep = tf.cast(to_keep, I.dtype) # Nodes to keep in each graph
index = tf.range(batch_n_nodes)
index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes)
y_min = tf.reduce_min(x)
dense_y = tf.ones((n_graphs * max_n_nodes,))
# subtract 1 to ensure that filler values do not get picked
dense_y = dense_y * tf.cast(y_min - 1, dense_y.dtype)
dense_y = tf.cast(dense_y, top_k_var.dtype)
# top_k_var is a variable with unknown shape defined in the elsewhere
top_k_var.assign(dense_y)
dense_y = tf.tensor_scatter_nd_update(top_k_var, index[..., None], tf.cast(x, top_k_var.dtype))
dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes))
perm = tf.argsort(dense_y, direction='DESCENDING')
perm = perm + cumsum_start[:, None]
perm = tf.reshape(perm, (-1,))
to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs,))
rep_times = tf.reshape(tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1), (-1,))
mask = repeat(to_rep, rep_times)
perm = tf.boolean_mask(perm, mask)
return perm
示例12: version_11
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def version_11(cls, node, **kwargs):
axis = node.attrs.get("axis", 0)
data = kwargs["tensor_dict"][node.inputs[0]]
indices = kwargs["tensor_dict"][node.inputs[1]]
updates = kwargs["tensor_dict"][node.inputs[2]]
# poocess negative axis
axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)
# check are there any indices are out of bounds
result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
msg = 'ScatterElements indices are out of bounds, please double check the indices and retry.'
with tf.control_dependencies(
[tf.compat.v1.assert_equal(result, True, message=msg)]):
# process negative indices
indices = cls.process_neg_idx_along_axis(data, axis, indices)
# Calculate shape of the tensorflow version of indices tensor.
sparsified_dense_idx_shape = tf_shape(updates)
# Move on to convert ONNX indices to tensorflow indices in 2 steps:
#
# Step 1:
# What would the index tensors look like if updates are all
# dense? In other words, produce a coordinate tensor for updates:
#
# coordinate[i, j, k ...] = [i, j, k ...]
# where the shape of "coordinate" tensor is same as that of updates.
#
# Step 2:
# But the coordinate tensor needs some correction because coord
# vector at position axis is wrong (since we assumed update is dense,
# but it is not at the axis specified).
# So we update coordinate vector tensor elements at psotion=axis with
# the sparse coordinate indices.
idx_tensors_per_axis = tf.meshgrid(
*list(
map(lambda x: tf.range(x, dtype=tf.dtypes.int64),
sparsified_dense_idx_shape)),
indexing='ij')
idx_tensors_per_axis[axis] = indices
dim_expanded_idx_tensors_per_axis = list(
map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
coordinate = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)
# Now the coordinate tensor is in the shape
# [updates.shape, updates.rank]
# we need it to flattened into the shape:
# [product(updates.shape), updates.rank]
indices = tf.reshape(coordinate, [-1, tf.rank(data)])
updates = tf.reshape(updates, [-1])
return [tf.tensor_scatter_nd_update(data, indices, updates)]
示例13: version_11
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def version_11(cls, node, **kwargs):
# GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis that identifies
# an axis of data (by default, the outer-most axis, that is axis 0). It is an indexing operation that produces its output by
# indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the
# same as the shape of indices and consists of one value (gathered from the data) for each element in indices.
axis = node.attrs.get("axis", 0)
data = kwargs["tensor_dict"][node.inputs[0]]
indices = kwargs["tensor_dict"][node.inputs[1]]
# poocess negative axis
axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)
# check are there any indices are out of bounds
result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
msg = 'GatherElements indices are out of bounds,'\
' please double check the indices and retry.'
with tf.control_dependencies(
[tf.compat.v1.assert_equal(result, True, message=msg)]):
# process negative indices
indices = cls.process_neg_idx_along_axis(data, axis, indices)
# adapted from reference implementation in onnx/onnx/backend/test/case/node/gatherelements.py
if axis == 0:
axis_perm = tf.range(tf.rank(data))
data_swaped = data
index_swaped = indices
else:
axis_perm = tf.tensor_scatter_nd_update(tf.range(tf.rank(data)),
tf.constant([[0], [axis]]),
tf.constant([axis, 0]))
data_swaped = tf.transpose(data, perm=axis_perm)
index_swaped = tf.transpose(indices, perm=axis_perm)
idx_tensors_per_axis = tf.meshgrid(*list(
map(lambda x: tf.range(x, dtype=index_swaped.dtype),
index_swaped.shape.as_list())),
indexing='ij')
idx_tensors_per_axis[0] = index_swaped
dim_expanded_idx_tensors_per_axis = list(
map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
index_expanded = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)
gathered = tf.gather_nd(data_swaped, index_expanded)
y = tf.transpose(gathered, perm=axis_perm)
return [y]
示例14: scatter_to_2d
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def scatter_to_2d(tensor, segments, pad_value, output_shape=None):
"""Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`.
For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then
the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred
when None is provided. In this case, the shape will be dynamic and may not be
compatible with TPU. For TPU use case, please provide the `output_shape`
explicitly.
Args:
tensor: A 1-D numeric `Tensor`.
segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0,
0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted.
pad_value: A numeric value to pad the output `Tensor`.
output_shape: A `Tensor` of size 2 telling the desired shape of the output
tensor. If None, the output_shape will be inferred and not fixed at
compilation time. When output_shape is smaller than needed, trucation will
be applied.
Returns:
A 2-D Tensor.
"""
with tf.compat.v1.name_scope(name='scatter_to_2d'):
tensor = tf.convert_to_tensor(value=tensor)
segments = tf.convert_to_tensor(value=segments)
tensor.get_shape().assert_has_rank(1)
segments.get_shape().assert_has_rank(1)
tensor.get_shape().assert_is_compatible_with(segments.get_shape())
# Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so
# that we can use scatter_nd to distribute the value in `tensor` to 2-D. The
# needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the
# in-segment indices.
index_2nd_dim = _in_segment_indices(segments)
# Compute the output_shape.
if output_shape is None:
# Set output_shape to the inferred one.
output_shape = [
tf.reduce_max(input_tensor=segments) + 1,
tf.reduce_max(input_tensor=index_2nd_dim) + 1
]
else:
# The output_shape may be smaller. We collapse the out-of-range ones into
# indices [output_shape[0], 0] and then use tf.slice to remove extra row
# and column after scatter.
valid_segments = tf.less(segments, output_shape[0])
valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1])
mask = tf.logical_and(valid_segments, valid_2nd_dim)
segments = tf.compat.v1.where(mask, segments,
output_shape[0] * tf.ones_like(segments))
index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim,
tf.zeros_like(index_2nd_dim))
# Create the 2D Tensor. For padding, we add one extra row and column and
# then slice them to fit the output_shape.
nd_indices = tf.stack([segments, index_2nd_dim], axis=1)
padding = pad_value * tf.ones(
shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype)
tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor)
tensor = tf.slice(tensor, begin=[0, 0], size=output_shape)
return tensor
示例15: scatter_to_2d
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def scatter_to_2d(tensor, segments, pad_value, output_shape=None):
"""Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`.
For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then
the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred
when None is provided. In this case, the shape will be dynamic and may not be
compatible with TPU. For TPU use case, please provide the `output_shape`
explicitly.
Args:
tensor: A 1-D numeric `Tensor`.
segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0,
0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted.
pad_value: A numeric value to pad the output `Tensor`.
output_shape: A `Tensor` of size 2 telling the desired shape of the output
tensor. If None, the output_shape will be inferred and not fixed at
compilation time. When output_shape is smaller than needed, trucation will
be applied.
Returns:
A 2-D Tensor.
"""
with tf.compat.v1.name_scope(name='scatter_to_2d'):
tensor = tf.convert_to_tensor(value=tensor)
segments = tf.convert_to_tensor(value=segments)
tensor.get_shape().assert_has_rank(1)
segments.get_shape().assert_has_rank(1)
tensor.get_shape().assert_is_compatible_with(segments.get_shape())
# Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so
# that we can use scatter_nd to distribute the value in `tensor` to 2-D. The
# needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the
# in-segment indices.
index_2nd_dim = _in_segment_indices(segments)
# Compute the output_shape.
if output_shape is None:
# Set output_shape to the inferred one.
output_shape = [
tf.reduce_max(input_tensor=segments) + 1,
tf.reduce_max(input_tensor=index_2nd_dim) + 1
]
else:
# The output_shape may be smaller. We collapse the out-of-range ones into
# indices [output_shape[0], 0] and then use tf.slice to remove extra row
# and column after scatter.
valid_segments = tf.less(segments, output_shape[0])
valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1])
mask = tf.logical_and(valid_segments, valid_2nd_dim)
segments = tf.compat.v1.where(mask, segments,
output_shape[0] * tf.ones_like(segments))
index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim,
tf.zeros_like(index_2nd_dim))
# Create the 2D Tensor. For padding, we add one extra row and column and
# then slice them to fit the output_shape.
nd_indices = tf.stack([segments, index_2nd_dim], axis=1)
padding = pad_value * tf.ones(
shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype)
tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor)
tensor = tf.slice(tensor, begin=[0, 0], size=output_shape)
return tensor