当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.tensor_scatter_nd_update方法代码示例

本文整理汇总了Python中tensorflow.tensor_scatter_nd_update方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.tensor_scatter_nd_update方法的具体用法?Python tensorflow.tensor_scatter_nd_update怎么用?Python tensorflow.tensor_scatter_nd_update使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.tensor_scatter_nd_update方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: disjoint_signal_to_batch

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def disjoint_signal_to_batch(X, I):
    """
    Converts a disjoint graph signal to batch node by zero-padding.

    :param X: Tensor, node features of shape (nodes, features).
    :param I: Tensor, graph IDs of shape `(N, )`;
    :return batch: Tensor, batched node features of shape (batch, N_max, F)
    """
    I = tf.cast(I, tf.int32)
    num_nodes = tf.math.segment_sum(tf.ones_like(I), I)
    start_index = tf.cumsum(num_nodes, exclusive=True)
    n_graphs = tf.shape(num_nodes)[0]
    max_n_nodes = tf.reduce_max(num_nodes)
    batch_n_nodes = tf.shape(I)[0]
    feature_dim = tf.shape(X)[-1]

    index = tf.range(batch_n_nodes)
    index = (index - tf.gather(start_index, I)) + (I * max_n_nodes)
    dense = tf.zeros((n_graphs * max_n_nodes, feature_dim), dtype=X.dtype)
    dense = tf.tensor_scatter_nd_update(dense, index[..., None], X)

    batch = tf.reshape(dense, (n_graphs, max_n_nodes, feature_dim))

    return batch 
开发者ID:danielegrattarola,项目名称:spektral,代码行数:26,代码来源:modes.py

示例2: update_active_spikes

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def update_active_spikes(self, spikes):
    """ Given some spikes, add them to active spikes with the appropraite delays
    
    Parameters:
      spikes (array like): The spikes that have just occured
      
    Returns:
      None
    """
    delays_some_hot = spikes * self.delays  # (100, 2)
    idxs = tf.where(tf.not_equal(delays_some_hot, 0))  # Will give indices of delays    (num_spikes * num_neurns, 2) elements are indices into delays_some_hot that are not 0
    just_delays = tf.gather_nd(delays_some_hot, idxs)  # These become the idx's in delay dimension? (after correction)  (num_spikes * num_neurns, 2) elements are delays (floats)
    
    # adjust for variable step size and circular array
    delay_dim_idxs = tf.reshape(self.spike_arrival_step(just_delays), [-1, 1])  # Okay now is the arrival index  (num_spikes * num_neurns, 1) elements are the correction step at which this spike will arrive
    
    full_idxs = tf.concat([idxs, delay_dim_idxs], axis=1)  # add delay indices as a column since they are an index and not more examples

    self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.ones(full_idxs.shape[0])) 
开发者ID:jotia1,项目名称:spiking-net-tensorflow,代码行数:21,代码来源:delayedmodels.py

示例3: encode

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def encode(self, x):
        """Build a simple Lookup Table and set as input `x` if it exists, or `self.x` otherwise.

        :param x: An optional input sub-graph to bind to this operation or use `self.x` if `None`
        :return: The sub-graph output
        """
        self.x = x
        e0 = tf.tensor_scatter_nd_update(
            self.W, tf.constant(Offsets.PAD, dtype=tf.int32, shape=[1, 1]), tf.zeros(shape=[1, self.dsz])
        )
        with tf.control_dependencies([e0]):
            # The ablation table (4) in https://arxiv.org/pdf/1708.02182.pdf shows this has a massive impact
            embedding_w_dropout = self.drop(self.W, training=TRAIN_FLAG())
            word_embeddings = tf.nn.embedding_lookup(embedding_w_dropout, self.x)

        return word_embeddings 
开发者ID:dpressel,项目名称:mead-baseline,代码行数:18,代码来源:embeddings.py

