当前位置: 首页>>代码示例>>Python>>正文


Python template.make_template函数代码示例

本文整理汇总了Python中tensorflow.python.ops.template.make_template函数的典型用法代码示例。如果您正苦于以下问题:Python make_template函数的具体用法?Python make_template怎么用?Python make_template使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了make_template函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_custom_getter

  def test_custom_getter(self):
    # Custom getter that maintains call count and forwards to true getter
    custom_getter_count = [0]

    def custom_getter(getter, name, *args, **kwargs):
      custom_getter_count[0] += 1
      return getter(name, *args, **kwargs)

    # Test that custom getter is called both when variables are created and
    # subsequently accessed
    tmpl1 = template.make_template(
        "s1", variable_scoped_function, custom_getter_=custom_getter)
    self.assertEqual(custom_getter_count[0], 0)
    tmpl1()
    self.assertEqual(custom_getter_count[0], 1)
    tmpl1()
    self.assertEqual(custom_getter_count[0], 2)

    # Test that custom getter is called when the variable scope is created
    # during construction
    custom_getter_count[0] = 0
    tmpl2 = template.make_template(
        "s2",
        variable_scoped_function,
        custom_getter_=custom_getter,
        create_scope_now_=True)
    self.assertEqual(custom_getter_count[0], 0)
    tmpl2()
    self.assertEqual(custom_getter_count[0], 1)
    tmpl2()
    self.assertEqual(custom_getter_count[0], 2)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:31,代码来源:template_test.py

示例2: nested_template

 def nested_template():
   nested1 = template.make_template("nested", variable_scoped_function)
   nested2 = template.make_template("nested", variable_scoped_function)
   v1 = nested1()
   v2 = nested2()
   self.assertNotEqual(v1, v2)
   return v2
开发者ID:TianyouLi,项目名称:tensorflow,代码行数:7,代码来源:template_test.py

示例3: _outer_template

 def _outer_template():
   first_inner = template.make_template("i1", _inner_template)
   second_inner = template.make_template("i2", _inner_template)
   v1 = first_inner()
   v2 = second_inner()
   v3 = second_inner()
   return (first_inner, second_inner), (v1, v2, v3)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:7,代码来源:checkpointable_utils_test.py

示例4: test_unique_name_raise_error_in_eager

 def test_unique_name_raise_error_in_eager(self):
   with context.eager_mode():
     with self.assertRaisesRegexp(
         ValueError,
         "unique_name_ cannot be used when eager exeuction is enabled."):
       template.make_template(
           "_", variable_scoped_function, unique_name_="s1")
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:7,代码来源:template_test.py

示例5: test_checkpointable_save_restore

  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer())
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer())
      return v, v + 1., v2

    save_template = template.make_template("s1", _templated)
    save_root = checkpointable_utils.Checkpoint(my_template=save_template)
    v1_save, _, v2_save = save_template()
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_root = checkpointable_utils.Checkpoint(my_template=load_template)
    status = load_root.restore(save_path)
    var, var_plus_one, var2 = load_template()
    self.assertEqual(2, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
开发者ID:DILASSS,项目名称:tensorflow,代码行数:29,代码来源:checkpointable_utils_test.py

示例6: test_unique_name_raise_error

 def test_unique_name_raise_error(self):
   tmpl1 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   tmpl1()
   tmpl2 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   with self.assertRaises(ValueError):
     tmpl2()
开发者ID:Immexxx,项目名称:tensorflow,代码行数:8,代码来源:template_test.py

示例7: test_same_unique_name_raise_error

 def test_same_unique_name_raise_error(self):
   tmpl1 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   tmpl1()
   tmpl2 = template.make_template(
       "_", variable_scoped_function, unique_name_="s1")
   with self.assertRaisesRegexp(
       ValueError, "Variable s1/dummy already exists, disallowed.*"):
     tmpl2()
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:9,代码来源:template_test.py

示例8: test_nested_templates_with_defun

  def test_nested_templates_with_defun(self):

    def variable_scoped_function_no_return_value(trainable=True):
      # defun cannot compile functions that return non-Tensor objects
      _ = variable_scope.get_variable(
          "dummy",
          shape=[1],
          trainable=trainable,
          initializer=init_ops.zeros_initializer())

    def nested_template():
      nested1 = template.make_template_internal(
          "nested",
          variable_scoped_function_no_return_value,
          create_graph_function_=True)
      nested2 = template.make_template_internal(
          "nested",
          variable_scoped_function_no_return_value,
          create_graph_function_=True)
      nested1()
      nested2()
      v1 = nested1.variables
      v2 = nested2.variables

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, v1)
      self.assertEqual(nested2.variables, v2)
      self.assertEqual(nested1.trainable_variables, v1)
      self.assertEqual(nested2.trainable_variables, v2)
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    tmpl1()
    v1 = tmpl1.variables
    tmpl1()
    v2 = tmpl1.variables
    tmpl2()
    v3 = tmpl2.variables

    # The second invocation of tmpl1 should reuse the variables
    # created in the first invocation.
    self.assertSequenceEqual(v1, v2)

    # tmpl1 and tmpl2 should not share variables.
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested/dummy:0", v1[0].name)
    self.assertEqual("s1/nested_1/dummy:0", v1[1].name)
    self.assertEqual("s1_1/nested/dummy:0", v3[0].name)
    self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:56,代码来源:template_test.py

