本文整理汇总了Python中tensorflow.compat.v1.while_loop方法的典型用法代码示例。如果您正苦于以下问题:Python v1.while_loop方法的具体用法?Python v1.while_loop怎么用?Python v1.while_loop使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v1
的用法示例。
在下文中一共展示了v1.while_loop方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _build
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def _build(self, inputs, labels):
def cond(i, unused_attack, success):
# If we are already successful, we break.
return tf.logical_and(i < self._num_restarts,
tf.logical_not(tf.reduce_all(success)))
def body(i, attack, success):
new_attack = self._inner_attack(inputs, labels)
new_success = self._inner_attack.success
# The first iteration always sets the attack.
use_new_values = tf.logical_or(tf.equal(i, 0), new_success)
return (i + 1,
tf.where(use_new_values, new_attack, attack),
tf.logical_or(success, new_success))
_, self._attack, self._success = tf.while_loop(
cond, body, back_prop=False, parallel_iterations=1,
loop_vars=[
tf.constant(0, dtype=tf.int32),
inputs,
tf.zeros([tf.shape(inputs)[0]], dtype=tf.bool),
])
self._logits = self._eval_fn(self._attack, mode='final')
return self._attack
示例2: search
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def search(self, initial_ids, initial_cache):
"""Beam search for sequences with highest scores."""
state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
finished_state = tf.while_loop(
self._continue_search, self._search_step, loop_vars=[state],
shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
finished_state = finished_state[0]
alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq = tf.where(
tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
finished_scores = tf.where(
tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
return finished_seq, finished_scores
示例3: test_loop_2_vars
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_2_vars():
graph = tf.Graph()
with graph.as_default():
i0 = tf.constant(0)
j0 = tf.ones([2, 2])
def c(i, j): return i < 10
def b(i, j): return [tf.add(i, 1), j]
i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0])
i1 += tf.constant(1337)
with tf.Session() as sess:
tf_out = sess.run(i1)
check_equal(graph, tf_out)
示例4: test_loop_3_vars
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_3_vars():
graph = tf.Graph()
with graph.as_default():
i0 = tf.constant(1)
j0 = tf.constant(2)
k0 = tf.constant(4)
def c(i, j, k): return i < 10
def b(i, j, k): return [i+1, j * k, k + i]
r = tf.while_loop(c, b, loop_vars=[i0, j0, k0])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
示例5: test_loop_conditions
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_conditions():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(1)
j = tf.constant(1)
k = tf.constant(5)
def c(i, j, k): return \
tf.equal(tf.not_equal(tf.less(i + j, 10),
tf.less(j * k, 100)),
tf.greater_equal(k, i + j))
def b(i, j, k): return [i+j, j+k, k+1]
r = tf.while_loop(c, b, loop_vars=[i, j, k])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
示例6: test_nested_loop
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_nested_loop():
graph = tf.Graph()
with graph.as_default():
def body(x):
def nest_body(c):
return tf.multiply(c, 2)
def cd(c): return tf.less(c, 10)
c = tf.constant(2)
res = tf.while_loop(cd, nest_body, loop_vars=[c])
return tf.nn.relu(x + res)
def condition(x):
return tf.greater(x, 100)
x = tf.constant(3)
r = tf.while_loop(condition, body, loop_vars=[x])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
示例7: test_loop_in_cond
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_loop_in_cond():
graph = tf.Graph()
with graph.as_default():
def fn1(a, b):
i = tf.constant(0)
def cd(i): return tf.less(i, 10)
def bd(i): return tf.add(i, 1)
res = tf.while_loop(cd, bd, [i])
return tf.multiply(tf.add(20, res), 10)
def fn2(a, b):
return tf.add(10, 20)
x = tf.constant(7)
y = tf.constant(20)
z = tf.constant(10)
pred = tf.less(x, y)
r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})
check_equal(graph, tf_out)
示例8: test_cond_in_loop
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_cond_in_loop():
graph = tf.Graph()
with graph.as_default():
def body(x):
x = tf.constant(7)
z = tf.constant(20)
res = tf.cond(tf.less(x, 10), lambda: tf.add(
10, 20), lambda: tf.square(10))
return tf.multiply(res, x)
x = tf.constant(21)
def condition(x):
return tf.less(x, 100)
r = tf.while_loop(condition, body, loop_vars=[x])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
示例9: should_generate_summaries
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def should_generate_summaries():
"""Is this an appropriate context to generate summaries.
Returns:
a boolean
"""
name_scope = contrib.framework().get_name_scope()
if name_scope and "while/" in name_scope:
# Summaries don't work well within tf.while_loop()
return False
if tf.get_variable_scope().reuse:
# Avoid generating separate summaries for different data shards
return False
return True
示例10: _should_cache_variables
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def _should_cache_variables():
"""Returns True if a default caching device should be set, otherwise False."""
# Don't set a caching device when running in a loop, since it is possible that
# train steps could be wrapped in a tf.while_loop. In that scenario caching
# prevents forward computations in loop iterations from re-reading the
# updated weights.
graph = tf.get_default_graph()
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
in_v1_while_loop = (
control_flow_util.GetContainingWhileContext(ctxt) is not None)
return not in_v1_while_loop
示例11: adapt
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def adapt(self, original_inputs, adversarial_inputs, labels):
"""Runs binary search to find the first misclassified input."""
batch_size = tf.shape(original_inputs)[0]
binary_search_iterations = 10
def cond(i, *_):
return tf.less(i, binary_search_iterations)
def get(m):
m = tf.reshape(m, [batch_size] + [1] * (len(original_inputs.shape) - 1))
return (adversarial_inputs - original_inputs) * m + original_inputs
def is_attack_successful(m):
logits = self._eval_fn(get(m))
return self._success_fn(self._specification.evaluate(logits))
def loop_body(i, lower, upper):
m = (lower + upper) * .5
success = is_attack_successful(m)
new_lower = tf.where(success, lower, m)
new_upper = tf.where(success, m, upper)
return i + 1, new_lower, new_upper
lower = tf.zeros(shape=[batch_size])
upper = tf.ones(shape=[batch_size])
_, lower, upper = tf.while_loop(
cond,
loop_body,
loop_vars=[tf.constant(0.), lower, upper],
parallel_iterations=1,
back_prop=False)
# If lower is incorrectly classified, pick lower; otherwise pick upper.
success = is_attack_successful(lower)
return get(tf.where(success, lower, upper))
示例12: benchmark_handwritten
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def benchmark_handwritten(self):
with tf.Graph().as_default():
ds, opt, hp, w, b = get_data_and_params()
iterator = ds.make_one_shot_iterator()
def loop_body(i, unused_previous_loss_t):
"""Manual implementation of training loop."""
# Call get_next() inside body or else training happens repeatedly on
# the first minibatch only.
x, y = iterator.get_next()
loss_t = loss_fn(x, y, w, b)
train_op = opt.minimize(loss_t, var_list=(w, b))
i = tf.cond(tf.equal(i % 100, 0),
lambda: tf.Print(i, [i, loss_t], message='Step, loss: '),
lambda: i)
with tf.control_dependencies([train_op]):
return i + 1, loss_t
_, final_loss_t = tf.while_loop(
lambda i, _: i < hp.train_steps,
loop_body,
[tf.constant(0), tf.constant(0.0)])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
def target():
loss_val = sess.run(final_loss_t)
assert 0.1 < loss_val < 1, loss_val
self.time_execution(
'Handwritten',
target,
iter_volume=hp.train_steps,
iter_unit='training steps')
示例13: test_vanilla_loop
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_vanilla_loop():
graph = tf.Graph()
with graph.as_default():
i = tf.constant(0, name="while/constant")
def c(i): return tf.less(i, 10)
def b(i): return tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
示例14: test_callnode_loop_vars
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_callnode_loop_vars():
graph = tf.Graph()
with graph.as_default():
i = tf.add(tf.constant(0), 1)
def c(i): return tf.less(i, 10)
def b(i): return tf.add(i, 1)
r = tf.while_loop(c, b, [i])
with tf.Session() as sess:
tf_out = sess.run(r)
check_equal(graph, tf_out)
示例15: test_vanilla_loop_bound
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import while_loop [as 别名]
def test_vanilla_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
check_equal(graph, tf_out, {dname: np_data})