本文整理汇总了Python中tensorflow_hub.create_module_spec方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow_hub.create_module_spec方法的具体用法?Python tensorflow_hub.create_module_spec怎么用?Python tensorflow_hub.create_module_spec使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow_hub
的用法示例。
在下文中一共展示了tensorflow_hub.create_module_spec方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testModuleWithVariablesAndNoCheckpoint
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testModuleWithVariablesAndNoCheckpoint(self):
with tf.Graph().as_default():
spec = native_module.create_module_spec(module_with_variables)
spec._create_impl(name="module", trainable=False, tags=None)
self.assertAllEqual(
[x.op.name for x in tf_v1.global_variables()],
[
"module/weights",
"module/partition/part_0",
"module/partition/part_1",
"module/partition/part_2",
])
with tf_v1.Session() as session:
session.run(tf_v1.initializers.global_variables())
expected_values = [
[0.0, 0.0, 0.0],
[0.0, 0.0],
[0.0],
[0.0],
]
for a, b in zip(session.run(tf_v1.global_variables()), expected_values):
self.assertAllEqual(a, b)
示例2: testUnsupportedCollections
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testUnsupportedCollections(self):
def module_fn():
scale = tf_v1.get_variable("x", (), collections=["my_scope"])
x = tf_v1.placeholder(tf.float32, shape=[None, 3])
native_module.add_signature("my_func", {"x": x}, {"y": x*scale})
with self.assertRaises(ValueError) as cm:
_ = native_module.create_module_spec(module_fn)
self.assertIn("Unsupported collections in graph", cm)
with tf.Graph().as_default() as tmp_graph:
module_fn()
unsupported_collections = native_module.get_unsupported_collections(
tmp_graph.get_all_collection_keys())
self.assertEqual(["my_scope"], unsupported_collections)
_ = native_module.create_module_spec(
module_fn, drop_collections=unsupported_collections)
示例3: testUseWithinWhileLoop
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testUseWithinWhileLoop(self):
with tf.Graph().as_default():
spec = hub.create_module_spec(double_module_fn)
m = hub.Module(spec)
i = tf.constant(0)
x = tf.constant(10.0)
p = tf_v1.placeholder(dtype=tf.int32)
c = lambda i, x: tf.less(i, p)
b = lambda i, x: (tf.add(i, 1), m(x))
oi, ox = tf.while_loop(c, b, [i, x]) # ox = v**p * x
v = m.variables[0]
dodv = tf.gradients(ox, v)[0] # d ox / dv = p*v**(p-1) * x
dodx = tf.gradients(ox, x)[0] # d ox / dx = v**p
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 1}), [1, 20])
self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 2}), [2, 40])
self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 4}), [4, 160])
# Gradients also use the control flow structures setup earlier.
# Also check they are working properly.
self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 1}), [10, 2])
self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 2}), [40, 4])
self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 4}), [320, 16])
# tf.map_fn() is merely a wrapper around tf.while(), but just to be sure...
示例4: testSparseTensors
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testSparseTensors(self):
square_spec = hub.create_module_spec(sparse_square_module_fn)
with tf.Graph().as_default():
square = hub.Module(square_spec)
v = tf_v1.sparse_placeholder(dtype=tf.int64, name="v")
y = square(v)
with tf_v1.Session().as_default():
indices = [[0, 0], [0, 1], [1, 1]]
values = [10, 2, 1]
shape = [2, 2]
v1 = tf_v1.SparseTensorValue(indices, values, shape)
v2 = y.eval(feed_dict={v: v1})
v4 = y.eval(feed_dict={v: v2})
self.assertAllEqual(v4.indices, indices) # Unchanged.
self.assertAllEqual(v4.values, [t**4 for t in values]) # Squared twice.
self.assertAllEqual(v4.dense_shape, shape) # Unchanged.
示例5: testNonResourceVariableInWhileLoop
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testNonResourceVariableInWhileLoop(self):
with tf.Graph().as_default():
# This test uses non-Resource variables to see an actual colocation
# constraint propagated to the context Enter op. The long comment on
# colocation in testResourceVariables explains why they may not offer
# that.
spec = hub.create_module_spec(stateful_non_rv_module_fn)
m = hub.Module(spec)
cond = lambda i, x: tf.less(i, 4)
def body(i, x):
v = m()
self.assertItemsEqual(v.op.colocation_groups(),
[tf.compat.as_bytes("loc:@module/var123")])
return (tf.add(i, 1), 2*x)
oi, ox = tf.while_loop(cond, body, [0, 10.0])
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
self.assertAllEqual(sess.run([oi, ox]), [4, 160.0])
示例6: testNonResourceVariableInCond
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testNonResourceVariableInCond(self):
with tf.Graph().as_default():
spec = hub.create_module_spec(stateful_non_rv_module_fn)
m = hub.Module(spec)
pred = tf_v1.placeholder(tf.bool)
def true_fn():
v = m()
self.assertItemsEqual(v.op.colocation_groups(),
[tf.compat.as_bytes("loc:@module/var123")])
return v
def false_fn():
return tf.constant(9.0)
out = tf.cond(pred, true_fn, false_fn)
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
self.assertEqual(sess.run(out, feed_dict={pred: True}), 10.0)
self.assertEqual(sess.run(out, feed_dict={pred: False}), 9.0)
示例7: testVariableColocationPropagation
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testVariableColocationPropagation(self):
with tf.Graph().as_default():
spec = hub.create_module_spec(stateful_module_fn_with_colocation)
m = hub.Module(spec)
u1 = tf.constant(1, name="u1")
u2 = tf.constant(2, name="u2")
with tf_v1.colocate_with(u1), tf_v1.colocate_with(u2):
x = tf.constant(100.0, name="x")
y = m(x)
self.assertItemsEqual(y.op.colocation_groups(),
[tf.compat.as_bytes("loc:@module/var123"),
tf.compat.as_bytes("loc:@u1"),
tf.compat.as_bytes("loc:@u2")])
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
self.assertEqual(sess.run(y), 101.0)
示例8: testPartitionedVariables
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testPartitionedVariables(self):
with tf.Graph().as_default():
spec = hub.create_module_spec(
create_partitioned_variable_module_fn(partitions=3, shape=[7, 3]))
m = hub.Module(spec, name="test")
out = m()
self.assertEqual(len(m.variable_map), 2)
self.assertEqual(m.variable_map["normal_variable"].name,
"test/normal_variable:0")
self.assertAllEqual([
variable.name for variable in m.variable_map["partitioned_variable"]
], [
"test/partitioned_variable/part_0:0",
"test/partitioned_variable/part_1:0",
"test/partitioned_variable/part_2:0"
])
self.assertAllEqual( # Check deterministric order (by variable_map key).
[variable.name for variable in m.variables],
["test/normal_variable:0",
"test/partitioned_variable/part_0:0",
"test/partitioned_variable/part_1:0",
"test/partitioned_variable/part_2:0"])
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
self.assertAllClose(sess.run(out), 2 * np.ones([7, 3]))
示例9: testLoadTrainableModuleFromFuncDef
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testLoadTrainableModuleFromFuncDef(self):
with tf_v1.Session() as sess:
spec = hub.create_module_spec(stateful_module_fn)
m = hub.Module(spec, trainable=True)
x = m()
step = tf.Variable(0, trainable=False, name="global_step")
train_op = tf_v1.train.GradientDescentOptimizer(0.40).minimize(
loss=tf_v1.losses.mean_squared_error(x, [3.1, 3.2, 3.3]),
global_step=step)
sess.run(tf_v1.global_variables_initializer())
for _ in range(50):
sess.run(train_op)
got = sess.run(x)
self.assertAllClose(got, [3.1, 3.2, 3.3])
# TODO(b/112575006): The following tests verify functionality of function call
# within a TPU context. Work to generalize this for all function calls is
# ongoing.
示例10: testTPUModuleInitializeOnceWithDefun
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUModuleInitializeOnceWithDefun(self):
spec = hub.create_module_spec(stateful_random_rv_module_fn)
@function.Defun()
def import_computation():
context = TPUReplicateContext()
context.Enter()
m = hub.Module(spec, name="module_", trainable=True)
return [m(), m()]
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
x = import_computation()
sess.run(tf_v1.global_variables_initializer())
got = sess.run(x)
# Check the values are equal. If the initializer ran on each call,
# the values would be different.
self.assertEqual(got[0], got[1])
示例11: testTPUPruneWithUnusedInput
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUPruneWithUnusedInput(self):
spec = hub.create_module_spec(unused_input_module_fn)
@function.Defun()
def import_computation(x):
context = TPUReplicateContext()
context.Enter()
m = hub.Module(spec, name="module_", trainable=True)
return m({
"x": tf.cast(x, dtype=tf.int64),
"unused": tf.constant(2, dtype=tf.int64)
})
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
x = import_computation(5)
got = sess.run(x)
self.assertEqual(got, 25)
示例12: testTPUModuleDoesntPruneControlDependencies
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUModuleDoesntPruneControlDependencies(self):
spec = hub.create_module_spec(control_dependency_module_fn)
@function.Defun()
def import_computation():
context = TPUReplicateContext()
context.Enter()
m = hub.Module(spec, name="module_", trainable=True)
return m()
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
x = import_computation()
got = sess.run(x)
self.assertEqual(got, 5.0)
# If the op got pruned, the following get_operation_by_name should fail
# with a dependency error.
tf_v1.get_default_graph().get_operation_by_name("module_/dependency_op")
示例13: testTPUModuleWithWrapFunc
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testTPUModuleWithWrapFunc(self):
spec = hub.create_module_spec(stateful_rv_with_input_module_fn)
def import_computation(first, second):
context = TPUReplicateContext()
context.Enter()
m = hub.Module(spec, trainable=True)
return [m(first), m(second)]
with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
x = tf_v1.wrap_function(
import_computation,
[tf.TensorSpec((), tf.float32),
tf.TensorSpec((), tf.float32)])
sess.run(tf_v1.global_variables_initializer())
got = sess.run(x(9.0, 6.0))
self.assertEqual(got, [19.0, 16.0])
示例14: testModuleWithLayers
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testModuleWithLayers(self):
export_path = os.path.join(self.get_temp_dir(), "layers-module")
sample_input = [[1.0, 2.0], [3.1, 10.0]]
spec = hub.create_module_spec(layers_module_fn)
with tf.Graph().as_default():
m = hub.Module(spec, trainable=False)
x = tf_v1.placeholder(dtype=tf.float32)
y = m(x)
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
sample_output = sess.run(y, feed_dict={x: sample_input})
m.export(export_path, sess)
with tf.Graph().as_default():
x = tf_v1.placeholder(dtype=tf.float32)
y = hub.Module(export_path)(x)
with tf_v1.Session() as sess:
sess.run(tf_v1.global_variables_initializer())
got = sess.run(y, feed_dict={x: sample_input})
self.assertAllEqual(got, sample_output)
示例15: testInputsFromMultivaluedOp
# 需要导入模块: import tensorflow_hub [as 别名]
# 或者: from tensorflow_hub import create_module_spec [as 别名]
def testInputsFromMultivaluedOp(self):
"""Tests warning for inputs from multivalued ops in create_module_spec."""
# Ideally, one would be able to write
# with self.assertLogs("blah"): hub.create_module_spec(module_fn)
# but in the absence of assertions on logs, we test the underlying helper
# in the environment seen from within a module_fn.
with tf.Graph().as_default():
first, _ = tf.split([[1, 2], [3, 4]], 2, name="split1")
_, second = tf.split([[5, 6], [7, 8]], 2, name="split2")
third = tf.constant(105, name="const")
message = native_module.find_signature_inputs_from_multivalued_ops(
dict(first=first, second=second, third=third))
self.assertRegexpMatches(
message,
".*single output.*\n"
"Affected inputs: first='split1:0', second='split2:1'$")
# Also test the case of no errors.
with tf.Graph().as_default():
first = tf.constant(101)
second = tf.constant(102)
third = tf.constant(103)
message = native_module.find_signature_inputs_from_multivalued_ops(
dict(first=first, second=second, third=third))
self.assertIsNone(message)