当前位置: 首页>>代码示例>>Python>>正文


Python train.gan_model函数代码示例

本文整理汇总了Python中tensorflow.contrib.gan.python.train.gan_model函数的典型用法代码示例。如果您正苦于以下问题:Python gan_model函数的具体用法?Python gan_model怎么用?Python gan_model使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了gan_model函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_no_shape_check

 def test_no_shape_check(self):
   def dummy_generator_model(_):
     return (None, None)
   def dummy_discriminator_model(data, conditioning):  # pylint: disable=unused-argument
     return 1
   with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
     train.gan_model(
         dummy_generator_model,
         dummy_discriminator_model,
         real_data=array_ops.zeros([1, 2]),
         generator_inputs=array_ops.zeros([1]),
         check_shapes=True)
   train.gan_model(
       dummy_generator_model,
       dummy_discriminator_model,
       real_data=array_ops.zeros([1, 2]),
       generator_inputs=array_ops.zeros([1]),
       check_shapes=False)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:18,代码来源:train_test.py

示例2: test_doesnt_crash_when_in_nested_scope

  def test_doesnt_crash_when_in_nested_scope(self):
    with variable_scope.variable_scope('outer_scope'):
      gan_model = train.gan_model(
          generator_model,
          discriminator_model,
          real_data=array_ops.zeros([1, 2]),
          generator_inputs=random_ops.random_normal([1, 2]))

      # This should work inside a scope.
      train.gan_loss(gan_model, gradient_penalty_weight=1.0)

    # This should also work outside a scope.
    train.gan_loss(gan_model, gradient_penalty_weight=1.0)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:13,代码来源:train_test.py

示例3: test_discriminator_only_sees_pool

 def test_discriminator_only_sees_pool(self):
   """Checks that discriminator only sees pooled values."""
   def checker_gen_fn(_):
     return constant_op.constant(0.0)
   model = train.gan_model(
       checker_gen_fn,
       discriminator_model,
       real_data=array_ops.zeros([]),
       generator_inputs=random_ops.random_normal([]))
   def tensor_pool_fn(_):
     return (random_ops.random_uniform([]), random_ops.random_uniform([]))
   def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
   model = model._replace(
       discriminator_fn=checker_dis_fn)
   train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:18,代码来源:train_test.py

示例4: _make_train_gan_model

def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
                          generator_inputs, generator_scope, add_summaries):
  """Make a `GANModel` for training."""
  gan_model = tfgan_train.gan_model(
      generator_fn,
      discriminator_fn,
      real_data,
      generator_inputs,
      generator_scope=generator_scope,
      check_shapes=_use_check_shapes(real_data))
  if add_summaries:
    if not isinstance(add_summaries, (tuple, list)):
      add_summaries = [add_summaries]
    with ops.name_scope(None):
      for summary_type in add_summaries:
        _summary_type_map[summary_type](gan_model)

  return gan_model
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:18,代码来源:gan_estimator_impl.py

示例5: _make_gan_model

def _make_gan_model(generator_fn, discriminator_fn, real_data,
                    generator_inputs, generator_scope, add_summaries, mode):
  """Make a `GANModel`, and optionally pass in `mode`."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn, mode=mode)
  gan_model = tfgan_train.gan_model(
      generator_fn,
      discriminator_fn,
      real_data,
      generator_inputs,
      generator_scope=generator_scope,
      check_shapes=False)
  if add_summaries:
    if not isinstance(add_summaries, (tuple, list)):
      add_summaries = [add_summaries]
    with ops.name_scope(None):
      for summary_type in add_summaries:
        _summary_type_map[summary_type](gan_model)

  return gan_model
开发者ID:Fair-Child,项目名称:tensorflow,代码行数:21,代码来源:gan_estimator_impl.py

示例6: model_fn

  def model_fn(features, labels, mode, params):
    """Model function defining an inpainting estimator."""
    batch_size = params['batch_size']
    z_shape = [batch_size] + params['z_shape']
    add_summaries = params['add_summaries']
    input_clip = params['input_clip']

    z = variable_scope.get_variable(
        name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape),
        constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip))

    generator = functools.partial(generator_fn, mode=mode)
    discriminator = functools.partial(discriminator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn=generator,
                                      discriminator_fn=discriminator,
                                      real_data=labels,
                                      generator_inputs=z,
                                      check_shapes=False)

    loss = loss_fn(gan_model, features, labels, add_summaries)

    # Use a variable scope to make sure that estimator variables dont cause
    # save/load problems when restoring from ckpts.
    with variable_scope.variable_scope(OPTIMIZER_NAME):
      opt = optimizer(learning_rate=params['learning_rate'],
                      **params['opt_kwargs'])
      train_op = opt.minimize(
          loss=loss, global_step=training_util.get_or_create_global_step(),
          var_list=[z])

    if add_summaries:
      z_grads = gradients_impl.gradients(loss, z)
      summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads))
      summary.scalar('z_loss/loss', loss)

    return model_fn_lib.EstimatorSpec(mode=mode,
                                      predictions=gan_model.generated_data,
                                      loss=loss,
                                      train_op=train_op)
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:39,代码来源:latent_gan_estimator_impl.py

示例7: create_callable_gan_model

def create_callable_gan_model():
  return train.gan_model(
      Generator(),
      Discriminator(),
      real_data=array_ops.zeros([1, 2]),
      generator_inputs=random_ops.random_normal([1, 2]))
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:6,代码来源:train_test.py


注:本文中的tensorflow.contrib.gan.python.train.gan_model函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。