本文整理汇总了Python中tensorflow.contrib.gan.python.train.gan_loss函数的典型用法代码示例。如果您正苦于以下问题:Python gan_loss函数的具体用法?Python gan_loss怎么用?Python gan_loss使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了gan_loss函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _test_acgan_helper
def _test_acgan_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
[loss.generator_loss,
loss_ac_gen.generator_loss,
loss_ac_dis.generator_loss])
loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
[loss.discriminator_loss,
loss_ac_gen.discriminator_loss,
loss_ac_dis.discriminator_loss])
self.assertTrue(loss_gen_np < loss_dis_np)
self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
示例2: test_doesnt_crash_when_in_nested_scope
def test_doesnt_crash_when_in_nested_scope(self):
with variable_scope.variable_scope('outer_scope'):
gan_model = train.gan_model(
generator_model,
discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
# This should work inside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
# This should also work outside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
示例3: _test_tensor_pool_helper
def _test_tensor_pool_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
if isinstance(model, namedtuples.InfoGANModel):
def tensor_pool_fn_impl(input_values):
generated_data, generator_inputs = input_values
output_values = random_tensor_pool.tensor_pool(
[generated_data] + generator_inputs, pool_size=5)
return output_values[0], output_values[1:]
tensor_pool_fn = tensor_pool_fn_impl
else:
def tensor_pool_fn_impl(input_values):
return random_tensor_pool.tensor_pool(input_values, pool_size=5)
tensor_pool_fn = tensor_pool_fn_impl
loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
for _ in range(10):
sess.run([loss.generator_loss, loss.discriminator_loss])
示例4: _test_grad_penalty_helper
def _test_grad_penalty_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_gen_gp_np = sess.run(
[loss.generator_loss, loss_gp.generator_loss])
loss_dis_np, loss_dis_gp_np = sess.run(
[loss.discriminator_loss, loss_gp.discriminator_loss])
self.assertEqual(loss_gen_np, loss_gen_gp_np)
self.assertTrue(loss_dis_np < loss_dis_gp_np)
示例5: test_train_hooks_exist_in_get_hooks_fn
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = get_sync_optimizer()
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
model,
loss,
g_opt,
d_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
self.assertLen(sequential_train_hooks, 4)
sync_opts = [
hook._sync_optimizer for hook in sequential_train_hooks if
isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
self.assertLen(sync_opts, 2)
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
joint_train_hooks = train.get_joint_train_hooks()(train_ops)
self.assertLen(joint_train_hooks, 5)
sync_opts = [
hook._sync_optimizer for hook in joint_train_hooks if
isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
self.assertLen(sync_opts, 2)
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
示例6: test_sync_replicas
def test_sync_replicas(self, create_gan_model_fn, create_global_step):
model = create_gan_model_fn()
loss = train.gan_loss(model)
num_trainable_vars = len(variables_lib.get_trainable_variables())
if create_global_step:
gstep = variable_scope.get_variable(
'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)
g_opt = get_sync_optimizer()
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
# No new trainable variables should have been added.
self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)
# Sync hooks should be populated in the GANTrainOps.
self.assertLen(train_ops.train_hooks, 2)
for hook in train_ops.train_hooks:
self.assertIsInstance(
hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
# Check that update op is run properly.
global_step = training_util.get_or_create_global_step()
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
variables.local_variables_initializer().run()
g_opt.chief_init_op.run()
d_opt.chief_init_op.run()
gstep_before = global_step.eval()
# Start required queue runner for SyncReplicasOptimizer.
coord = coordinator.Coordinator()
g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)
g_sync_init_op.run()
d_sync_init_op.run()
train_ops.generator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
train_ops.discriminator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
coord.request_stop()
coord.join(g_threads + d_threads)
示例7: test_discriminator_only_sees_pool
def test_discriminator_only_sees_pool(self):
"""Checks that discriminator only sees pooled values."""
def checker_gen_fn(_):
return constant_op.constant(0.0)
model = train.gan_model(
checker_gen_fn,
discriminator_model,
real_data=array_ops.zeros([]),
generator_inputs=random_ops.random_normal([]))
def tensor_pool_fn(_):
return (random_ops.random_uniform([]), random_ops.random_uniform([]))
def checker_dis_fn(inputs, _):
"""Discriminator that checks that it only sees pooled Tensors."""
self.assertFalse(constant_op.is_constant(inputs))
return inputs
model = model._replace(
discriminator_fn=checker_dis_fn)
train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
示例8: test_grad_penalty
def test_grad_penalty(self, create_gan_model_fn, one_sided):
"""Test gradient penalty option."""
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_gp = train.gan_loss(
model,
gradient_penalty_weight=1.0,
gradient_penalty_one_sided=one_sided)
self.assertIsInstance(loss_gp, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_gen_gp_np = sess.run(
[loss.generator_loss, loss_gp.generator_loss])
loss_dis_np, loss_dis_gp_np = sess.run(
[loss.discriminator_loss, loss_gp.discriminator_loss])
self.assertEqual(loss_gen_np, loss_gen_gp_np)
self.assertLess(loss_dis_np, loss_dis_gp_np)
示例9: test_tensor_pool
def test_tensor_pool(self, create_gan_model_fn):
"""Test tensor pool option."""
model = create_gan_model_fn()
tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5)
loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
self.assertIsInstance(loss, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
for _ in range(10):
sess.run([loss.generator_loss, loss.discriminator_loss])
示例10: _test_regularization_helper
def _test_regularization_helper(self, get_gan_model_fn):
# Evaluate losses without regularization.
no_reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()
with ops.name_scope(get_gan_model_fn().generator_scope.name):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
# Check that losses now include the correct regularization values.
reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
reg_loss_gen_np = reg_loss.generator_loss.eval()
reg_loss_dis_np = reg_loss.discriminator_loss.eval()
self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
示例11: _test_tensor_pool_helper
def _test_tensor_pool_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
if isinstance(model, namedtuples.InfoGANModel):
tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
else:
tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
for _ in range(10):
sess.run([loss.generator_loss, loss.discriminator_loss])
示例12: _test_run_helper
def _test_run_helper(self, create_gan_model_fn):
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
final_step = train.gan_train(
train_ops,
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
示例13: _test_output_type_helper
def _test_output_type_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(
model,
loss,
g_opt,
d_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
示例14: test_patchgan
def test_patchgan(self, create_gan_model_fn):
"""Ensure that patch-based discriminators work end-to-end."""
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
final_step = train.gan_train(
train_ops,
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
示例15: test_unused_update_ops
def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
model = create_gan_model_fn()
loss = train.gan_loss(model)
# Add generator and discriminator update ops.
with variable_scope.variable_scope(model.generator_scope):
gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
gen_update_op = gen_update_count.assign_add(1)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
with variable_scope.variable_scope(model.discriminator_scope):
dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
dis_update_op = dis_update_count.assign_add(1)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)
# Add an update op outside the generator and discriminator scopes.
if provide_update_ops:
kwargs = {
'update_ops': [
constant_op.constant(1.0), gen_update_op, dis_update_op
]
}
else:
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
kwargs = {}
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
train.gan_train_ops(
model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs)
train_ops = train.gan_train_ops(
model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(0, gen_update_count.eval())
self.assertEqual(0, dis_update_count.eval())
train_ops.generator_train_op.eval()
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(0, dis_update_count.eval())
train_ops.discriminator_train_op.eval()
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(1, dis_update_count.eval())