当前位置: 首页>>代码示例>>Python>>正文


Python parameterized.parameters方法代码示例

本文整理汇总了Python中absl.testing.parameterized.parameters方法的典型用法代码示例。如果您正苦于以下问题:Python parameterized.parameters方法的具体用法?Python parameterized.parameters怎么用?Python parameterized.parameters使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在absl.testing.parameterized的用法示例。


在下文中一共展示了parameterized.parameters方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_basic_encode_decode_tf_constructor_parameters

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def test_basic_encode_decode_tf_constructor_parameters(self):
    """Tests the core funcionality with `tf.Variable` constructor parameters."""
    a_var = tf.compat.v1.get_variable('a_var', initializer=self._DEFAULT_A)
    b_var = tf.compat.v1.get_variable('b_var', initializer=self._DEFAULT_B)
    stage = test_utils.SimpleLinearEncodingStage(a_var, b_var)

    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
    x = self.default_input()
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)
    test_data = self.evaluate_test_data(
        test_utils.TestData(x, encoded_x, decoded_x))
    self.common_asserts_for_test_data(test_data)

    # Change the variables and verify the behavior of stage changes.
    self.evaluate(
        [tf.compat.v1.assign(a_var, 5.0),
         tf.compat.v1.assign(b_var, 6.0)])
    test_data = self.evaluate_test_data(
        test_utils.TestData(x, encoded_x, decoded_x))
    self.assertAllClose(test_data.x * 5.0 + 6.0,
                        test_data.encoded_x[self._ENCODED_VALUES_KEY]) 
开发者ID:tensorflow,项目名称:model-optimization,代码行数:26,代码来源:test_utils_test.py

示例2: test_dynamic_graph_convolution_keras_layer_exception_not_raised_shapes

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def test_dynamic_graph_convolution_keras_layer_exception_not_raised_shapes(
      self, batch_size, num_vertices, in_channels, out_channels, reduction):
    """Check if the convolution parameters and output have correct shapes."""
    if not tf.executing_eagerly():
      return
    data, neighbors = _dummy_data(batch_size, num_vertices, in_channels)
    layer = gc_layer.DynamicGraphConvolutionKerasLayer(
        num_output_channels=out_channels,
        reduction=reduction)

    try:
      output = layer(inputs=[data, neighbors], sizes=None)
    except Exception as e:  # pylint: disable=broad-except
      self.fail("Exception raised: %s" % str(e))

    self.assertAllEqual((batch_size, num_vertices, out_channels), output.shape) 
开发者ID:tensorflow,项目名称:graphics,代码行数:18,代码来源:graph_convolution_test.py

示例3: test_get_defun_argspec_with_typed_non_eager_defun

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def test_get_defun_argspec_with_typed_non_eager_defun(self):
    # In a tf.function with a defined input signature, **kwargs or default
    # values are not allowed, but *args are, and the input signature may overlap
    # with *args.
    fn = tf.function(lambda x, y, *z: None, (
        tf.TensorSpec(None, tf.int32),
        tf.TensorSpec(None, tf.bool),
        tf.TensorSpec(None, tf.float32),
        tf.TensorSpec(None, tf.float32),
    ))
    self.assertEqual(
        collections.OrderedDict(function_utils.get_signature(fn).parameters),
        collections.OrderedDict(
            x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD),
            y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD),
            z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
        )) 
开发者ID:tensorflow,项目名称:federated,代码行数:19,代码来源:function_utils_test.py

示例4: test_get_signature_with_class_instance_method

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def test_get_signature_with_class_instance_method(self):

    class C:

      def __init__(self, x):
        self._x = x

      def foo(self, y):
        return self._x * y

    c = C(5)
    signature = function_utils.get_signature(c.foo)
    self.assertEqual(
        signature.parameters,
        collections.OrderedDict(
            y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD))) 
开发者ID:tensorflow,项目名称:federated,代码行数:18,代码来源:function_utils_test.py

示例5: _modify_class

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _modify_class(class_object, testcases, naming_type):
  assert not getattr(class_object, '_test_method_ids', None), (
      'Cannot add parameters to %s. Either it already has parameterized '
      'methods, or its super class is also a parameterized class.' % (
          class_object,))
  class_object._test_method_ids = test_method_ids = {}
  for name, obj in six.iteritems(class_object.__dict__.copy()):
    if (name.startswith(unittest.TestLoader.testMethodPrefix)
        and isinstance(obj, types.FunctionType)):
      delattr(class_object, name)
      methods = {}
      _update_class_dict_for_param_test_case(
          class_object.__name__, methods, test_method_ids, name,
          _ParameterizedTestIter(obj, testcases, naming_type, name))
      for name, meth in six.iteritems(methods):
        setattr(class_object, name, meth) 
开发者ID:abseil,项目名称:abseil-py,代码行数:18,代码来源:parameterized.py

示例6: parameters

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def parameters(*testcases):
  """A decorator for creating parameterized tests.

  See the module docstring for a usage example.

  Args:
    *testcases: Parameters for the decorated method, either a single
        iterable, or a list of tuples/dicts/objects (for tests with only one
        argument).

  Raises:
    NoTestsError: Raised when the decorator generates no tests.

  Returns:
     A test generator to be handled by TestGeneratorMetaclass.
  """
  return _parameter_decorator(_ARGUMENT_REPR, testcases) 
开发者ID:abseil,项目名称:abseil-py,代码行数:19,代码来源:parameterized.py

