本文整理汇总了Python中tensorflow.contrib.gan.python.train.gan_model函数的典型用法代码示例。如果您正苦于以下问题:Python gan_model函数的具体用法?Python gan_model怎么用?Python gan_model使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了gan_model函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_no_shape_check
def test_no_shape_check(self):
def dummy_generator_model(_):
return (None, None)
def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument
return 1
with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
train.gan_model(
dummy_generator_model,
dummy_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=array_ops.zeros([1]),
check_shapes=True)
train.gan_model(
dummy_generator_model,
dummy_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=array_ops.zeros([1]),
check_shapes=False)
示例2: test_doesnt_crash_when_in_nested_scope
def test_doesnt_crash_when_in_nested_scope(self):
with variable_scope.variable_scope('outer_scope'):
gan_model = train.gan_model(
generator_model,
discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
# This should work inside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
# This should also work outside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
示例3: test_discriminator_only_sees_pool
def test_discriminator_only_sees_pool(self):
"""Checks that discriminator only sees pooled values."""
def checker_gen_fn(_):
return constant_op.constant(0.0)
model = train.gan_model(
checker_gen_fn,
discriminator_model,
real_data=array_ops.zeros([]),
generator_inputs=random_ops.random_normal([]))
def tensor_pool_fn(_):
return (random_ops.random_uniform([]), random_ops.random_uniform([]))
def checker_dis_fn(inputs, _):
"""Discriminator that checks that it only sees pooled Tensors."""
self.assertFalse(constant_op.is_constant(inputs))
return inputs
model = model._replace(
discriminator_fn=checker_dis_fn)
train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
示例4: _make_train_gan_model
def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries):
"""Make a `GANModel` for training."""
gan_model = tfgan_train.gan_model(
generator_fn,
discriminator_fn,
real_data,
generator_inputs,
generator_scope=generator_scope,
check_shapes=_use_check_shapes(real_data))
if add_summaries:
if not isinstance(add_summaries, (tuple, list)):
add_summaries = [add_summaries]
with ops.name_scope(None):
for summary_type in add_summaries:
_summary_type_map[summary_type](gan_model)
return gan_model
示例5: _make_gan_model
def _make_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries, mode):
"""Make a `GANModel`, and optionally pass in `mode`."""
# If `generator_fn` has an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn, mode=mode)
gan_model = tfgan_train.gan_model(
generator_fn,
discriminator_fn,
real_data,
generator_inputs,
generator_scope=generator_scope,
check_shapes=False)
if add_summaries:
if not isinstance(add_summaries, (tuple, list)):
add_summaries = [add_summaries]
with ops.name_scope(None):
for summary_type in add_summaries:
_summary_type_map[summary_type](gan_model)
return gan_model
示例6: model_fn
def model_fn(features, labels, mode, params):
"""Model function defining an inpainting estimator."""
batch_size = params['batch_size']
z_shape = [batch_size] + params['z_shape']
add_summaries = params['add_summaries']
input_clip = params['input_clip']
z = variable_scope.get_variable(
name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape),
constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip))
generator = functools.partial(generator_fn, mode=mode)
discriminator = functools.partial(discriminator_fn, mode=mode)
gan_model = tfgan_train.gan_model(generator_fn=generator,
discriminator_fn=discriminator,
real_data=labels,
generator_inputs=z,
check_shapes=False)
loss = loss_fn(gan_model, features, labels, add_summaries)
# Use a variable scope to make sure that estimator variables dont cause
# save/load problems when restoring from ckpts.
with variable_scope.variable_scope(OPTIMIZER_NAME):
opt = optimizer(learning_rate=params['learning_rate'],
**params['opt_kwargs'])
train_op = opt.minimize(
loss=loss, global_step=training_util.get_or_create_global_step(),
var_list=[z])
if add_summaries:
z_grads = gradients_impl.gradients(loss, z)
summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads))
summary.scalar('z_loss/loss', loss)
return model_fn_lib.EstimatorSpec(mode=mode,
predictions=gan_model.generated_data,
loss=loss,
train_op=train_op)
示例7: create_callable_gan_model
def create_callable_gan_model():
return train.gan_model(
Generator(),
Discriminator(),
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))