当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow_hub.create_module_spec方法代码示例

本文整理汇总了Python中tensorflow_hub.create_module_spec方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow_hub.create_module_spec方法的具体用法?Python tensorflow_hub.create_module_spec怎么用?Python tensorflow_hub.create_module_spec使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow_hub的用法示例。


在下文中一共展示了tensorflow_hub.create_module_spec方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testModuleWithVariablesAndNoCheckpoint

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testModuleWithVariablesAndNoCheckpoint(self):
    with tf.Graph().as_default():
      spec = native_module.create_module_spec(module_with_variables)
      spec._create_impl(name="module", trainable=False, tags=None)
      self.assertAllEqual(
          [x.op.name for x in tf_v1.global_variables()],
          [
              "module/weights",
              "module/partition/part_0",
              "module/partition/part_1",
              "module/partition/part_2",
          ])

      with tf_v1.Session() as session:
        session.run(tf_v1.initializers.global_variables())
        expected_values = [
            [0.0, 0.0, 0.0],
            [0.0, 0.0],
            [0.0],
            [0.0],
        ]
        for a, b in zip(session.run(tf_v1.global_variables()), expected_values):
          self.assertAllEqual(a, b) 
开发者ID:tensorflow,项目名称:hub,代码行数:25,代码来源:native_module_test.py

示例2: testUnsupportedCollections

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testUnsupportedCollections(self):

    def module_fn():
      scale = tf_v1.get_variable("x", (), collections=["my_scope"])
      x = tf_v1.placeholder(tf.float32, shape=[None, 3])
      native_module.add_signature("my_func", {"x": x}, {"y": x*scale})

    with self.assertRaises(ValueError) as cm:
      _ = native_module.create_module_spec(module_fn)
      self.assertIn("Unsupported collections in graph", cm)

    with tf.Graph().as_default() as tmp_graph:
      module_fn()
      unsupported_collections = native_module.get_unsupported_collections(
          tmp_graph.get_all_collection_keys())
      self.assertEqual(["my_scope"], unsupported_collections)

    _ = native_module.create_module_spec(
        module_fn, drop_collections=unsupported_collections) 
开发者ID:tensorflow,项目名称:hub,代码行数:21,代码来源:native_module_test.py

示例3: testUseWithinWhileLoop

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testUseWithinWhileLoop(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(double_module_fn)
      m = hub.Module(spec)
      i = tf.constant(0)
      x = tf.constant(10.0)
      p = tf_v1.placeholder(dtype=tf.int32)
      c = lambda i, x: tf.less(i, p)
      b = lambda i, x: (tf.add(i, 1), m(x))
      oi, ox = tf.while_loop(c, b, [i, x])  # ox = v**p * x
      v = m.variables[0]
      dodv = tf.gradients(ox, v)[0]  # d ox / dv = p*v**(p-1) * x
      dodx = tf.gradients(ox, x)[0]  # d ox / dx = v**p
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 1}), [1, 20])
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 2}), [2, 40])
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 4}), [4, 160])
        # Gradients also use the control flow structures setup earlier.
        # Also check they are working properly.
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 1}), [10, 2])
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 2}), [40, 4])
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 4}), [320, 16])

  # tf.map_fn() is merely a wrapper around tf.while(), but just to be sure... 
开发者ID:tensorflow,项目名称:hub,代码行数:27,代码来源:native_module_test.py

示例4: testSparseTensors

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testSparseTensors(self):
    square_spec = hub.create_module_spec(sparse_square_module_fn)

    with tf.Graph().as_default():
      square = hub.Module(square_spec)
      v = tf_v1.sparse_placeholder(dtype=tf.int64, name="v")
      y = square(v)

      with tf_v1.Session().as_default():
        indices = [[0, 0], [0, 1], [1, 1]]
        values = [10, 2, 1]
        shape = [2, 2]
        v1 = tf_v1.SparseTensorValue(indices, values, shape)
        v2 = y.eval(feed_dict={v: v1})
        v4 = y.eval(feed_dict={v: v2})

        self.assertAllEqual(v4.indices, indices)  # Unchanged.
        self.assertAllEqual(v4.values, [t**4 for t in values])  # Squared twice.
        self.assertAllEqual(v4.dense_shape, shape)  # Unchanged. 
开发者ID:tensorflow,项目名称:hub,代码行数:21,代码来源:native_module_test.py