示例4: transform_targets_for_output

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def transform_targets_for_output(y_true, grid_y, grid_x, anchor_idxs, classes):
    # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
    N = tf.shape(y_true)[0]

    # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
    y_true_out = tf.zeros((N, grid_y, grid_x, tf.shape(anchor_idxs)[0], 6))

    anchor_idxs = tf.cast(anchor_idxs, tf.int32)

    indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
    idx = 0
    for i in tf.range(N):
        for j in tf.range(tf.shape(y_true)[1]):
            if tf.equal(y_true[i][j][2], 0):
                continue
            anchor_eq = tf.equal(anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))

            if tf.reduce_any(anchor_eq):
                box = y_true[i][j][0:4]
                box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2.

                anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
                grid_size = tf.cast(tf.stack([grid_x, grid_y], axis=-1), tf.float32)
                grid_xy = tf.cast(box_xy * grid_size, tf.int32)
                # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
                indexes = indexes.write(idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
                updates = updates.write(idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
                idx += 1

    y_ture_out = tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack())
    return y_ture_out 
开发者ID:akkaze,项目名称:tf2-yolo3,代码行数:34,代码来源:dataset.py

示例5: basis_message_func

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def basis_message_func(self, edges):
        """Message function for basis regularizer"""
        if self.num_bases < self.num_rels:
            # generate all weights from bases
            weight = tf.reshape(self.weight, (self.num_bases,
                                              self.in_feat * self.out_feat))
            weight = tf.reshape(tf.matmul(self.w_comp, weight), (
                self.num_rels, self.in_feat, self.out_feat))
        else:
            weight = self.weight

        # calculate msg @ W_r before put msg into edge
        # if src is th.int64 we expect it is an index select
        if edges.src['h'].dtype != tf.int64 and self.low_mem:
            etypes, _ = tf.unique(edges.data['type'])
            msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
            idx = tf.range(edges.src['h'].shape[0])
            for etype in etypes:
                loc = (edges.data['type'] == etype)
                w = weight[etype]
                src = tf.boolean_mask(edges.src['h'], loc)
                sub_msg = tf.matmul(src, w)
                indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
                msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
        else:
            msg = utils.bmm_maybe_select(
                edges.src['h'], weight, edges.data['type'])
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg} 
开发者ID:dmlc,项目名称:dgl,代码行数:32,代码来源:relgraphconv.py

示例6: bdd_message_func

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def bdd_message_func(self, edges):
        """Message function for block-diagonal-decomposition regularizer"""
        if ((edges.src['h'].dtype == tf.int64) and
                len(edges.src['h'].shape) == 1):
            raise TypeError(
                'Block decomposition does not allow integer ID feature.')

        # calculate msg @ W_r before put msg into edge
        # if src is th.int64 we expect it is an index select
        if self.low_mem:
            etypes, _ = tf.unique(edges.data['type'])
            msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
            idx = tf.range(edges.src['h'].shape[0])
            for etype in etypes:
                loc = (edges.data['type'] == etype)
                w = tf.reshape(self.weight[etype],
                               (self.num_bases, self.submat_in, self.submat_out))
                src = tf.reshape(tf.boolean_mask(edges.src['h'], loc),
                                 (-1, self.num_bases, self.submat_in))
                sub_msg = tf.einsum('abc,bcd->abd', src, w)
                sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
                indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
                msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
        else:
            weight = tf.reshape(tf.gather(
                self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out))
            node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in))
            msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg} 
开发者ID:dmlc,项目名称:dgl,代码行数:33,代码来源:relgraphconv.py

示例7: scatter_row

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def scatter_row(data, row_index, value):
    row_index = tf.expand_dims(row_index, 1)
    return tf.tensor_scatter_nd_update(data, row_index, value) 
开发者ID:dmlc,项目名称:dgl,代码行数:5,代码来源:tensor.py