示例9: test_template_with_name

  def test_template_with_name(self):
    tmpl1 = template.make_template("s1", variable_scoped_function)
    tmpl2 = template.make_template("s1", variable_scoped_function)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/dummy:0", v1.name)
    self.assertEqual("s1_1/dummy:0", v3.name)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:11,代码来源:template_test.py

示例10: test_make_template

    def test_make_template(self):
        # Test both that we can call it with positional and keywords.
        tmpl1 = template.make_template("s1", internally_var_scoped_function, scope_name="test")
        tmpl2 = template.make_template("s1", internally_var_scoped_function, scope_name="test")

        v1 = tmpl1()
        v2 = tmpl1()
        v3 = tmpl2()
        self.assertEqual(v1, v2)
        self.assertNotEqual(v1, v3)
        self.assertEqual("s1/test/dummy:0", v1.name)
        self.assertEqual("s1_1/test/dummy:0", v3.name)
开发者ID:brchiu,项目名称:tensorflow,代码行数:12,代码来源:template_test.py

示例11: test_unique_name_and_reuse

    def test_unique_name_and_reuse(self):
        tmpl1 = template.make_template("_", var_scoped_function, unique_name_="s1")
        v1 = tmpl1()
        v2 = tmpl1()

        tf.get_variable_scope().reuse_variables()
        tmpl2 = template.make_template("_", var_scoped_function, unique_name_="s1")
        v3 = tmpl2()

        self.assertEqual(v1, v2)
        self.assertEqual(v1, v3)
        self.assertEqual("s1/dummy:0", v1.name)
开发者ID:brchiu,项目名称:tensorflow,代码行数:12,代码来源:template_test.py

示例12: test_template_with_internal_reuse

  def test_template_with_internal_reuse(self):
    tmpl1 = template.make_template("s1", internally_variable_scoped_function)
    tmpl2 = template.make_template("s1", internally_variable_scoped_function)

    v1 = tmpl1("test")
    v2 = tmpl1("test")
    v3 = tmpl2("test")
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/test/dummy:0", v1.name)
    self.assertEqual("s1_1/test/dummy:0", v3.name)

    with self.assertRaises(ValueError):
      tmpl1("not_test")
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:14,代码来源:template_test.py

示例13: test_checkpointable_save_restore

  def test_checkpointable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(
        self,
        [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
        checkpointable_utils.list_objects(save_template))
    manual_dep, = manual_scope._checkpoint_dependencies
    self.assertEqual("in_manual_scope", manual_dep.name)
    self.assertIs(manual_scope_v, manual_dep.ref)
    optimizer = adam.AdamOptimizer(0.0)
    save_root = checkpointable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in save_template.variables])
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = checkpointable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(3, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    self.assertEqual("ManualScope",
                     load_template._checkpoint_dependencies[2].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:49,代码来源:util_with_v1_optimizers_test.py

示例14: test_template_in_scope

  def test_template_in_scope(self):
    tmpl1 = template.make_template("s1", variable_scoped_function)
    tmpl2 = template.make_template("s1", variable_scoped_function)

    with variable_scope.variable_scope("scope"):
      v1 = tmpl1()
      v3 = tmpl2()

    # The template contract requires the following to ignore scope2.
    with variable_scope.variable_scope("scope2"):
      v2 = tmpl1()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("scope/s1/dummy:0", v1.name)
    self.assertEqual("scope/s1_1/dummy:0", v3.name)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:15,代码来源:template_test.py

示例15: test_nested_templates

  def test_nested_templates(self):

    def nested_template():
      nested1 = template.make_template("nested", variable_scoped_function)
      nested2 = template.make_template("nested", variable_scoped_function)
      v1 = nested1()
      v2 = nested2()

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, [v1])
      self.assertEqual(nested2.variables, [v2])
      self.assertEqual(nested1.trainable_variables, [v1])
      self.assertEqual(nested2.trainable_variables, [v2])
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)
      return v1, v2

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    v1, v2 = tmpl1()
    v3, v4 = tmpl1()
    v5, v6 = tmpl2()

    # The second invocation of tmpl1 should reuse the variables
    # created in the first invocation.
    self.assertEqual([v1, v2], [v3, v4])
    self.assertEqual(tmpl1.variables, [v1, v2])
    self.assertEqual(tmpl1.trainable_variables, [v1, v2])
    self.assertEqual(len(tmpl1.non_trainable_variables), 0)

    # tmpl1 and tmpl2 should not share variables.
    self.assertNotEqual([v1, v2], [v5, v6])
    self.assertSequenceEqual(tmpl2.variables, [v5, v6])
    self.assertSequenceEqual(tmpl2.trainable_variables, [v5, v6])
    self.assertEqual(len(tmpl2.non_trainable_variables), 0)
    self.assertEqual("s1/nested/dummy:0", v1.name)
    self.assertEqual("s1/nested_1/dummy:0", v2.name)
    self.assertEqual("s1_1/nested/dummy:0", v5.name)
    self.assertEqual("s1_1/nested_1/dummy:0", v6.name)

    self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
    self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
    self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:48,代码来源:template_test.py


注:本文中的tensorflow.python.ops.template.make_template函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。