本文整理汇总了Python中tensorflow.scatter_update方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.scatter_update方法的具体用法?Python tensorflow.scatter_update怎么用?Python tensorflow.scatter_update使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.scatter_update方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: make_update_op
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
batch_size, use_recent_idx, intended_output):
"""Function that creates all the update ops."""
mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
dtype=tf.float32))
with tf.control_dependencies([mem_age_incr]):
mem_age_upd = tf.scatter_update(
self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))
mem_key_upd = tf.scatter_update(
self.mem_keys, upd_idxs, upd_keys)
mem_val_upd = tf.scatter_update(
self.mem_vals, upd_idxs, upd_vals)
if use_recent_idx:
recent_idx_upd = tf.scatter_update(
self.recent_idx, intended_output, upd_idxs)
else:
recent_idx_upd = tf.group()
return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
示例2: replace
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def replace(self, episodes, length, rows=None):
"""Replace full episodes.
Args:
episodes: Tuple of transition quantities with batch and time dimensions.
length: Batch of sequence lengths.
rows: Episodes to replace, defaults to all.
Returns:
Operation.
"""
rows = tf.range(self._capacity) if rows is None else rows
assert rows.shape.ndims == 1
assert_capacity = tf.assert_less(
rows, self._capacity, message='capacity exceeded')
with tf.control_dependencies([assert_capacity]):
assert_max_length = tf.assert_less_equal(
length, self._max_length, message='max length exceeded')
replace_ops = []
with tf.control_dependencies([assert_max_length]):
for buffer_, elements in zip(self._buffers, episodes):
replace_op = tf.scatter_update(buffer_, rows, elements)
replace_ops.append(replace_op)
with tf.control_dependencies(replace_ops):
return tf.scatter_update(self._length, rows, length)
示例3: reinit_nested_vars
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reinit_nested_vars(variables, indices=None):
"""Reset all variables in a nested tuple to zeros.
Args:
variables: Nested tuple or list of variaables.
indices: Indices along the first dimension to reset, defaults to all.
Returns:
Operation.
"""
if isinstance(variables, (tuple, list)):
return tf.group(*[
reinit_nested_vars(variable, indices) for variable in variables])
if indices is None:
return variables.assign(tf.zeros_like(variables))
else:
zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
return tf.scatter_update(variables, indices, zeros)
示例4: reset
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reset(self, indices=None):
"""Reset the batch of environments.
Args:
indices: The batch indices of the environments to reset; defaults to all.
Returns:
Batch tensor of the new observations.
"""
if indices is None:
indices = tf.range(len(self._batch_env))
observ_dtype = self._parse_dtype(self._batch_env.observation_space)
observ = tf.py_func(
self._batch_env.reset, [indices], observ_dtype, name='reset')
observ = tf.check_numerics(observ, 'observ')
reward = tf.zeros_like(indices, tf.float32)
done = tf.zeros_like(indices, tf.bool)
with tf.control_dependencies([
tf.scatter_update(self._observ, indices, observ),
tf.scatter_update(self._reward, indices, reward),
tf.scatter_update(self._done, indices, done)]):
return tf.identity(observ)
示例5: reinit_nested_vars
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reinit_nested_vars(variables, indices=None):
"""Reset all variables in a nested tuple to zeros.
Args:
variables: Nested tuple or list of variaables.
indices: Batch indices to reset, defaults to all.
Returns:
Operation.
"""
if isinstance(variables, (tuple, list)):
return tf.group(*[
reinit_nested_vars(variable, indices) for variable in variables])
if indices is None:
return variables.assign(tf.zeros_like(variables))
else:
zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
return tf.scatter_update(variables, indices, zeros)
示例6: assign_nested_vars
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def assign_nested_vars(variables, tensors, indices=None):
"""Assign tensors to matching nested tuple of variables.
Args:
variables: Nested tuple or list of variables to update.
tensors: Nested tuple or list of tensors to assign.
indices: Batch indices to assign to; default to all.
Returns:
Operation.
"""
if isinstance(variables, (tuple, list)):
return tf.group(*[
assign_nested_vars(variable, tensor)
for variable, tensor in zip(variables, tensors)])
if indices is None:
return variables.assign(tensors)
else:
return tf.scatter_update(variables, indices, tensors)
示例7: _reset_non_empty
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def _reset_non_empty(self, indices):
"""Reset the batch of environments.
Args:
indices: The batch indices of the environments to reset; defaults to all.
Returns:
Batch tensor of the new observations.
"""
observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
observ = tf.py_func(
self._batch_env.reset, [indices], observ_dtype, name='reset')
observ = tf.check_numerics(observ, 'observ')
with tf.control_dependencies([
tf.scatter_update(self._observ, indices, observ)]):
return tf.identity(observ)
示例8: sparse_moving_average
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def sparse_moving_average(self, variable, unique_indices, accumulant, name='Accumulator', decay=.9):
""""""
accumulant = tf.clip_by_value(accumulant, -self.clip, self.clip)
first_dim = variable.get_shape().as_list()[0]
accumulator = self.get_accumulator(name, variable)
indexed_accumulator = tf.gather(accumulator, unique_indices)
iteration = self.get_accumulator('{}/iteration'.format(name), variable, shape=[first_dim, 1])
indexed_iteration = tf.gather(iteration, unique_indices)
iteration = tf.scatter_add(iteration, unique_indices, tf.ones_like(indexed_iteration))
indexed_iteration = tf.gather(iteration, unique_indices)
if decay < 1:
current_indexed_decay = decay * (1-decay**(indexed_iteration-1)) / (1-decay**indexed_iteration)
else:
current_indexed_decay = (indexed_iteration-1) / indexed_iteration
accumulator = tf.scatter_update(accumulator, unique_indices, current_indexed_decay*indexed_accumulator)
accumulator = tf.scatter_add(accumulator, unique_indices, (1-current_indexed_decay)*accumulant)
return accumulator, iteration
#=============================================================
示例9: _reset_non_empty
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def _reset_non_empty(self, indices):
# pylint: disable=protected-access
new_values = self._batch_env._reset_non_empty(indices)
# pylint: enable=protected-access
initial_frames = getattr(self._batch_env, "history_observations", None)
if initial_frames is not None:
# Using history buffer frames for initialization, if they are available.
with tf.control_dependencies([new_values]):
# Transpose to [batch, height, width, history, channels] and merge
# history and channels into one dimension.
initial_frames = tf.transpose(initial_frames, [0, 2, 3, 1, 4])
initial_frames = tf.reshape(initial_frames,
(len(self),) + self.observ_shape)
else:
inx = tf.concat(
[
tf.ones(tf.size(tf.shape(new_values)),
dtype=tf.int64)[:-1],
[self.history]
],
axis=0)
initial_frames = tf.tile(new_values, inx)
assign_op = tf.scatter_update(self._observ, indices, initial_frames)
with tf.control_dependencies([assign_op]):
return tf.gather(self.observ, indices)
示例10: testWhileUpdateVariable_1
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def testWhileUpdateVariable_1(self):
with self.test_session():
select = tf.Variable([3.0, 4.0, 5.0])
n = tf.constant(0)
def loop_iterator(j):
return tf.less(j, 3)
def loop_body(j):
ns = tf.scatter_update(select, j, 10.0)
nj = tf.add(j, 1)
op = control_flow_ops.group(ns)
nj = control_flow_ops.with_dependencies([op], nj)
return [nj]
r = tf.while_loop(loop_iterator, loop_body, [n],
parallel_iterations=1)
self.assertTrue(check_op_order(n.graph))
tf.global_variables_initializer().run()
self.assertEqual(3, r.eval())
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
示例11: testWhileUpdateVariable_3
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def testWhileUpdateVariable_3(self):
with self.test_session():
select = tf.Variable([3.0, 4.0, 5.0])
n = tf.constant(0)
def loop_iterator(j, _):
return tf.less(j, 3)
def loop_body(j, _):
ns = tf.scatter_update(select, j, 10.0)
nj = tf.add(j, 1)
return [nj, ns]
r = tf.while_loop(loop_iterator, loop_body,
[n, tf.identity(select)],
parallel_iterations=1)
tf.global_variables_initializer().run()
result = r[1].eval()
self.assertTrue(check_op_order(n.graph))
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
# b/24814703
示例12: insert
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def insert(self, ids, scores):
"""Insert the ids and scores into the TopN."""
with tf.control_dependencies(self.last_ops):
scatter_op = tf.scatter_update(self.id_to_score, ids, scores)
larger_scores = tf.greater(scores, self.sl_scores[0])
def shortlist_insert():
larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
larger_score_values = tf.boolean_mask(scores, larger_scores)
shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
return tf.group(u1, u2)
# We only need to insert into the shortlist if there are any
# scores larger than the threshold.
cond_op = tf.cond(
tf.reduce_any(larger_scores), shortlist_insert, tf.no_op)
with tf.control_dependencies([cond_op]):
self.last_ops = [scatter_op, cond_op]
示例13: remove
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def remove(self, ids):
"""Remove the ids (and their associated scores) from the TopN."""
with tf.control_dependencies(self.last_ops):
scatter_op = tf.scatter_update(
self.id_to_score,
ids,
tf.ones_like(
ids, dtype=tf.float32) * tf.float32.min)
# We assume that removed ids are almost always in the shortlist,
# so it makes no sense to hide the Op behind a tf.cond
shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
ids)
u1 = tf.scatter_update(
self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
tf.concat(0, [new_length,
tf.ones_like(shortlist_ids_to_remove) * -1]))
u2 = tf.scatter_update(
self.sl_scores,
shortlist_ids_to_remove,
tf.float32.min * tf.ones_like(
shortlist_ids_to_remove, dtype=tf.float32))
self.last_ops = [scatter_op, u1, u2]
示例14: reset
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reset(self, indices=None):
"""Reset the batch of environments.
Args:
indices: The batch indices of the environments to reset; defaults to all.
Returns:
Batch tensor of the new observations.
"""
if indices is None:
indices = tf.range(len(self._batch_env))
observ_dtype = self._parse_dtype(self._batch_env.observation_space)
observ = tf.py_func(
self._batch_env.reset, [indices], observ_dtype, name='reset')
reward = tf.zeros_like(indices, tf.float32)
done = tf.zeros_like(indices, tf.int32)
with tf.control_dependencies([
tf.scatter_update(self._observ, indices, observ),
tf.scatter_update(self._reward, indices, reward),
tf.scatter_update(self._done, indices, tf.to_int32(done))]):
return tf.identity(observ)
示例15: curvature_range
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def curvature_range(self):
# set up the curvature window
self._curv_win = \
tf.Variable(np.zeros( [self._curv_win_width, ] ), dtype=tf.float32, name="curv_win", trainable=False)
self._curv_win = tf.scatter_update(self._curv_win,
self._global_step % self._curv_win_width, self._grad_norm_squared)
# note here the iterations start from iteration 0
valid_window = tf.slice(self._curv_win, tf.constant( [0, ] ),
tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0) )
self._h_min_t = tf.reduce_min(valid_window)
self._h_max_t = tf.reduce_max(valid_window)
curv_range_ops = []
with tf.control_dependencies([self._h_min_t, self._h_max_t] ):
avg_op = self._moving_averager.apply([self._h_min_t, self._h_max_t] )
with tf.control_dependencies([avg_op] ):
self._h_min = tf.identity(self._moving_averager.average(self._h_min_t) )
self._h_max = tf.identity(self._moving_averager.average(self._h_max_t) )
curv_range_ops.append(avg_op)
return curv_range_ops