本文整理汇总了Python中tensorflow.python.ops.math_ops.unsorted_segment_sum函数的典型用法代码示例。如果您正苦于以下问题:Python unsorted_segment_sum函数的具体用法?Python unsorted_segment_sum怎么用?Python unsorted_segment_sum使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了unsorted_segment_sum函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: quantiles_ready
def quantiles_ready():
"""The subgraph for when the quantiles are ready."""
quantized_feature = quantile_ops.quantiles([sparse_column_values], [],
[quantile_buckets], [])
quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64)
quantized_feature = array_ops.reshape(quantized_feature, [-1])
example_indices, _ = array_ops.split(
sparse_column_indices, num_or_size_splits=2, axis=1)
example_indices = array_ops.squeeze(example_indices, [1])
filtered_gradients = array_ops.gather(gradients, example_indices)
filtered_hessians = array_ops.gather(hessians, example_indices)
filtered_partition_ids = array_ops.gather(example_partition_ids,
example_indices)
unique_partitions, mapped_partitions = array_ops.unique(
example_partition_ids)
# Compute aggregate stats for each partition.
per_partition_gradients = math_ops.unsorted_segment_sum(
gradients, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
hessians, mapped_partitions, array_ops.size(unique_partitions))
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
bias_feature_ids = array_ops.fill(
array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
partition_ids = array_ops.concat(
[unique_partitions, filtered_partition_ids], 0)
filtered_gradients = array_ops.concat(
[per_partition_gradients, filtered_gradients], 0)
filtered_hessians = array_ops.concat(
[per_partition_hessians, filtered_hessians], 0)
bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)
return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
示例2: _full_batch_training_op
def _full_batch_training_op(self, inputs, cluster_idx_list, cluster_centers):
"""Creates an op for training for full batch case.
Args:
inputs: list of input Tensors.
cluster_idx_list: A vector (or list of vectors). Each element in the
vector corresponds to an input row in 'inp' and specifies the cluster id
corresponding to the input.
cluster_centers: Tensor Ref of cluster centers.
Returns:
An op for doing an update of mini-batch k-means.
"""
cluster_sums = []
cluster_counts = []
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp):
cluster_sums.append(
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
cluster_counts.append(
math_ops.unsorted_segment_sum(
array_ops.reshape(
array_ops.ones(
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, self._num_clusters))
with ops.colocate_with(cluster_centers):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
if self._clusters_l2_normalized():
new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
return state_ops.assign(cluster_centers, new_clusters_centers)
示例3: active_inputs
def active_inputs():
"""The normal flow when the handler is active."""
# Remove the second column of example indices matrix since it is not
# useful.
example_indices, _ = array_ops.split(
self._sparse_int_column.indices, num_or_size_splits=2, axis=1)
example_indices = array_ops.squeeze(example_indices, [1])
filtered_gradients = array_ops.gather(gradients, example_indices)
filtered_hessians = array_ops.gather(hessians, example_indices)
filtered_partition_ids = array_ops.gather(example_partition_ids,
example_indices)
unique_partitions, mapped_partitions = array_ops.unique(
example_partition_ids)
# Compute aggregate stats for each partition.
# The bias is computed on gradients and hessians (and not
# filtered_gradients) which have exactly one value per example, so we
# don't double count a gradient in multivalent columns.
# Since unsorted_segment_sum can be numerically unstable, use 64bit
# operation.
gradients64 = math_ops.cast(gradients, dtypes.float64)
hessians64 = math_ops.cast(hessians, dtypes.float64)
per_partition_gradients = math_ops.unsorted_segment_sum(
gradients64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
hessians64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_gradients = math_ops.cast(per_partition_gradients,
dtypes.float32)
per_partition_hessians = math_ops.cast(per_partition_hessians,
dtypes.float32)
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
# Bias is added to the stats even if there are no examples with values in
# the current sparse column. The reason is that the other example batches
# might have values in these partitions so we have to keep the bias
# updated.
bias_feature_ids = array_ops.fill(
array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
partition_ids = array_ops.concat(
[unique_partitions, filtered_partition_ids], 0)
filtered_gradients = array_ops.concat(
[per_partition_gradients, filtered_gradients], 0)
filtered_hessians = array_ops.concat(
[per_partition_hessians, filtered_hessians], 0)
feature_ids = array_ops.concat(
[bias_feature_ids, self._sparse_int_column.values], 0)
# Dimension is always zero for sparse int features.
dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64)
feature_ids_and_dimensions = array_ops.stack(
[feature_ids, dimension_ids], axis=1)
return (partition_ids, feature_ids_and_dimensions, filtered_gradients,
filtered_hessians)
示例4: quantiles_ready
def quantiles_ready():
"""The subgraph for when the quantiles are ready."""
quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [],
[quantile_buckets],
[sparse_column_indices])
quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64)
quantized_feature = array_ops.squeeze(quantized_feature, axis=0)
example_indices, _ = array_ops.split(
sparse_column_indices, num_or_size_splits=2, axis=1)
example_indices = array_ops.squeeze(example_indices, [1])
filtered_gradients = array_ops.gather(gradients, example_indices)
filtered_hessians = array_ops.gather(hessians, example_indices)
filtered_partition_ids = array_ops.gather(example_partition_ids,
example_indices)
unique_partitions, mapped_partitions = array_ops.unique(
example_partition_ids)
# Compute aggregate stats for each partition.
# Since unsorted_segment_sum can be numerically unstable, use 64bit
# operation.
gradients64 = math_ops.cast(gradients, dtypes.float64)
hessians64 = math_ops.cast(hessians, dtypes.float64)
per_partition_gradients = math_ops.unsorted_segment_sum(
gradients64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
hessians64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_gradients = math_ops.cast(per_partition_gradients,
dtypes.float32)
per_partition_hessians = math_ops.cast(per_partition_hessians,
dtypes.float32)
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
bias_feature_ids = array_ops.fill(
array_ops.shape(unique_partitions), _BIAS_FEATURE_ID)
bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64)
zeros = array_ops.zeros_like(bias_feature_ids)
bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1)
partition_ids = array_ops.concat(
[unique_partitions, filtered_partition_ids], 0)
filtered_gradients = array_ops.concat(
[per_partition_gradients, filtered_gradients], 0)
filtered_hessians = array_ops.concat(
[per_partition_hessians, filtered_hessians], 0)
bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0)
return partition_ids, bucket_ids, filtered_gradients, filtered_hessians
示例5: _apply_sparse
def _apply_sparse(self, g_t, x_tm1, prepare):
""""""
idxs, idxs_ = array_ops.unique(g_t.indices)
g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
updates = []
if self._mu > 0:
m_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', self._mu)
m_t_ = array_ops.gather(m_and_t[0], idxs)
gamma_t = ops.convert_to_tensor(self._gamma)
m_bar_t_ = (1-gamma_t)*m_t_ + gamma_t*g_t_
updates.extend(m_and_t)
else:
m_bar_t_ = g_t_
if self._ups > 0:
v_and_t = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', self._ups)
v_t_ = array_ops.gather(v_and_t[0], idxs)
eps_t = ops.convert_to_tensor(self._eps)
v_bar_t_ = math_ops.sqrt(v_t_ + eps_t)
updates.extend(v_and_t)
else:
v_bar_t_ = 1.
lr_t = ops.convert_to_tensor(self._lr)
s_t_ = lr_t * m_bar_t_ / v_bar_t_
return [[s_t_, x_tm1, idxs, g_t]] + updates
示例6: _prepare
def _prepare(self, grads_and_vars):
""""""
if self._lr is None:
sTy = 0
sTs = 0
yTy = 0
for g_t, x_tm1 in grads_and_vars:
if g_t is None:
continue
with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
if isinstance(g_t, ops.Tensor):
g_tm1 = self.get_slot(x_tm1, 'g')
s_tm1 = self.get_slot(x_tm1, 's')
y_t = (g_t-g_tm1)
sTy += math_ops.reduce_sum(s_tm1*y_t)
sTs += math_ops.reduce_sum(s_tm1**2)
yTy += math_ops.reduce_sum(y_t**2)
else:
idxs, idxs_ = array_ops.unique(g_t.indices)
g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
g_tm1 = self.get_slot(x_tm1, 'g')
g_tm1_ = array_ops.gather(g_tm1, idxs)
s_tm1 = self.get_slot(x_tm1, 's')
s_tm1_ = array_ops.gather(s_tm1, idxs)
y_t_ = (g_t_-g_tm1_)
sTy += math_ops.reduce_sum(s_tm1_*y_t_)
sTs += math_ops.reduce_sum(s_tm1_**2)
yTy += math_ops.reduce_sum(y_t_**2)
sTy = math_ops.abs(sTy)
self._lr = sTs / (sTy + self._eps)
示例7: testAggregateGradients
def testAggregateGradients(self):
def fn(x):
ind1 = tensor.Tensor(np.array([0, 1]))
ind2 = tensor.Tensor(np.array([2, 3]))
ind3 = tensor.Tensor(np.array([1, 3]))
# A mixture of IndexedSlices and dense tensor to aggregate.
g1 = embedding_ops.embedding_lookup(x, ind1)
g2 = embedding_ops.embedding_lookup(x, ind2)
g3 = embedding_ops.embedding_lookup(x, ind3)
g4 = math_ops.reduce_sum(x * tensor.Tensor(2.0))
return g1 * g2 * g3 * g4
var_np = np.random.rand(4, 2).astype(np.float32)
var = tensor.Tensor(var_np)
grad = backprop.gradients_function(fn, [0])(var)[0]
with context.graph_mode(), self.test_session():
tf_var = array_ops.constant(var_np, dtypes.float32)
tf_ind1 = array_ops.constant([0, 1])
tf_ind2 = array_ops.constant([2, 3])
tf_ind3 = array_ops.constant([1, 3])
tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1)
tf_g2 = embedding_ops.embedding_lookup(tf_var, tf_ind2)
tf_g3 = embedding_ops.embedding_lookup(tf_var, tf_ind3)
tf_g4 = math_ops.reduce_sum(tf_var * 2.0, reduction_indices=(0, 1))
tf_y = tf_g1 * tf_g2 * tf_g3 * tf_g4
tf_grad = gradients.gradients(tf_y, [tf_var])[0]
tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
self.assertAllClose(grad.numpy(), tf_dense_grad.eval())
示例8: approximate_hessian
def approximate_hessian(self, grads_and_vars, name=None):
"""
I haven't tested this yet so I have no idea if it works, but even if it
does it's probably super slow, and either way nothing else has been modified
to deal with it.
"""
gv = 0
var_refs = []
for g_t, x_tm1 in grads_and_vars:
var_refs.append(x_tm1.ref())
if g_t is None:
continue
with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
if isinstance(g_t, ops.Tensor):
gv += math_ops.reduce_sum(g_t * random_ops.random_normal(g_t.get_shape()))
else:
idxs, idxs_ = array_ops.unique(g_t.indices)
g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
gv += math_ops.reduce_sum(g_t_ * random_ops.random_normal(g_t_.get_shape()))
hesses = gradients.gradients(gv, var_refs,
gate_gradients=(gate_gradients == Optimizer.GATE_OP),
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops)
return zip([g_t for g_t, _ in grads_and_vars], [x_tm1 for _, x_tm1 in grads_and_vars], hesses)
示例9: _apply_sparse
def _apply_sparse(self, grad, var):
if len(grad.indices.get_shape()) == 1:
grad_indices = grad.indices
grad_values = grad.values
else:
grad_indices = array_ops.reshape(grad.indices, [-1])
grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value])
gidxs, metagidxs = array_ops.unique(grad_indices)
sizegidxs = array_ops.size(gidxs)
gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs)
# m_t = mu * m + (1 - mu) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = gvals * (1 - self._mu_t)
m_t = state_ops.scatter_update(m, gidxs,
array_ops.gather(m, gidxs) * self._mu_t,
use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values,
use_locking=self._use_locking)
m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power)
# m_bar = mu * m_t + (1 - mu) * g_t
m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power)
var_update = state_ops.scatter_sub(var, gidxs,
self._lr_t * m_bar,
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t])
示例10: DynamicStitchGrads
def DynamicStitchGrads(op, grad):
num_values = len(op.inputs) // 2
indices_grad = [None] * num_values
def AsInt32(x):
return (x if op.inputs[0].dtype == dtypes.int32 else
math_ops.cast(x, dtypes.int32))
idxs = [AsInt32(array_ops.reshape(op.inputs[i], (-1,)))
for i in range(num_values)]
if isinstance(grad, ops.IndexedSlices):
output_shape = array_ops.shape(op.outputs[0])
output_rows = output_shape[0]
grad = math_ops.unsorted_segment_sum(grad.values, grad.indices,
output_rows)
values_grad = []
zeros = array_ops.zeros_like(grad)
idx_zeros = [zeros[:array_ops.shape(x)[0]] for x in idxs]
grad_range = math_ops.range(array_ops.shape(grad)[0])
for i in range(num_values):
if i == num_values - 1:
v_grad = grad
else:
v_grad = data_flow_ops.dynamic_stitch(
[grad_range] + idxs[i + 1:], [grad] + idx_zeros[i + 1:])
v_grad = array_ops.gather(v_grad, AsInt32(op.inputs[i]))
values_grad += [v_grad]
return indices_grad + values_grad
示例11: testAggregateGradients
def testAggregateGradients(self):
def fn(x):
ind1 = constant_op.constant(np.array([0, 1]))
ind2 = constant_op.constant(np.array([2, 3]))
ind3 = constant_op.constant(np.array([1, 3]))
# A mixture of IndexedSlices and dense tensor to aggregate.
g1 = embedding_ops.embedding_lookup(x, ind1)
g2 = embedding_ops.embedding_lookup(x, ind2)
g3 = embedding_ops.embedding_lookup(x, ind3)
g4 = math_ops.reduce_sum(x * constant_op.constant(2.0))
return g1 * g2 * g3 * g4
var_np = np.random.rand(4, 2).astype(np.float32)
var = constant_op.constant(var_np)
grad = backprop.gradients_function(fn, [0])(var)[0]
grad = self.evaluate(ops.convert_to_tensor(grad))
if not context.executing_eagerly():
tf_var = array_ops.constant(var_np, dtypes.float32)
tf_ind1 = array_ops.constant([0, 1])
tf_ind2 = array_ops.constant([2, 3])
tf_ind3 = array_ops.constant([1, 3])
tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1)
tf_g2 = embedding_ops.embedding_lookup(tf_var, tf_ind2)
tf_g3 = embedding_ops.embedding_lookup(tf_var, tf_ind3)
tf_g4 = math_ops.reduce_sum(tf_var * 2.0, axis=(0, 1))
tf_y = tf_g1 * tf_g2 * tf_g3 * tf_g4
tf_grad = gradients.gradients(tf_y, [tf_var])[0]
tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
self.assertAllClose(grad, self.evaluate(tf_dense_grad))
示例12: _TileGrad
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
input_shape = array_ops.shape(op.inputs[0])
# We interleave multiples and input_shape to get split_shape,
# reshape grad to split_shape, and reduce along all even
# dimensions (the tiled dimensions) to get the result
# with shape input_shape. For example
# input_shape = [20, 30, 40]
# multiples = [2, 3, 4]
# split_shape = [2, 20, 3, 30, 4, 40]
# axes = [0, 2, 4]
split_shape = array_ops.reshape(
array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
axes = math_ops.range(0, array_ops.size(split_shape), 2)
# Sum reduces grad along the first dimension for IndexedSlices
if isinstance(grad, ops.IndexedSlices):
grad = math_ops.unsorted_segment_sum(
grad.values,
math_ops.mod(grad.indices, input_shape[0]),
input_shape[0])
split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
if not context.executing_eagerly():
input_grad.set_shape(op.inputs[0].get_shape())
return [input_grad, None]
示例13: testGradientMatchesSegmentSum
def testGradientMatchesSegmentSum(self):
# Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
# and compare the outputs, which should be identical.
# NB: for this test to work, indices must be valid for SegmentSum, namely
# it must be sorted, the indices must be contiguous, and num_segments
# must be max(indices) + 1.
indices = [0, 0, 1, 1, 1, 2, 3, 4, 5]
n = len(indices)
num_cols = 2
shape = [n, num_cols]
num_segments = max(indices) + 1
for dtype in self.differentiable_dtypes:
with self.cached_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtype)
# Results from UnsortedSegmentSum
unsorted_s = math_ops.unsorted_segment_sum(
data=tf_x, segment_ids=indices, num_segments=num_segments)
unsorted_jacob_t, unsorted_jacob_n = (
gradient_checker.compute_gradient(tf_x, shape, unsorted_s,
[num_segments, num_cols],
x_init_value=np_x, delta=1))
# Results from SegmentSum
sorted_s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
sorted_jacob_t, sorted_jacob_n = gradient_checker.compute_gradient(
tf_x,
shape,
sorted_s, [num_segments, num_cols],
x_init_value=np_x,
delta=1)
self.assertAllClose(unsorted_jacob_t, sorted_jacob_t)
self.assertAllClose(unsorted_jacob_n, sorted_jacob_n)
示例14: testDropNegatives
def testDropNegatives(self):
# Note: the test is done by replacing segment_ids with 8 to -1
# for index and replace values generated by numpy with 0.
dtypes = [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
]
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
num_segments = 12
for indices in indices_flat, indices_flat.reshape(5, 2):
shape = indices.shape + (2,)
for dtype in dtypes:
with self.test_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtype)
np_ans = self._segmentReduce(
indices, np_x, np.add, op2=None, num_segments=num_segments)
# Replace np_ans[8] with 0 for the value
np_ans[8:] = 0
# Replace 8 with -1 in indices
np.place(indices, indices == 8, [-1])
s = math_ops.unsorted_segment_sum(
data=tf_x, segment_ids=indices, num_segments=num_segments)
tf_ans = s.eval()
self.assertAllClose(np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
示例15: _GatherV2Grad
def _GatherV2Grad(op, grad):
"""Gradient for GatherV2 op."""
# params can be large, so colocate the shape calculation with it.
#
# params can be very large for sparse model, array_ops.shape raises
# exception on the Windows platform when any dimension is larger than
# int32. params_shape is not used in optimizer apply_sparse gradients,
# so it's fine to convert it back to int32 regardless of truncation.
params = op.inputs[0]
with ops.colocate_with(params):
params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
params_shape = math_ops.to_int32(params_shape)
indices = op.inputs[1]
indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
axis = op.inputs[2]
axis_static = tensor_util.constant_value(axis)
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0:
values_shape = array_ops.concat([indices_size, params_shape[1:]], 0)
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, indices_size)
return [ops.IndexedSlices(values, indices, params_shape), None, None]
outer_shape = params_shape[:axis]
outer_dims = array_ops.size(outer_shape)
inner_shape = params_shape[axis:][1:]
inner_dims = array_ops.size(inner_shape)
outer_axes_indices = math_ops.range(outer_dims)
inner_axes_indices = math_ops.range(outer_dims + 1,
outer_dims + 1 + inner_dims)
values_shape = array_ops.concat([outer_shape, indices_size, inner_shape], 0)
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, indices_size)
# We need to sum up every slice `values[..., i, ....]` corresponding to
# `params[..., indices[i], ...]`. Since `unsorted_segment_sum` does not
# support an axis parameter, we transpose the gather dimension to the front,
# then use `unsorted_segment_sum` to build a
# [gather_axis, outer_axes, inner_axes] tensor with all the gradients
# affecting each index in `gather_axis` summed up.
transpose_dims = array_ops.concat(
[[outer_dims], outer_axes_indices, inner_axes_indices], 0)
values_transpose = array_ops.transpose(values, transpose_dims)
num_segments = params_shape[axis]
params_grad = math_ops.unsorted_segment_sum(
values_transpose, indices, num_segments)
# Inverts the above transpose by moving dimension 0 back to its original
# position.
invert_transpose_dims = array_ops.concat(
[outer_axes_indices + 1, [0], inner_axes_indices], 0)
params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
return [params_grad, None, None]