当前位置: 首页>>代码示例>>Python>>正文


Python function.define_function函数代码示例

本文整理汇总了Python中tensorflow.python.framework.function.define_function函数的典型用法代码示例。如果您正苦于以下问题:Python define_function函数的具体用法?Python define_function怎么用?Python define_function使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了define_function函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testGradientFunc

  def testGradientFunc(self):

    def XSquarePlusOne(x):
      return x * x + 1.0

    def XSquarePlusOneGrad(x, dy):
      dx = functional_ops._symbolic_gradient(input=[x, dy],
                                             Tout=[tf.float32],
                                             f="XSquarePlusOne",
                                             name="dx")
      return dx

    g = tf.Graph()
    with g.as_default():
      f = function.define_function(XSquarePlusOne, {"x": tf.float32})
      g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
                                                        "dy": tf.float32})
      epsilon = tf.constant([0.1])
      two = tf.constant([2.0])
      call_f = function.call_function(f, two)
      call_g = function.call_function(g, two, epsilon)

      with tf.Session() as sess:
        self.assertAllClose([5.0], sess.run(call_f))
        self.assertAllClose([0.4], sess.run(call_g))
开发者ID:13331151,项目名称:tensorflow,代码行数:25,代码来源:function_test.py

示例2: testDefineFunctionNoArgs

    def testDefineFunctionNoArgs(self):
        def AConstant():
            return tf.constant([42])

        with tf.Graph().as_default():
            f_def = function.define_function(AConstant, {})
            call = function.call_function(f_def)
            self.assertEquals("AConstant", call.op.name)
            with tf.Session() as sess:
                self.assertAllEqual([42], sess.run(call))
开发者ID:bgyss,项目名称:tensorflow,代码行数:10,代码来源:function_test.py

示例3: testStrippedOpListNestedFunctions

  def testStrippedOpListNestedFunctions(self):
    with self.test_session():
      # Square two levels deep
      def f0(x):
        return tf.square(x)
      f0 = function.define_function(f0, {"x": tf.int32})
      def f1(x):
        return function.call_function(f0, x)
      f1 = function.define_function(f1, {"x": tf.int32})

      # At this point we've defined two functions but haven't called them, so
      # there should be no used ops.
      op_list = tf.contrib.util.stripped_op_list_for_graph(
          tf.get_default_graph().as_graph_def())
      self.assertEquals(len(op_list.op), 0)

      # If we call the function on a constant, there should be two ops
      function.call_function(f1, tf.constant(7))
      op_list = tf.contrib.util.stripped_op_list_for_graph(
          tf.get_default_graph().as_graph_def())
      self.assertEquals(["Const", "Square"], [op.name for op in op_list.op])
开发者ID:2er0,项目名称:tensorflow,代码行数:21,代码来源:saver_test.py

示例4: testDefineFunction2Args

    def testDefineFunction2Args(self):
        def APlus2B(a, b):
            return a + b * 2

        with tf.Graph().as_default():
            f_def = function.define_function(APlus2B, {"a": tf.float32, "b": tf.float32})
            one = tf.constant([1.0])
            two = tf.constant([2.0])
            call = function.call_function(f_def, one, two)
            self.assertEquals("APlus2B", call.op.name)
            with tf.Session() as sess:
                self.assertAllEqual([5.0], sess.run(call))
开发者ID:bgyss,项目名称:tensorflow,代码行数:12,代码来源:function_test.py

示例5: testCallErrors

  def testCallErrors(self):

    def Const():
      return tf.constant(1)

    def PlusOne(a):
      return a + 1

    def PlusMinus(a, b):
      return a + b, b - a

    with tf.Graph().as_default():
      one = tf.constant([1])
      two = tf.constant([2])
      const = function.define_function(Const, {})
      plus_one = function.define_function(PlusOne, {"a": tf.int32})
      plus_minus = function.define_function(PlusMinus, {"a": tf.int32,
                                                        "b": tf.int32})

      function.call_function(const)
      with self.assertRaisesRegexp(ValueError, "arguments: 0"):
        function.call_function(const, one)
      with self.assertRaisesRegexp(ValueError, "arguments: 0"):
        function.call_function(const, one, two)

      with self.assertRaisesRegexp(ValueError, "arguments: 1"):
        function.call_function(plus_one)
      function.call_function(plus_one, one)
      with self.assertRaisesRegexp(ValueError, "arguments: 1"):
        function.call_function(plus_one, one, two)

      with self.assertRaisesRegexp(ValueError, "arguments: 2"):
        function.call_function(plus_minus)
      with self.assertRaisesRegexp(ValueError, "arguments: 2"):
        function.call_function(plus_minus, one)
      function.call_function(plus_minus, one, two)

      function.call_function(plus_one, one, name="p1")
      with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"):
        function.call_function(plus_one, one, device="/gpu:0")