示例7: testFactorisedKLGaussian

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def testFactorisedKLGaussian(self, dist1_type, dist2_type):
    """Tests that the factorised KL terms sum up to the true KL."""
    dist1, dist1_mean, dist1_cov = self._create_gaussian(dist1_type)
    dist2, dist2_mean, dist2_cov = self._create_gaussian(dist2_type)
    both_diagonal = _is_diagonal(dist1.scale) and _is_diagonal(dist2.scale)
    if both_diagonal:
      dist1_cov = dist1.parameters['scale_diag']
      dist2_cov = dist2.parameters['scale_diag']
    kl = tfp.distributions.kl_divergence(dist1, dist2)
    kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian(
        dist1_mean,
        dist1_cov,
        dist2_mean,
        dist2_cov,
        both_diagonal=both_diagonal)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      actual_kl, kl_mean_np, kl_cov_np = sess.run([kl, kl_mean, kl_cov])
      self.assertAllClose(actual_kl, kl_mean_np + kl_cov_np, rtol=1e-4) 
开发者ID:deepmind,项目名称:trfl,代码行数:21,代码来源:distribution_ops_test.py

示例8: _get_attributes_test_params

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _get_attributes_test_params():
  model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
  data = core.MjData(model)
  # Get the names of the non-private attributes of model and data through
  # introspection. These are passed as parameters to each of the test methods
  # in AttributesTest.
  array_args = []
  scalar_args = []
  skipped_args = []
  for parent_name, parent_obj in zip(("model", "data"), (model, data)):
    for attr_name in dir(parent_obj):
      if not attr_name.startswith("_"):  # Skip 'private' attributes
        args = (parent_name, attr_name)
        attr = getattr(parent_obj, attr_name)
        if isinstance(attr, ARRAY_TYPES):
          array_args.append(args)
        elif isinstance(attr, SCALAR_TYPES):
          scalar_args.append(args)
        elif callable(attr):
          # Methods etc. should be covered specifically in CoreTest.
          continue
        else:
          skipped_args.append(args)
  return array_args, scalar_args, skipped_args 
开发者ID:deepmind,项目名称:dm_control,代码行数:26,代码来源:core_test.py

示例9: _runSingleTrainingStep

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = SSGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_size=4)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:25,代码来源:ssgan_test.py

示例10: _runSingleTrainingStep

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        conditional="biggan" in architecture)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:19,代码来源:modular_gan_test.py

示例11: testSingleTrainingStepWithJointGenForDisc

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def testSingleTrainingStepWithJointGenForDisc(self):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 120,
        "disc_iters": 2,
    }
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        experimental_joint_gen_for_disc=True,
        experimental_force_graph_unroll=True,
        conditional=True)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:19,代码来源:modular_gan_test.py

示例12: testSingleTrainingStepDiscItersWithEma

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 128,
        "dics_iters": disc_iters,
    }
    gin.bind_parameter("ModularGAN.g_use_ema", True)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)
    # Check for moving average variables in checkpoint.
    checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
    ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path)
                       if v[0].endswith("ExponentialMovingAverage")])
    tf.logging.info("ema_vars=%s", ema_vars)
    expected_ema_vars = sorted([
        "generator/fc_noise/kernel/ExponentialMovingAverage",
        "generator/fc_noise/bias/ExponentialMovingAverage",
    ])
    self.assertAllEqual(ema_vars, expected_ema_vars) 
开发者ID:google,项目名称:compare_gan,代码行数:27,代码来源:modular_gan_test.py

示例13: _runSingleTrainingStep

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
                             labeled_dataset):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        conditional=True,
        model_dir=model_dir)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
开发者ID:google,项目名称:compare_gan,代码行数:24,代码来源:modular_gan_conditional_test.py

示例14: _generate_message_parameters

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _generate_message_parameters(want_permutations=False):
  """Generate message parameters for test cases.

  Args:
    want_permutations: bool, whether or not to run the messages through various
        permutations.

  Yields:
    A list containing the list of messages.
  """
  answer_message = survey_messages.Answer(
      text='Left my laptop at home.',
      more_info_enabled=False,
      placeholder_text=None)
  survey_messages_1 = survey_messages.Question(
      question_type=survey_models.QuestionType.ASSIGNMENT,
      question_text=_QUESTION.format(num=1),
      answers=[answer_message],
      rand_weight=1,
      required=True)
  survey_messages_2 = survey_messages.Question(
      question_type=survey_models.QuestionType.ASSIGNMENT,
      question_text=_QUESTION.format(num=2),
      answers=[answer_message],
      rand_weight=1,
      enabled=False,
      required=False)
  survey_messages_3 = survey_messages.Question(
      question_type=survey_models.QuestionType.RETURN,
      question_text=_QUESTION.format(num=3),
      answers=[answer_message],
      rand_weight=1,
      enabled=True)
  messages = [
      survey_messages_1, survey_messages_2,
      survey_messages_3]
  if want_permutations:
    for p in itertools.permutations(messages):
      yield [p]
  else:
    yield [messages] 
开发者ID:google,项目名称:loaner,代码行数:43,代码来源:survey_api_test.py

示例15: _create_template_parameters

# 需要导入模块: from absl.testing import parameterized [as 别名]
# 或者: from absl.testing.parameterized import parameters [as 别名]
def _create_template_parameters():
  """Creates a template list of parameters for parameterized test cases.

  Yields:
    A list containing values for template parameters
  """
  template_name_value = 'this_template'
  body_value = 'body update test'
  title_value = 'title update test'

  template_parameters = [template_name_value, title_value, body_value]
  yield [template_parameters] 
开发者ID:google,项目名称:loaner,代码行数:14,代码来源:template_model_test.py


注:本文中的absl.testing.parameterized.parameters方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。