示例5: testNonResourceVariableInWhileLoop

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testNonResourceVariableInWhileLoop(self):
    with tf.Graph().as_default():
      # This test uses non-Resource variables to see an actual colocation
      # constraint propagated to the context Enter op. The long comment on
      # colocation in testResourceVariables explains why they may not offer
      # that.
      spec = hub.create_module_spec(stateful_non_rv_module_fn)
      m = hub.Module(spec)
      cond = lambda i, x: tf.less(i, 4)
      def body(i, x):
        v = m()
        self.assertItemsEqual(v.op.colocation_groups(),
                              [tf.compat.as_bytes("loc:@module/var123")])
        return (tf.add(i, 1), 2*x)
      oi, ox = tf.while_loop(cond, body, [0, 10.0])
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllEqual(sess.run([oi, ox]), [4, 160.0]) 
开发者ID:tensorflow,项目名称:hub,代码行数:20,代码来源:native_module_test.py

示例6: testNonResourceVariableInCond

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testNonResourceVariableInCond(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(stateful_non_rv_module_fn)
      m = hub.Module(spec)
      pred = tf_v1.placeholder(tf.bool)
      def true_fn():
        v = m()
        self.assertItemsEqual(v.op.colocation_groups(),
                              [tf.compat.as_bytes("loc:@module/var123")])
        return v
      def false_fn():
        return tf.constant(9.0)
      out = tf.cond(pred, true_fn, false_fn)
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertEqual(sess.run(out, feed_dict={pred: True}), 10.0)
        self.assertEqual(sess.run(out, feed_dict={pred: False}), 9.0) 
开发者ID:tensorflow,项目名称:hub,代码行数:19,代码来源:native_module_test.py

示例7: testVariableColocationPropagation

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testVariableColocationPropagation(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(stateful_module_fn_with_colocation)
      m = hub.Module(spec)
      u1 = tf.constant(1, name="u1")
      u2 = tf.constant(2, name="u2")
      with tf_v1.colocate_with(u1), tf_v1.colocate_with(u2):
        x = tf.constant(100.0, name="x")
      y = m(x)
      self.assertItemsEqual(y.op.colocation_groups(),
                            [tf.compat.as_bytes("loc:@module/var123"),
                             tf.compat.as_bytes("loc:@u1"),
                             tf.compat.as_bytes("loc:@u2")])
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertEqual(sess.run(y), 101.0) 
开发者ID:tensorflow,项目名称:hub,代码行数:18,代码来源:native_module_test.py

示例8: testPartitionedVariables

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testPartitionedVariables(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(
          create_partitioned_variable_module_fn(partitions=3, shape=[7, 3]))
      m = hub.Module(spec, name="test")
      out = m()
      self.assertEqual(len(m.variable_map), 2)
      self.assertEqual(m.variable_map["normal_variable"].name,
                       "test/normal_variable:0")
      self.assertAllEqual([
          variable.name for variable in m.variable_map["partitioned_variable"]
      ], [
          "test/partitioned_variable/part_0:0",
          "test/partitioned_variable/part_1:0",
          "test/partitioned_variable/part_2:0"
      ])
      self.assertAllEqual(  # Check deterministric order (by variable_map key).
          [variable.name for variable in m.variables],
          ["test/normal_variable:0",
           "test/partitioned_variable/part_0:0",
           "test/partitioned_variable/part_1:0",
           "test/partitioned_variable/part_2:0"])
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllClose(sess.run(out), 2 * np.ones([7, 3])) 
开发者ID:tensorflow,项目名称:hub,代码行数:27,代码来源:native_module_test.py

示例9: testLoadTrainableModuleFromFuncDef

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testLoadTrainableModuleFromFuncDef(self):
    with tf_v1.Session() as sess:
      spec = hub.create_module_spec(stateful_module_fn)
      m = hub.Module(spec, trainable=True)
      x = m()
      step = tf.Variable(0, trainable=False, name="global_step")
      train_op = tf_v1.train.GradientDescentOptimizer(0.40).minimize(
          loss=tf_v1.losses.mean_squared_error(x, [3.1, 3.2, 3.3]),
          global_step=step)
      sess.run(tf_v1.global_variables_initializer())
      for _ in range(50):
        sess.run(train_op)
      got = sess.run(x)
      self.assertAllClose(got, [3.1, 3.2, 3.3])

  # TODO(b/112575006): The following tests verify functionality of function call
  # within a TPU context. Work to generalize this for all function calls is
  # ongoing. 
开发者ID:tensorflow,项目名称:hub,代码行数:20,代码来源:native_module_test.py

示例10: testTPUModuleInitializeOnceWithDefun

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUModuleInitializeOnceWithDefun(self):
    spec = hub.create_module_spec(stateful_random_rv_module_fn)

    @function.Defun()
    def import_computation():
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, name="module_", trainable=True)
      return [m(), m()]

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = import_computation()
      sess.run(tf_v1.global_variables_initializer())
      got = sess.run(x)
      # Check the values are equal. If the initializer ran on each call,
      # the values would be different.
      self.assertEqual(got[0], got[1]) 
开发者ID:tensorflow,项目名称:hub,代码行数:19,代码来源:native_module_test.py

示例11: testTPUPruneWithUnusedInput

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUPruneWithUnusedInput(self):
    spec = hub.create_module_spec(unused_input_module_fn)

    @function.Defun()
    def import_computation(x):
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, name="module_", trainable=True)
      return m({
          "x": tf.cast(x, dtype=tf.int64),
          "unused": tf.constant(2, dtype=tf.int64)
      })

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = import_computation(5)
      got = sess.run(x)
      self.assertEqual(got, 25) 
开发者ID:tensorflow,项目名称:hub,代码行数:19,代码来源:native_module_test.py

示例12: testTPUModuleDoesntPruneControlDependencies

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUModuleDoesntPruneControlDependencies(self):
    spec = hub.create_module_spec(control_dependency_module_fn)

    @function.Defun()
    def import_computation():
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, name="module_", trainable=True)
      return m()

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = import_computation()
      got = sess.run(x)
      self.assertEqual(got, 5.0)
      # If the op got pruned, the following get_operation_by_name should fail
      # with a dependency error.
      tf_v1.get_default_graph().get_operation_by_name("module_/dependency_op") 