示例8: clear_current_active_spikes

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def clear_current_active_spikes(self):
    """ Remove any spikes that arrived at the current time step
    
    Parameters:
      None
      
    Returns:
      None
    """
    # Fill in any 1's with zeros
    spike_idxs = tf.where(tf.not_equal(self.active_spikes[:, :, self.get_active_spike_idx()], 0) )
    full_idxs = tf.concat([spike_idxs, tf.ones((spike_idxs.shape[0], 1), dtype=tf.int64) * self.get_active_spike_idx()], axis=1)
    self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.zeros(full_idxs.shape[0])) 
开发者ID:jotia1,项目名称:spiking-net-tensorflow,代码行数:15,代码来源:delayedmodels.py

示例9: transform_targets_for_output

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def transform_targets_for_output(y_true, grid_size, anchor_idxs):
    # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
    N = tf.shape(y_true)[0]

    # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
    y_true_out = tf.zeros(
        (N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6))

    anchor_idxs = tf.cast(anchor_idxs, tf.int32)

    indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
    idx = 0
    for i in tf.range(N):
        for j in tf.range(tf.shape(y_true)[1]):
            if tf.equal(y_true[i][j][2], 0):
                continue
            anchor_eq = tf.equal(
                anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))

            if tf.reduce_any(anchor_eq):
                box = y_true[i][j][0:4]
                box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2

                anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
                grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32)

                # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
                indexes = indexes.write(
                    idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
                updates = updates.write(
                    idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
                idx += 1

    # tf.print(indexes.stack())
    # tf.print(updates.stack())

    return tf.tensor_scatter_nd_update(
        y_true_out, indexes.stack(), updates.stack()) 
开发者ID:microsoft,项目名称:DirectML,代码行数:41,代码来源:dataset.py

示例10: call

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def call(self, inputs, training=True):
    """
    parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
    """
    atom_features = inputs[0]
    # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
    # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
    parents = tf.cast(inputs[1], dtype=tf.int32)
    # target atoms for each step: (batch_size*max_atoms) * max_atoms
    calculation_orders = inputs[2]
    calculation_masks = inputs[3]

    n_atoms = tf.squeeze(inputs[4])
    graph_features = tf.zeros((self.max_atoms * self.batch_size,
                               self.max_atoms + 1, self.n_graph_feat))

    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
      mask = calculation_masks[:, count]
      current_round = tf.boolean_mask(calculation_orders[:, count], mask)
      batch_atom_features = tf.gather(atom_features, current_round)

      # generating index for graph features used in the inputs
      stack1 = tf.reshape(
          tf.stack(
              [tf.boolean_mask(tf.range(n_atoms), mask)] * (self.max_atoms - 1),
              axis=1), [-1])
      stack2 = tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1])
      index = tf.stack([stack1, stack2], axis=1)
      # extracting graph features for parents of the target atoms, then flatten
      # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
      batch_graph_features = tf.reshape(
          tf.gather_nd(graph_features, index),
          [-1, (self.max_atoms - 1) * self.n_graph_feat])

      # concat into the input tensor: (batch_size*max_atoms) * n_inputs
      batch_inputs = tf.concat(
          axis=1, values=[batch_atom_features, batch_graph_features])
      # DAGgraph_step maps from batch_inputs to a batch of graph_features
      # of shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of target atoms in each graph
      batch_outputs = _DAGgraph_step(batch_inputs, self.W_list, self.b_list,
                                     self.activation_fn, self.dropouts,
                                     training)

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      target_index = tf.boolean_mask(target_index, mask)
      graph_features = tf.tensor_scatter_nd_update(graph_features, target_index,
                                                   batch_outputs)
    return batch_outputs 
开发者ID:deepchem,项目名称:deepchem,代码行数:54,代码来源:layers.py

