本文整理汇总了Python中tensorflow.python.ops.state_ops.scatter_add函数的典型用法代码示例。如果您正苦于以下问题:Python scatter_add函数的具体用法?Python scatter_add怎么用?Python scatter_add使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了scatter_add函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _sparse_moving_average
def _sparse_moving_average(self, x_tm1, idxs, b_t_, name, beta=.9):
"""
Creates a moving average for a sparse variable.
Inputs:
x_tm1: the associated parameter (e.g. a weight matrix)
idxs: the tensor representing the indices used
b_t_: the value to accumulate (e.g. slices of the gradient)
name: a string to use to retrieve it later (e.g. 'm')
beta: the decay factor (defaults to .9)
Outputs:
a_t: the average after moving (same shape as x_tm1, not b_t_)
t: the internal timestep (used to correct initialization bias)
"""
a_tm1 = self._zeros_slot(x_tm1, '%s' % name, self._name)
a_tm1_ = array_ops.gather(a_tm1, idxs)
tm1 = self._zeros_idx_slot(x_tm1, '%s/tm1' % name, self._name)
tm1_ = array_ops.gather(tm1, idxs)
t = state_ops.scatter_add(tm1, idxs, tm1_*0+1, use_locking=self._use_locking)
t_ = array_ops.gather(t, idxs)
if beta < 1:
beta_t = ops.convert_to_tensor(beta, name='%s/decay' % name)
beta_t_ = beta_t * (1-beta_t**tm1_) / (1-beta_t**t_)
else:
beta_t_ = tm1_/t_
a_t = state_ops.scatter_update(a_tm1, idxs, beta_t_*a_tm1_, use_locking=self._use_locking)
a_t = state_ops.scatter_add(a_t, idxs, (1-beta_t)*b_t_, use_locking=self._use_locking)
return a_t, t
示例2: _apply_sparse
def _apply_sparse(self, grad, var):
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad.values * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad.values * grad.values) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values,
use_locking=self._use_locking)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
示例3: _get_coordinatewise_learning_rate
def _get_coordinatewise_learning_rate(self, grad, var):
# Compute the learning rate using a moving average for the diagonal of BB^T
avg_first = self.get_slot(var, 'first_moment')
avg_second = self.get_slot(var, 'second_moment')
decay_tensor = math_ops.cast(self._decay_tensor, var.dtype)
batch_size = math_ops.cast(self._batch_size_tensor, var.dtype)
# Create an estimator for the moving average of gradient mean and variance
# via Welford's algorithm
if isinstance(grad, ops.Tensor):
delta = grad - avg_first
first_moment_update = avg_first.assign_add(
array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype),
1. - decay_tensor) * delta)
with ops.control_dependencies([first_moment_update]):
second_moment_update = avg_second.assign_add(
math_ops.cast(self._counter < 1, var.dtype) *
-(1. - decay_tensor) * (
avg_second - decay_tensor * math_ops.square(delta)))
diag_preconditioner = control_flow_ops.with_dependencies(
[second_moment_update],
clip_ops.clip_by_value(avg_second, 1e-12, 1e12))
elif isinstance(grad, ops.IndexedSlices):
delta = grad.values - array_ops.gather_nd(avg_first, grad.indices)
first_moment_update = state_ops.scatter_add(
avg_first,
grad.indices,
array_ops.where(self._counter < 1,
math_ops.cast(1., var.dtype),
1. - decay_tensor) * delta)
with ops.control_dependencies([first_moment_update]):
avg_second = state_ops.scatter_add(
avg_second,
grad.indices,
math_ops.cast(self._counter < 1, var.dtype) *
-(1. - decay_tensor) * (
array_ops.gather_nd(avg_second, grad.indices) - decay_tensor *
math_ops.square(delta)))
avg_second = array_ops.gather_nd(avg_second, grad.indices)
# TODO(b/70783772)
diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12)
else:
raise errors.InvalidArgumentError(
None, None, 'grad must of type Tensor or IndexedSlice')
diag_preconditioner *= batch_size
if self._use_single_learning_rate:
diag_preconditioner = math_ops.reduce_mean(diag_preconditioner)
# From Theorem 2 Corollary 1 of Mandt et al. 2017
return 2. * batch_size / (
math_ops.cast(self._total_num_examples, var.dtype.base_dtype) *
diag_preconditioner)
示例4: testWrongShape
def testWrongShape(self):
# Indices and values mismatch.
var = variables.Variable(
array_ops.zeros(shape=[1024, 64, 64], dtype=dtypes.float32))
indices = array_ops.placeholder(dtypes.int32, shape=[32])
values = array_ops.placeholder(dtypes.float32, shape=[33, 64, 64])
with self.assertRaises(ValueError):
state_ops.scatter_add(var, indices, values)
# Var and values mismatch.
values = array_ops.placeholder(dtypes.float32, shape=[32, 64, 63])
with self.assertRaises(ValueError):
state_ops.scatter_add(var, indices, values)
示例5: _apply_sparse
def _apply_sparse(self, grad, var):
if len(grad.indices.get_shape()) == 1:
grad_indices = grad.indices
grad_values = grad.values
else:
grad_indices = array_ops.reshape(grad.indices, [-1])
grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value])
gidxs, metagidxs = array_ops.unique(grad_indices)
sizegidxs = array_ops.size(gidxs)
gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs)
# m_t = mu * m + (1 - mu) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = gvals * (1 - self._mu_t)
m_t = state_ops.scatter_update(m, gidxs,
array_ops.gather(m, gidxs) * self._mu_t,
use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values,
use_locking=self._use_locking)
m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power)
# m_bar = mu * m_t + (1 - mu) * g_t
m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power)
var_update = state_ops.scatter_sub(var, gidxs,
self._lr_t * m_bar,
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t])
示例6: _apply_sparse
def _apply_sparse(self, grad, var):
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad.values * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
use_locking=self._use_locking)
# u_t = max(beta_2 * u_{t-1}, L1(g_t))
# theta_t = theta_{t-1} - alpha/(1-beta_1).m_t/u_t
v = self.get_slot(var, "v")
g_abs_values = tensorflow.abs(g_t)
v_t = state_ops.assign(v, v * beta_2, use_locking = self._use_locking)
v_t = state_ops.assign_max(v_t, grad.indices, g_abs_values,
use_locking=self._use_locking)
var_update = state_ops.assign_sub(var,
lr*m_t/(v_t*(1 - beta_1)),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
示例7: _apply_sparse
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values, var, grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking),
lambda x, i, v: state_ops.scatter_update( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking))
示例8: _apply_sparse
def _apply_sparse(self, grad, var):
lr = self._lr_t * math_ops.sqrt(1 - self._beta2_power) / (1 - self._beta1_power)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad.values * (1 - self._beta1_t)
m_t = state_ops.assign(m, m * self._beta1_t, use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values, use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t)
v_t = state_ops.assign(v, v * self._beta2_t, use_locking=self._use_locking)
v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values, use_locking=self._use_locking)
v_sqrt = tf.pow(v_t, self._pow_t)
var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + self._epsilon_t), use_locking=self._use_locking)
# regularization
var_update = state_ops.assign_sub(var_update, self._sparse_regularization * var, use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
示例9: _apply_sparse
def _apply_sparse(self, grad, var):
# ms_t = decay * ms + (1 - decay) * (g_t * g_t)
ms = self.get_slot(var, "rms") # should not be named rms when it's ms
print('---SPARSE TIME---')
print('lr: ' + str(self._learning_rate_tensor.get_shape()))
print('decay: ' + str(self._decay_tensor.get_shape()))
print('momentum: ' + str(self._momentum_tensor.get_shape()))
print('epsilon: ' + str(self._epsilon_tensor.get_shape()))
print('ms: ' + str(ms.get_shape()))
print('grad.values: ' + str(grad.values.get_shape()))
ms_scaled_g_values = (grad.values * grad.values) * \
(1 - self._decay_tensor)
print('ms_scaled_g_values:' + str(ms_scaled_g_values.get_shape()))
# no clue what these ops does
ms_t = state_ops.assign(ms, ms * self._decay_tensor,
use_locking=self._use_locking)
print('ms_t: ' + str(ms_t.get_shape()))
ms_t = state_ops.scatter_add(ms_t, grad.indices, ms_scaled_g_values,
use_locking=self._use_locking)
print('ms_t: ' + str(ms_t.get_shape()))
rms = math_ops.sqrt(ms_t)
print('rms: ' + str(rms.get_shape()))
rms += self._epsilon_tensor
print('rms: ' + str(rms.get_shape()))
mom = self.get_slot(var, "momentum")
print('mom: ' + str(mom.get_shape()))
sparse_grad = self.get_slot(var, "sparse_grad")
sparse_grad_t = state_ops.assign(sparse_grad, sparse_grad, use_locking=self._use_locking)
sparse_grad_t = state_ops.scatter_add(sparse_grad, grad.indices, grad.values*self._learning_rate, use_locking=self._use_locking)
mom_scaled_g_values = sparse_grad_t / rms
print('mom_scaled_g_values: ' + str(mom.get_shape()))
mom_t = state_ops.assign(mom, mom * self._momentum_tensor,
use_locking=self._use_locking)
print('mom_t: ' + str(mom_t.get_shape()))
mom_t += mom_scaled_g_values
# mom_t = state_ops.scatter_add(mom_t, grad.indices, mom_scaled_g_values,
# use_locking=self._use_locking)
print('mom_t: ' + str(mom_t.get_shape()))
var_update = state_ops.assign_sub(var, mom_t,
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, ms_t, mom_t])
示例10: _get_partitioned_update_ops
def _get_partitioned_update_ops(self,
v_num,
num_partitions_by_var,
p_assignments_by_var,
gather_ids_by_var,
weights,
full_update,
p_assignments,
num_partitions):
"""Get updates for partitioned variables."""
num_partitions = num_partitions_by_var[v_num]
p_assignments = p_assignments_by_var[v_num]
gather_ids = gather_ids_by_var[v_num]
updates = data_flow_ops.dynamic_partition(
full_update, p_assignments, num_partitions)
update_ops = []
for p in range(num_partitions):
with ops.colocate_with(weights[p]):
result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p])
update_ops.append(result)
return update_ops
示例11: testScatterAddStateOps
def testScatterAddStateOps(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add")
state_ops.scatter_add(v, [1], [3])
self.assertAllEqual([1.0, 5.0], v.numpy())
示例12: histogram_fixed_width
def histogram_fixed_width(values, value_range, nbins=100, use_locking=True, dtype=dtypes.int32, name=None):
"""Return histogram of values.
Given the tensor `values`, this operation returns a rank 1 histogram counting
the number of entries in `values` that fell into every bin. The bins are
equal width and determined by the arguments `value_range` and `nbins`.
Args:
values: Numeric `Tensor`.
value_range: Shape [2] `Tensor`. new_values <= value_range[0] will be
mapped to hist[0], values >= value_range[1] will be mapped to hist[-1].
Must be same dtype as new_values.
nbins: Integer number of bins in this histogram.
use_locking: Boolean.
If `True`, use locking during the operation (optional).
dtype: dtype for returned histogram.
name: A name for this operation (defaults to 'histogram_fixed_width').
Returns:
A `Variable` holding histogram of values.
Examples:
```python
# Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
nbins = 5
value_range = [0.0, 5.0]
new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
with tf.default_session() as sess:
hist = tf.histogram_fixed_width(new_values, value_range, nbins=5)
variables.initialize_all_variables().run()
sess.run(hist) => [2, 1, 1, 0, 2]
```
"""
with variable_scope.variable_op_scope([values, value_range], name, "histogram_fixed_width") as scope:
values = ops.convert_to_tensor(values, name="values")
values = array_ops.reshape(values, [-1])
value_range = ops.convert_to_tensor(value_range, name="value_range")
# Map tensor values that fall within value_range to [0, 1].
scaled_values = math_ops.truediv(values - value_range[0], value_range[1] - value_range[0], name="scaled_values")
# map tensor values within the open interval value_range to {0,.., nbins-1},
# values outside the open interval will be zero or less, or nbins or more.
indices = math_ops.floor(nbins * scaled_values, name="indices")
# Clip edge cases (e.g. value = value_range[1]) or "outliers."
indices = math_ops.cast(clip_ops.clip_by_value(indices, 0, nbins - 1), dtypes.int32)
# Dummy vector to scatter.
# TODO(langmore) Replace non-ideal creation of large dummy vector once an
# alternative to scatter is available.
updates = array_ops.ones_like(indices, dtype=dtype)
hist = variable_scope.get_variable(
"hist", initializer=array_ops.zeros_initializer([nbins], dtype=dtype), trainable=False
)
hist_assign_zero = hist.assign(array_ops.zeros_like(hist))
with ops.control_dependencies([hist_assign_zero]):
return state_ops.scatter_add(hist, indices, updates, use_locking=use_locking, name=scope.name)
示例13: minimize
def minimize(self, global_step=None, name=None):
"""Add operations to train a linear model by minimizing the loss function.
Args:
global_step: Optional `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation.
Returns:
An Operation that updates the variables passed in the constructor.
"""
# Technically, the op depends on a lot more than the variables,
# but we'll keep the list short.
with name_scope(name, 'sdca/minimize'):
sparse_example_indices = []
sparse_feature_indices = []
sparse_features_values = []
for sf in self._examples['sparse_features']:
sparse_example_indices.append(sf.example_indices)
sparse_feature_indices.append(sf.feature_indices)
# If feature values are missing, sdca assumes a value of 1.0f.
if sf.feature_values is not None:
sparse_features_values.append(sf.feature_values)
example_ids_hashed = sdca_fprint(
convert_to_tensor(self._examples['example_ids']))
example_state_data = self._hashtable.lookup(example_ids_hashed)
# Solver returns example_state_update, new delta sparse_feature_weights
# and delta dense_feature_weights.
weights_tensor = self._convert_n_to_tensor(self._slots[
'unshrinked_sparse_features_weights'])
sparse_weights = []
sparse_indices = []
for w, i in zip(weights_tensor, sparse_feature_indices):
# Find the feature ids to lookup in the variables.
with ops.device(w.device):
sparse_indices.append(
math_ops.cast(
array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
dtypes.int64))
sparse_weights.append(array_ops.gather(w, sparse_indices[-1]))
esu, sfw, dfw = sdca_optimizer(
sparse_example_indices,
sparse_feature_indices,
sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
convert_to_tensor(self._examples['example_weights']),
convert_to_tensor(self._examples['example_labels']),
sparse_indices,
sparse_weights,
self._convert_n_to_tensor(self._slots[
'unshrinked_dense_features_weights']),
example_state_data,
loss_type=self._options['loss_type'],
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization(),
num_loss_partitions=self._num_loss_partitions(),
num_inner_iterations=1)
with ops.control_dependencies([esu]):
update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
# Update the weights before the proximal step.
for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'],
sparse_indices, sfw):
update_ops.append(state_ops.scatter_add(w, i, u))
for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
update_ops.append(w.assign_add(u))
if not global_step:
return control_flow_ops.group(*update_ops)
with ops.control_dependencies(update_ops):
return state_ops.assign_add(global_step, 1, name=name).op
示例14: minimize
#.........这里部分代码省略.........
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into num_partitions
# separate lists.
gather_ids = data_flow_ops.dynamic_partition(new_ids,
p_assignments,
num_partitions)
# Add these into the dictionaries for use in the later update.
num_partitions_by_var[v_num] = num_partitions
p_assignments_by_var[v_num] = p_assignments
gather_ids_by_var[v_num] = gather_ids
# Gather the weights from each partition.
partition_gathered_weights = []
for p in range(num_partitions):
with ops.colocate_with(w[p]):
partition_gathered_weights.append(
array_ops.gather(w[p], gather_ids[p]))
# Stitch the weights back together in the same order they were before
# we dynamic_partitioned them.
condition_indices = data_flow_ops.dynamic_partition(
math_ops.range(array_ops.shape(new_ids)[0]),
p_assignments, num_partitions)
batch_gathered_weights = data_flow_ops.dynamic_stitch(
condition_indices, partition_gathered_weights)
else:
w_as_tensor = internal_convert_to_tensor(w)
with ops.device(w_as_tensor.device):
batch_gathered_weights = array_ops.gather(
w_as_tensor, sparse_idx)
sparse_weights.append(batch_gathered_weights)
# pylint: disable=protected-access
if compat.forward_compatible(year=2018, month=10, day=30):
esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2(
sparse_example_indices,
sparse_feature_indices,
sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
internal_convert_to_tensor(self._examples['example_weights']),
internal_convert_to_tensor(self._examples['example_labels']),
sparse_indices,
sparse_weights,
self._convert_n_to_tensor(self._slots[
'unshrinked_dense_features_weights']),
example_state_data,
loss_type=self._options['loss_type'],
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization(),
num_loss_partitions=self._num_loss_partitions(),
num_inner_iterations=1,
adaptive=self._adaptive())
else:
esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
sparse_example_indices,
sparse_feature_indices,
sparse_features_values,
self._convert_n_to_tensor(self._examples['dense_features']),
internal_convert_to_tensor(self._examples['example_weights']),
internal_convert_to_tensor(self._examples['example_labels']),
sparse_indices,
sparse_weights,
self._convert_n_to_tensor(self._slots[
'unshrinked_dense_features_weights']),
example_state_data,
loss_type=self._options['loss_type'],
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization(),
num_loss_partitions=self._num_loss_partitions(),
num_inner_iterations=1,
adaptative=self._adaptive())
# pylint: enable=protected-access
with ops.control_dependencies([esu]):
update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
# Update the weights before the proximal step.
for v_num, (w, i, u) in enumerate(
zip(self._slots['unshrinked_sparse_features_weights'],
sparse_indices, sfw)):
if (isinstance(w, var_ops.PartitionedVariable) or
isinstance(w, list)):
update_ops += self._get_partitioned_update_ops(
v_num, num_partitions_by_var, p_assignments_by_var,
gather_ids_by_var, w, u, p_assignments, num_partitions)
else:
update_ops.append(state_ops.scatter_add(w, i, u))
for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
if (isinstance(w, var_ops.PartitionedVariable) or
isinstance(w, list)):
split_updates = array_ops.split(
u, num_or_size_splits=[v.shape.as_list()[0] for v in w])
for v, split_update in zip(w, split_updates):
update_ops.append(state_ops.assign_add(v, split_update))
else:
update_ops.append(state_ops.assign_add(w, u))
if not global_step:
return control_flow_ops.group(*update_ops)
with ops.control_dependencies(update_ops):
return state_ops.assign_add(global_step, 1, name=name).op
示例15: _mini_batch_training_op
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
cluster_centers_var, total_counts):
"""Creates an op for training for mini batch case.
Args:
inputs: list of input Tensors.
cluster_idx_list: A vector (or list of vectors). Each element in the
vector corresponds to an input row in 'inp' and specifies the cluster id
corresponding to the input.
cluster_centers: Tensor of cluster centers, possibly normalized.
cluster_centers_var: Tensor Ref of cluster centers.
total_counts: Tensor Ref of cluster counts.
Returns:
An op for doing an update of mini-batch k-means.
"""
update_ops = []
for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp):
assert total_counts is not None
cluster_idx = array_ops.reshape(cluster_idx, [-1])
# Dedupe the unique ids of cluster_centers being updated so that updates
# can be locally aggregated.
unique_ids, unique_idx = array_ops.unique(cluster_idx)
num_unique_cluster_idx = array_ops.size(unique_ids)
# Fetch the old values of counts and cluster_centers.
with ops.colocate_with(total_counts):
old_counts = array_ops.gather(total_counts, unique_ids)
with ops.colocate_with(cluster_centers):
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
# Locally aggregate the increment to counts.
count_updates = math_ops.unsorted_segment_sum(
array_ops.ones_like(
unique_idx, dtype=total_counts.dtype),
unique_idx,
num_unique_cluster_idx)
# Locally compute the sum of inputs mapped to each id.
# For a cluster with old cluster value x, old count n, and with data
# d_1,...d_k newly assigned to it, we recompute the new value as
# x += (sum_i(d_i) - k * x) / (n + k).
# Compute sum_i(d_i), see comment above.
cluster_center_updates = math_ops.unsorted_segment_sum(
inp, unique_idx, num_unique_cluster_idx)
# Shape to enable broadcasting count_updates and learning_rate to inp.
# It extends the shape with 1's to match the rank of inp.
broadcast_shape = array_ops.concat(
[
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
dtype=dtypes.int32)
],
0)
# Subtract k * x, see comment above.
cluster_center_updates -= math_ops.cast(
array_ops.reshape(count_updates, broadcast_shape),
inp.dtype) * old_cluster_centers
learning_rate = math_ops.reciprocal(
math_ops.cast(old_counts + count_updates, inp.dtype))
learning_rate = array_ops.reshape(learning_rate, broadcast_shape)
# scale by 1 / (n + k), see comment above.
cluster_center_updates *= learning_rate
# Apply the updates.
update_counts = state_ops.scatter_add(total_counts, unique_ids,
count_updates)
update_cluster_centers = state_ops.scatter_add(cluster_centers_var,
unique_ids,
cluster_center_updates)
update_ops.extend([update_counts, update_cluster_centers])
return control_flow_ops.group(*update_ops)