开发者ID:tensorflow,项目名称:hub,代码行数:19,代码来源:native_module_test.py

示例13: testTPUModuleWithWrapFunc

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUModuleWithWrapFunc(self):
    spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

    def import_computation(first, second):
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, trainable=True)
      return [m(first), m(second)]

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = tf_v1.wrap_function(
          import_computation,
          [tf.TensorSpec((), tf.float32),
           tf.TensorSpec((), tf.float32)])
      sess.run(tf_v1.global_variables_initializer())
      got = sess.run(x(9.0, 6.0))
      self.assertEqual(got, [19.0, 16.0]) 
开发者ID:tensorflow,项目名称:hub,代码行数:19,代码来源:native_module_test.py

示例14: testModuleWithLayers

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testModuleWithLayers(self):
    export_path = os.path.join(self.get_temp_dir(), "layers-module")

    sample_input = [[1.0, 2.0], [3.1, 10.0]]

    spec = hub.create_module_spec(layers_module_fn)
    with tf.Graph().as_default():
      m = hub.Module(spec, trainable=False)
      x = tf_v1.placeholder(dtype=tf.float32)
      y = m(x)
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        sample_output = sess.run(y, feed_dict={x: sample_input})
        m.export(export_path, sess)

    with tf.Graph().as_default():
      x = tf_v1.placeholder(dtype=tf.float32)
      y = hub.Module(export_path)(x)

      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        got = sess.run(y, feed_dict={x: sample_input})
        self.assertAllEqual(got, sample_output) 
开发者ID:tensorflow,项目名称:hub,代码行数:25,代码来源:native_module_test.py

示例15: testInputsFromMultivaluedOp

# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testInputsFromMultivaluedOp(self):
    """Tests warning for inputs from multivalued ops in create_module_spec."""
    # Ideally, one would be able to write
    #    with self.assertLogs("blah"): hub.create_module_spec(module_fn)
    # but in the absence of assertions on logs, we test the underlying helper
    # in the environment seen from within a module_fn.
    with tf.Graph().as_default():
      first, _ = tf.split([[1, 2], [3, 4]], 2, name="split1")
      _, second = tf.split([[5, 6], [7, 8]], 2, name="split2")
      third = tf.constant(105, name="const")
      message = native_module.find_signature_inputs_from_multivalued_ops(
          dict(first=first, second=second, third=third))
    self.assertRegexpMatches(
        message,
        ".*single output.*\n"
        "Affected inputs: first='split1:0', second='split2:1'$")
    # Also test the case of no errors.
    with tf.Graph().as_default():
      first = tf.constant(101)
      second = tf.constant(102)
      third = tf.constant(103)
      message = native_module.find_signature_inputs_from_multivalued_ops(
          dict(first=first, second=second, third=third))
    self.assertIsNone(message) 
开发者ID:tensorflow,项目名称:hub,代码行数:26,代码来源:native_module_test.py


注:本文中的tensorflow_hub.create_module_spec方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。