示例11: segment_top_k

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def segment_top_k(x, I, ratio, top_k_var):
    """
    Returns indices to get the top K values in x segment-wise, according to
    the segments defined in I. K is not fixed, but it is defined as a ratio of
    the number of elements in each segment.
    :param x: a rank 1 Tensor;
    :param I: a rank 1 Tensor with segment IDs for x;
    :param ratio: float, ratio of elements to keep for each segment;
    :param top_k_var: a tf.Variable created without shape validation (i.e.,
    `tf.Variable(0.0, validate_shape=False)`);
    :return: a rank 1 Tensor containing the indices to get the top K values of
    each segment in x.
    """
    I = tf.cast(I, tf.int32)
    num_nodes = tf.math.segment_sum(tf.ones_like(I), I)  # Number of nodes in each graph
    cumsum = tf.cumsum(num_nodes)  # Cumulative number of nodes (A, A+B, A+B+C)
    cumsum_start = cumsum - num_nodes  # Start index of each graph
    n_graphs = tf.shape(num_nodes)[0]  # Number of graphs in batch
    max_n_nodes = tf.reduce_max(num_nodes)  # Order of biggest graph in batch
    batch_n_nodes = tf.shape(I)[0]  # Number of overall nodes in batch
    to_keep = tf.math.ceil(ratio * tf.cast(num_nodes, tf.float32))
    to_keep = tf.cast(to_keep, I.dtype)  # Nodes to keep in each graph

    index = tf.range(batch_n_nodes)
    index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes)

    y_min = tf.reduce_min(x)
    dense_y = tf.ones((n_graphs * max_n_nodes,))
    # subtract 1 to ensure that filler values do not get picked
    dense_y = dense_y * tf.cast(y_min - 1, dense_y.dtype)
    dense_y = tf.cast(dense_y, top_k_var.dtype)
    # top_k_var is a variable with unknown shape defined in the elsewhere
    top_k_var.assign(dense_y)
    dense_y = tf.tensor_scatter_nd_update(top_k_var, index[..., None], tf.cast(x, top_k_var.dtype))
    dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes))

    perm = tf.argsort(dense_y, direction='DESCENDING')
    perm = perm + cumsum_start[:, None]
    perm = tf.reshape(perm, (-1,))

    to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs,))
    rep_times = tf.reshape(tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1), (-1,))
    mask = repeat(to_rep, rep_times)

    perm = tf.boolean_mask(perm, mask)

    return perm 
开发者ID:danielegrattarola,项目名称:spektral,代码行数:49,代码来源:ops.py

示例12: version_11

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def version_11(cls, node, **kwargs):
    axis = node.attrs.get("axis", 0)
    data = kwargs["tensor_dict"][node.inputs[0]]
    indices = kwargs["tensor_dict"][node.inputs[1]]
    updates = kwargs["tensor_dict"][node.inputs[2]]

    # poocess negative axis
    axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)

    # check are there any indices are out of bounds
    result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
    msg = 'ScatterElements indices are out of bounds, please double check the indices and retry.'
    with tf.control_dependencies(
        [tf.compat.v1.assert_equal(result, True, message=msg)]):
      # process negative indices
      indices = cls.process_neg_idx_along_axis(data, axis, indices)

      # Calculate shape of the tensorflow version of indices tensor.
      sparsified_dense_idx_shape = tf_shape(updates)

      # Move on to convert ONNX indices to tensorflow indices in 2 steps:
      #
      # Step 1:
      #   What would the index tensors look like if updates are all
      #   dense? In other words, produce a coordinate tensor for updates:
      #
      #   coordinate[i, j, k ...] = [i, j, k ...]
      #   where the shape of "coordinate" tensor is same as that of updates.
      #
      # Step 2:
      #   But the coordinate tensor needs some correction because coord
      #   vector at position axis is wrong (since we assumed update is dense,
      #   but it is not at the axis specified).
      #   So we update coordinate vector tensor elements at psotion=axis with
      #   the sparse coordinate indices.

      idx_tensors_per_axis = tf.meshgrid(
          *list(
              map(lambda x: tf.range(x, dtype=tf.dtypes.int64),
                  sparsified_dense_idx_shape)),
          indexing='ij')
      idx_tensors_per_axis[axis] = indices
      dim_expanded_idx_tensors_per_axis = list(
          map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
      coordinate = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)

      # Now the coordinate tensor is in the shape
      # [updates.shape, updates.rank]
      # we need it to flattened into the shape:
      # [product(updates.shape), updates.rank]
      indices = tf.reshape(coordinate, [-1, tf.rank(data)])
      updates = tf.reshape(updates, [-1])

      return [tf.tensor_scatter_nd_update(data, indices, updates)] 
