本文整理汇总了Python中tensorflow.contrib.layers.python.layers.utils.smart_cond函数的典型用法代码示例。如果您正苦于以下问题:Python smart_cond函数的具体用法?Python smart_cond怎么用?Python smart_cond使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了smart_cond函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_value
def test_value(self):
fn1 = lambda: 'fn1'
fn2 = lambda: 'fn2'
expected = lambda v: 'fn1' if v else 'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
self.assertEqual(o, expected(v))
示例2: test_tensors
def test_tensors(self):
fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
with self.test_session():
self.assertEqual(o.eval(), expected(v))
示例3: test_constant
def test_constant(self):
fn1 = lambda: constant_op.constant('fn1')
fn2 = lambda: constant_op.constant('fn2')
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
with self.test_session():
self.assertEqual(o.eval(), expected(v))
示例4: test_variable
def test_variable(self):
fn1 = lambda: variables.Variable('fn1')
fn2 = lambda: variables.Variable('fn2')
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
示例5: test_constant
def test_constant(self):
fn1 = lambda: tf.constant('fn1')
fn2 = lambda: tf.constant('fn2')
expected = lambda v: b'fn1' if v else b'fn2'
p = tf.placeholder(tf.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
with self.test_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
示例6: test_variable
def test_variable(self):
fn1 = lambda: tf.Variable('fn1')
fn2 = lambda: tf.Variable('fn2')
expected = lambda v: b'fn1' if v else b'fn2'
p = tf.placeholder(tf.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
示例7: batch_norm_mine_old
#.........这里部分代码省略.........
'moving_mean', init_ops.zeros_initializer())
moving_mean = variables.model_variable(
'moving_mean',
shape=params_shape,
dtype=dtype,
initializer=moving_mean_initializer,
trainable=False,
collections=moving_mean_collections)
moving_variance_collections = utils.get_variable_collections(
variables_collections, 'moving_variance')
moving_variance_initializer = param_initializers.get(
'moving_variance', init_ops.ones_initializer())
moving_variance = variables.model_variable(
'moving_variance',
shape=params_shape,
dtype=dtype,
initializer=moving_variance_initializer,
trainable=False,
collections=moving_variance_collections)
finally:
variable_scope.get_variable_scope().set_partitioner(partitioner)
# If `is_training` doesn't have a constant value, because it is a `Tensor`,
# a `Variable` or `Placeholder` then is_training_value will be None and
# `needs_moments` will be true.
is_training_value = utils.constant_value(is_training)
need_moments = is_training_value is None or is_training_value
if need_moments:
# Calculate the moments based on the individual batch.
if batch_weights is None:
if data_format == DATA_FORMAT_NCHW:
mean, _ = nn.moments(inputs, moments_axes, keep_dims=True)
variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
mean, _ = nn.moments(inputs, moments_axes)
variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes)
else:
if data_format == DATA_FORMAT_NCHW:
mean, _ = nn.weighted_moments(inputs, moments_axes,
batch_weights, keep_dims=True)
variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
batch_weights, keep_dims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
mean, _ = nn.weighted_moments(inputs, moments_axes,
batch_weights)
variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
batch_weights)
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay, zero_debias=False)
with ops.control_dependencies([update_moving_mean,
update_moving_variance]):
return array_ops.identity(mean), array_ops.identity(variance)
mean, variance = utils.smart_cond(is_training,
_force_updates,
moving_vars_fn)
else:
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay, zero_debias=False)
return update_moving_mean, update_moving_variance
update_mean, update_variance = utils.smart_cond(is_training,
_delay_updates,
moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
# Use computed moments during training and moving_vars otherwise.
vars_fn = lambda: (mean, variance)
mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
else:
mean, variance = moving_mean, moving_variance
if data_format == DATA_FORMAT_NCHW:
mean = array_ops.reshape(mean, params_shape_broadcast)
variance = array_ops.reshape(variance, params_shape_broadcast)
beta = array_ops.reshape(beta, params_shape_broadcast)
if gamma is not None:
gamma = array_ops.reshape(gamma, params_shape_broadcast)
# Compute batch_normalization.
outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
epsilon)
outputs.set_shape(inputs_shape)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections,
sc.original_name_scope, outputs)