本文整理汇总了Python中tensorflow.Module方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.Module方法的具体用法?Python tensorflow.Module怎么用?Python tensorflow.Module使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.Module方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: build
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def build(self, input_shape):
if self.share_parameters:
# When sharing parameters, build the first leaf inputter and then set
# all attributes with parameters to the other inputters.
leaves = self.get_leaf_inputters()
first, others = leaves[0], leaves[1:]
first.build(input_shape)
for name, attr in first.__dict__.copy().items():
if isinstance(attr, tf.Variable) or (isinstance(attr, tf.Module) and attr.variables):
for inputter in others:
setattr(inputter, name, attr)
inputter.built = True
else:
for inputter in self.inputters:
inputter.build(input_shape)
super(ParallelInputter, self).build(input_shape)
示例2: set_dropout
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def set_dropout(root_layer, dropout):
"""Overrides all dropout values in :obj:`root_layer` and its descendants.
Args:
dropout: The dropout value to set.
Raises:
ValueError: if :obj:`root_layer` is not a ``tf.Module``.
"""
if not isinstance(root_layer, tf.Module):
raise ValueError("Layer should be a tf.Module")
for layer in (root_layer,) + root_layer.submodules:
for attr, value in layer.__dict__.copy().items():
if isinstance(value, tf.keras.layers.Dropout):
value.rate = dropout
elif "dropout" in attr:
setattr(layer, attr, dropout)
示例3: testGetVariableName
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def testGetVariableName(self):
class Layer(tf.Module):
def __init__(self):
super(Layer, self).__init__()
self.variable = tf.Variable(0)
class Model(tf.Module):
def __init__(self):
super(Model, self).__init__()
self.layers = [Layer()]
model = Model()
variable = model.layers[0].variable
expected_name = "model/layers/0/variable/.ATTRIBUTES/VARIABLE_VALUE"
variable_name = misc.get_variable_name(variable, model)
self.assertEqual(variable_name, expected_name)
variables_to_names, names_to_variables = misc.get_variables_name_mapping(model, root_key="model")
self.assertDictEqual(variables_to_names, {variable.ref(): expected_name})
self.assertDictEqual(names_to_variables, {expected_name: variable})
示例4: test_computation_callable
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def test_computation_callable(self):
tf_module = tf.Module()
fn = lambda x: x + 1.0
sig = [tf.TensorSpec([], tf.float32)]
tf_module.foo = tf.function(fn, input_signature=sig)
with tempfile.TemporaryDirectory() as model_dir:
save_options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, model_dir, options=save_options)
iree_compiler_module = iree_compiler.tf_load_saved_model(
model_dir, exported_names=['foo'])
my_computation_module = computation_module.ComputationModule(
iree_compiler_module, 'foo',
computation_types.FunctionType(tf.float32, tf.float32))
computation_callable = runtime.ComputationCallable(
my_computation_module, backend_info.VULKAN_SPIRV)
self.assertTrue(callable(computation_callable))
result = computation_callable(np.float32(5.0))
self.assertEqual(result, 6.0)
示例5: test_module_class_with_add_one
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def test_module_class_with_add_one(self):
tf_module = tf.Module()
tf_module.foo = tf.function(
lambda x: x + 1.0,
input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
model_dir = '/tmp/foo'
save_options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, model_dir, options=save_options)
iree_compiler_module = iree_compiler.tf_load_saved_model(
model_dir, exported_names=['foo'])
my_computation_module = computation_module.ComputationModule(
iree_compiler_module, 'foo',
computation_types.FunctionType(tf.float32, tf.float32))
self.assertIs(my_computation_module.compiler_module, iree_compiler_module)
self.assertEqual(my_computation_module.function_name, 'foo')
self.assertEqual(
str(my_computation_module.type_signature), '(float32 -> float32)')
示例6: _create_tflite_model
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def _create_tflite_model():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return
try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return
root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)
input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model
示例7: _get_direct_children
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def _get_direct_children(layer):
children = []
for name, attr in layer.__dict__.items():
if name.startswith("_"):
continue
if (isinstance(attr, tf.Module)
or (isinstance(attr, list) and attr and isinstance(attr[0], tf.Module))):
children.append((name, attr))
return children
示例8: save_checkpoint
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def save_checkpoint(self, state: Any, round_num: int) -> None:
"""Saves a new checkpointed `state` for the given `round_num`.
Args:
state: A nested structure which `tf.convert_to_tensor` supports.
round_num: An integer representing the current training round.
"""
basename = '{}{}'.format(self._prefix, round_num)
checkpoint_path = os.path.join(self._root_dir, basename)
flat_obj = tf.nest.flatten(state)
model = tf.Module()
model.obj = flat_obj
model.build_obj_fn = tf.function(lambda: model.obj, input_signature=())
# First write to a temporary directory.
temp_basename = '.temp_{}'.format(basename)
temp_path = os.path.join(self._root_dir, temp_basename)
try:
tf.io.gfile.rmtree(temp_path)
except tf.errors.NotFoundError:
pass
tf.io.gfile.makedirs(temp_path)
tf.saved_model.save(model, temp_path, signatures={})
# Rename the temp directory to the final location atomically.
tf.io.gfile.rename(temp_path, checkpoint_path)
logging.info('Checkpoint saved: %s', checkpoint_path)
self._clear_old_checkpoints()
示例9: save
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def save(obj, export_dir, prefix=None):
r"""Save a nested structure to `export_dir`.
Note: to be compatible with `latest_checkpoint`, the basename of `export_dir`
must follow the regular expression pattern `<prefix>\d+`, where the final
digit matcher determines the ordering of the checkpoints.
Args:
obj: A nested structure which `tf.convert_to_tensor` supports.
export_dir: A directory in which to write the state.
prefix: The common prefix shared by all checkpoint directories. If provided,
we will fail if the export directory doesn't match this prefix. If not
provided, no check will be performed.
Raises:
ValueError: If `prefix` is provided and `export_dir` doesn't use the prefix.
"""
if prefix is not None and get_serial_number(export_dir, prefix) < 0:
raise ValueError('Checkpoint dir "{}" is not named like "{}XXXX!'.format(
export_dir, prefix))
model = tf.Module()
model.obj = tf.nest.flatten(obj)
model.build_obj_fn = tf.function(lambda: model.obj, input_signature=())
# First write to a temporary directory.
temp_export_dir = os.path.join(
os.path.dirname(export_dir), '.temp_' + os.path.basename(export_dir))
try:
tf.io.gfile.rmtree(temp_export_dir)
except tf.errors.NotFoundError:
pass
tf.io.gfile.makedirs(temp_export_dir)
tf.saved_model.save(model, temp_export_dir, signatures={})
# Rename the temp directory to the final location atomically.
tf.io.gfile.rename(temp_export_dir, export_dir)
logging.info('Checkpoint saved to: %s', export_dir)
示例10: serialize_dataset
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def serialize_dataset(
dataset,
max_serialized_size_bytes=DEFAULT_MAX_SERIALIZED_SEQUENCE_SIZE_BYTES):
"""Serializes a `tf.data.Dataset` value into a `bytes` object.
Args:
dataset: A `tf.data.Dataset`.
max_serialized_size_bytes: An `int` size in bytes designating the threshold
on when to raise an error if the resulting serialization is too big.
Returns:
A `bytes` object that can be sent to
`tensorflow_serialization.deserialize_dataset` to recover the original
`tf.data.Dataset`.
Raises:
SerializationError: if there was an error in TensorFlow during
serialization.
"""
py_typecheck.check_type(dataset,
type_conversions.TF_DATASET_REPRESENTATION_TYPES)
module = tf.Module()
module.dataset = dataset
module.dataset_fn = tf.function(lambda: module.dataset, input_signature=())
temp_dir = tempfile.mkdtemp('dataset')
fd, temp_zip = tempfile.mkstemp('zip')
os.close(fd)
try:
tf.saved_model.save(module, temp_dir, signatures={})
with zipfile.ZipFile(temp_zip, 'w') as z:
for topdir, _, filenames in tf.io.gfile.walk(temp_dir):
dest_dir = topdir[len(temp_dir):]
for filename in filenames:
z.write(
os.path.join(topdir, filename), os.path.join(dest_dir, filename))
with open(temp_zip, 'rb') as z:
zip_bytes = z.read()
except Exception as e: # pylint: disable=broad-except
raise SerializationError(
'Error serializing tff.Sequence value. Inner error: {!s}'.format(
e)) from e
finally:
tf.io.gfile.rmtree(temp_dir)
tf.io.gfile.remove(temp_zip)
if len(zip_bytes) > max_serialized_size_bytes:
raise ValueError('Serialized size of Dataset ({:d} bytes) exceeds maximum '
'allowed ({:d} bytes)'.format(
len(zip_bytes), max_serialized_size_bytes))
return zip_bytes