开发者ID:onnx,项目名称:onnx-tensorflow,代码行数:56,代码来源:scatter_elements.py

示例13: version_11

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def version_11(cls, node, **kwargs):
    # GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis that identifies
    # an axis of data (by default, the outer-most axis, that is axis 0). It is an indexing operation that produces its output by
    # indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the
    # same as the shape of indices and consists of one value (gathered from the data) for each element in indices.

    axis = node.attrs.get("axis", 0)
    data = kwargs["tensor_dict"][node.inputs[0]]
    indices = kwargs["tensor_dict"][node.inputs[1]]

    # poocess negative axis
    axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)

    # check are there any indices are out of bounds
    result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
    msg = 'GatherElements indices are out of bounds,'\
      ' please double check the indices and retry.'
    with tf.control_dependencies(
        [tf.compat.v1.assert_equal(result, True, message=msg)]):
      # process negative indices
      indices = cls.process_neg_idx_along_axis(data, axis, indices)

      # adapted from reference implementation in onnx/onnx/backend/test/case/node/gatherelements.py
      if axis == 0:
        axis_perm = tf.range(tf.rank(data))
        data_swaped = data
        index_swaped = indices
      else:
        axis_perm = tf.tensor_scatter_nd_update(tf.range(tf.rank(data)),
                                                tf.constant([[0], [axis]]),
                                                tf.constant([axis, 0]))
        data_swaped = tf.transpose(data, perm=axis_perm)
        index_swaped = tf.transpose(indices, perm=axis_perm)

      idx_tensors_per_axis = tf.meshgrid(*list(
          map(lambda x: tf.range(x, dtype=index_swaped.dtype),
              index_swaped.shape.as_list())),
                                        indexing='ij')
      idx_tensors_per_axis[0] = index_swaped
      dim_expanded_idx_tensors_per_axis = list(
          map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
      index_expanded = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)

      gathered = tf.gather_nd(data_swaped, index_expanded)
      y = tf.transpose(gathered, perm=axis_perm)

      return [y] 
开发者ID:onnx,项目名称:onnx-tensorflow,代码行数:49,代码来源:gather_elements.py

示例14: scatter_to_2d

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def scatter_to_2d(tensor, segments, pad_value, output_shape=None):
    """Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`.

    For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then
    the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred
    when None is provided. In this case, the shape will be dynamic and may not be
    compatible with TPU. For TPU use case, please provide the `output_shape`
    explicitly.

    Args:
      tensor: A 1-D numeric `Tensor`.
      segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0,
        0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted.
      pad_value: A numeric value to pad the output `Tensor`.
      output_shape: A `Tensor` of size 2 telling the desired shape of the output
        tensor. If None, the output_shape will be inferred and not fixed at
        compilation time. When output_shape is smaller than needed, trucation will
        be applied.

    Returns:
      A 2-D Tensor.
    """
    with tf.compat.v1.name_scope(name='scatter_to_2d'):
        tensor = tf.convert_to_tensor(value=tensor)
        segments = tf.convert_to_tensor(value=segments)
        tensor.get_shape().assert_has_rank(1)
        segments.get_shape().assert_has_rank(1)
        tensor.get_shape().assert_is_compatible_with(segments.get_shape())

        # Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so
        # that we can use scatter_nd to distribute the value in `tensor` to 2-D. The
        # needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the
        # in-segment indices.
        index_2nd_dim = _in_segment_indices(segments)

        # Compute the output_shape.
        if output_shape is None:
            # Set output_shape to the inferred one.
            output_shape = [
                tf.reduce_max(input_tensor=segments) + 1,
                tf.reduce_max(input_tensor=index_2nd_dim) + 1
            ]
        else:
            # The output_shape may be smaller. We collapse the out-of-range ones into
            # indices [output_shape[0], 0] and then use tf.slice to remove extra row
            # and column after scatter.
            valid_segments = tf.less(segments, output_shape[0])
            valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1])
            mask = tf.logical_and(valid_segments, valid_2nd_dim)
            segments = tf.compat.v1.where(mask, segments,
                                          output_shape[0] * tf.ones_like(segments))
            index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim,
                                               tf.zeros_like(index_2nd_dim))
        # Create the 2D Tensor. For padding, we add one extra row and column and
        # then slice them to fit the output_shape.
        nd_indices = tf.stack([segments, index_2nd_dim], axis=1)
        padding = pad_value * tf.ones(
            shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype)
        tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor)
        tensor = tf.slice(tensor, begin=[0, 0], size=output_shape)
        return tensor 