开发者ID:13331151,项目名称:tensorflow,代码行数:40,代码来源:function_test.py

示例6: testDefineFunctionNames

    def testDefineFunctionNames(self):
        def Foo(a):
            return a + 1

        with tf.Graph().as_default():
            f_def = function.define_function(Foo, {"a": tf.float32})
            one = tf.constant([1.0])
            call1 = function.call_function(f_def, one)
            self.assertEquals("Foo", call1.op.name)
            call2 = function.call_function(f_def, one)
            self.assertEquals("Foo_1", call2.op.name)
            call3 = function.call_function(f_def, one, name="mine")
            self.assertEquals("mine", call3.op.name)
            with tf.name_scope("my"):
                call4 = function.call_function(f_def, one, name="precious")
                self.assertEquals("my/precious", call4.op.name)
开发者ID:bgyss,项目名称:tensorflow,代码行数:16,代码来源:function_test.py

示例7: testDefineErrors

  def testDefineErrors(self):

    def NoResult():
      pass

    def VarArgs(*unused_b):
      return tf.constant([1])

    def DefaultArg(unused_a=12):
      return tf.constant([1])

    def KwArgs(**unused_kwargs):
      return tf.constant([1])

    def PlusMinus(a, b):
      return a + b, b - a

    with tf.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
        function.define_function(NoResult, {})
      with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
        function.define_function(VarArgs, {})
      with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
        function.define_function(DefaultArg, {})
      with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
        function.define_function(KwArgs, {})
      with self.assertRaisesRegexp(ValueError, "specified input types"):
        function.define_function(PlusMinus, {})
      with self.assertRaisesRegexp(ValueError, "specified input types"):
        function.define_function(PlusMinus, {"c": tf.float32})
      with self.assertRaisesRegexp(ValueError, "type for argument: b"):
        function.define_function(PlusMinus, {"a": tf.float32,
                                             "c": tf.float32})
      with self.assertRaisesRegexp(ValueError, "specified input types"):
        function.define_function(PlusMinus, {"a": tf.float32,
                                             "b": tf.float32,
                                             "c": tf.float32})
开发者ID:13331151,项目名称:tensorflow,代码行数:37,代码来源:function_test.py

示例8: testDefineErrors

  def testDefineErrors(self):

    def NoResult():
      pass

    def DefaultArg(unused_a=12):
      return tf.constant([1])

    def KwArgs(**unused_kwargs):
      return tf.constant([1])

    def PlusMinus(a, b):
      return a + b, b - a

    with tf.Graph().as_default():
      # pylint: disable=expression-not-assigned
      with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
        function.define_function(NoResult, {}).definition
      with self.assertRaisesRegexp(ValueError, "are not supported"):
        function.define_function(DefaultArg, {}).definition
      with self.assertRaisesRegexp(ValueError, "are not supported"):
        function.define_function(KwArgs, {}).definition
      with self.assertRaisesRegexp(ValueError, "specified input types"):
        function.define_function(PlusMinus, {}).definition
      with self.assertRaisesRegexp(ValueError, "specified input types"):
        function.define_function(PlusMinus, {"c": tf.float32}).definition
      with self.assertRaisesRegexp(ValueError, "type for argument: b"):
        function.define_function(PlusMinus, {"a": tf.float32,
                                             "c": tf.float32}).definition
      with self.assertRaisesRegexp(ValueError, "specified input types"):
        function.define_function(PlusMinus, {"a": tf.float32,
                                             "b": tf.float32,
                                             "c": tf.float32}).definition
开发者ID:apollos,项目名称:tensorflow,代码行数:33,代码来源:function_test.py

示例9: XSquarePlusOne

import tensorflow as tf
from tensorflow.python.framework import function
from tensorflow.python.ops import functional_ops
graph = tf.Graph()
with graph.as_default():
  tt = tf.constant([4.2])
  def XSquarePlusOne(x):
    ph = tf.placeholder("float", shape=[1])
    return x * x + 1.0

  def XSquarePlusOneGrad(x, dy):
    dx = functional_ops._symbolic_gradient(input=[x, dy],
                                         Tout=[tf.float32],
                                         f="XSquarePlusOne",
                                         name="dx")
    return dx

  f = function.define_function(XSquarePlusOne, {"x": tf.float32})
  g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
                                                      "dy": tf.float32})
  epsilon = tf.constant([1.0])
  two = tf.constant([2.0])
  call_f = function.call_function(f, two)
  call_g = function.call_function(g, two, epsilon)

  tf.train.write_graph(graph.as_graph_def(), '/tmp/tfb', 'simple.pbtxt', as_text=True)

  with tf.Session() as sess:
    print sess.run(call_f)
    print sess.run(call_g)
开发者ID:LaurentMazare,项目名称:tensorflow-ocaml,代码行数:30,代码来源:gradient.py


注:本文中的tensorflow.python.framework.function.define_function函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。