开发者ID:ULTR-Community,项目名称:ULTRA,代码行数:63,代码来源:metric_utils.py

示例15: scatter_to_2d

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import tensor_scatter_nd_update [as 别名]
def scatter_to_2d(tensor, segments, pad_value, output_shape=None):
  """Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`.

  For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then
  the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred
  when None is provided. In this case, the shape will be dynamic and may not be
  compatible with TPU. For TPU use case, please provide the `output_shape`
  explicitly.

  Args:
    tensor: A 1-D numeric `Tensor`.
    segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0,
      0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted.
    pad_value: A numeric value to pad the output `Tensor`.
    output_shape: A `Tensor` of size 2 telling the desired shape of the output
      tensor. If None, the output_shape will be inferred and not fixed at
      compilation time. When output_shape is smaller than needed, trucation will
      be applied.

  Returns:
    A 2-D Tensor.
  """
  with tf.compat.v1.name_scope(name='scatter_to_2d'):
    tensor = tf.convert_to_tensor(value=tensor)
    segments = tf.convert_to_tensor(value=segments)
    tensor.get_shape().assert_has_rank(1)
    segments.get_shape().assert_has_rank(1)
    tensor.get_shape().assert_is_compatible_with(segments.get_shape())

    # Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so
    # that we can use scatter_nd to distribute the value in `tensor` to 2-D. The
    # needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the
    # in-segment indices.
    index_2nd_dim = _in_segment_indices(segments)

    # Compute the output_shape.
    if output_shape is None:
      # Set output_shape to the inferred one.
      output_shape = [
          tf.reduce_max(input_tensor=segments) + 1,
          tf.reduce_max(input_tensor=index_2nd_dim) + 1
      ]
    else:
      # The output_shape may be smaller. We collapse the out-of-range ones into
      # indices [output_shape[0], 0] and then use tf.slice to remove extra row
      # and column after scatter.
      valid_segments = tf.less(segments, output_shape[0])
      valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1])
      mask = tf.logical_and(valid_segments, valid_2nd_dim)
      segments = tf.compat.v1.where(mask, segments,
                                    output_shape[0] * tf.ones_like(segments))
      index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim,
                                         tf.zeros_like(index_2nd_dim))
    # Create the 2D Tensor. For padding, we add one extra row and column and
    # then slice them to fit the output_shape.
    nd_indices = tf.stack([segments, index_2nd_dim], axis=1)
    padding = pad_value * tf.ones(
        shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype)
    tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor)
    tensor = tf.slice(tensor, begin=[0, 0], size=output_shape)
    return tensor 
开发者ID:tensorflow,项目名称:ranking,代码行数:63,代码来源:utils.py


注:本文中的tensorflow.tensor_scatter_nd_update方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。