当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.Module方法代码示例

本文整理汇总了Python中tensorflow.Module方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.Module方法的具体用法?Python tensorflow.Module怎么用?Python tensorflow.Module使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.Module方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: build

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def build(self, input_shape):
    if self.share_parameters:
      # When sharing parameters, build the first leaf inputter and then set
      # all attributes with parameters to the other inputters.
      leaves = self.get_leaf_inputters()
      first, others = leaves[0], leaves[1:]
      first.build(input_shape)
      for name, attr in first.__dict__.copy().items():
        if isinstance(attr, tf.Variable) or (isinstance(attr, tf.Module) and attr.variables):
          for inputter in others:
            setattr(inputter, name, attr)
            inputter.built = True
    else:
      for inputter in self.inputters:
        inputter.build(input_shape)
    super(ParallelInputter, self).build(input_shape) 
开发者ID:OpenNMT,项目名称:OpenNMT-tf,代码行数:18,代码来源:inputter.py

示例2: set_dropout

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def set_dropout(root_layer, dropout):
  """Overrides all dropout values in :obj:`root_layer` and its descendants.

  Args:
    dropout: The dropout value to set.

  Raises:
    ValueError: if :obj:`root_layer` is not a ``tf.Module``.
  """
  if not isinstance(root_layer, tf.Module):
    raise ValueError("Layer should be a tf.Module")
  for layer in (root_layer,) + root_layer.submodules:
    for attr, value in layer.__dict__.copy().items():
      if isinstance(value, tf.keras.layers.Dropout):
        value.rate = dropout
      elif "dropout" in attr:
        setattr(layer, attr, dropout) 
开发者ID:OpenNMT,项目名称:OpenNMT-tf,代码行数:19,代码来源:misc.py

示例3: testGetVariableName

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def testGetVariableName(self):

    class Layer(tf.Module):
      def __init__(self):
        super(Layer, self).__init__()
        self.variable = tf.Variable(0)

    class Model(tf.Module):
      def __init__(self):
        super(Model, self).__init__()
        self.layers = [Layer()]

    model = Model()
    variable = model.layers[0].variable
    expected_name = "model/layers/0/variable/.ATTRIBUTES/VARIABLE_VALUE"
    variable_name = misc.get_variable_name(variable, model)
    self.assertEqual(variable_name, expected_name)

    variables_to_names, names_to_variables = misc.get_variables_name_mapping(model, root_key="model")
    self.assertDictEqual(variables_to_names, {variable.ref(): expected_name})
    self.assertDictEqual(names_to_variables, {expected_name: variable}) 
开发者ID:OpenNMT,项目名称:OpenNMT-tf,代码行数:23,代码来源:misc_test.py

示例4: test_computation_callable

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def test_computation_callable(self):
    tf_module = tf.Module()
    fn = lambda x: x + 1.0
    sig = [tf.TensorSpec([], tf.float32)]
    tf_module.foo = tf.function(fn, input_signature=sig)
    with tempfile.TemporaryDirectory() as model_dir:
      save_options = tf.saved_model.SaveOptions(save_debug_info=True)
      tf.saved_model.save(tf_module, model_dir, options=save_options)
      iree_compiler_module = iree_compiler.tf_load_saved_model(
          model_dir, exported_names=['foo'])
    my_computation_module = computation_module.ComputationModule(
        iree_compiler_module, 'foo',
        computation_types.FunctionType(tf.float32, tf.float32))
    computation_callable = runtime.ComputationCallable(
        my_computation_module, backend_info.VULKAN_SPIRV)
    self.assertTrue(callable(computation_callable))
    result = computation_callable(np.float32(5.0))
    self.assertEqual(result, 6.0) 
开发者ID:tensorflow,项目名称:federated,代码行数:20,代码来源:runtime_test.py

示例5: test_module_class_with_add_one

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def test_module_class_with_add_one(self):
    tf_module = tf.Module()
    tf_module.foo = tf.function(
        lambda x: x + 1.0,
        input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
    model_dir = '/tmp/foo'
    save_options = tf.saved_model.SaveOptions(save_debug_info=True)
    tf.saved_model.save(tf_module, model_dir, options=save_options)
    iree_compiler_module = iree_compiler.tf_load_saved_model(
        model_dir, exported_names=['foo'])
    my_computation_module = computation_module.ComputationModule(
        iree_compiler_module, 'foo',
        computation_types.FunctionType(tf.float32, tf.float32))
    self.assertIs(my_computation_module.compiler_module, iree_compiler_module)
    self.assertEqual(my_computation_module.function_name, 'foo')
    self.assertEqual(
        str(my_computation_module.type_signature), '(float32 -> float32)') 
开发者ID:tensorflow,项目名称:federated,代码行数:19,代码来源:computation_module_test.py

示例6: _create_tflite_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def _create_tflite_model():
    if not tvm.runtime.enabled("tflite"):
        print("skip because tflite runtime is not enabled...")
        return
    if not tvm.get_global_func("tvm.tflite_runtime.create", True):
        print("skip because tflite runtime is not enabled...")
        return

    try:
        import tensorflow as tf
    except ImportError:
        print('skip because tensorflow not installed...')
        return

    root = tf.Module()
    root.const = tf.constant([1., 2.], tf.float32)
    root.f = tf.function(lambda x: root.const * x)

    input_signature = tf.TensorSpec(shape=[2,  ], dtype=tf.float32)
    concrete_func = root.f.get_concrete_function(input_signature)
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    tflite_model = converter.convert()
    return tflite_model 
开发者ID:apache,项目名称:incubator-tvm,代码行数:25,代码来源:test_tflite_runtime.py

示例7: _get_direct_children

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def _get_direct_children(layer):
  children = []
  for name, attr in layer.__dict__.items():
    if name.startswith("_"):
      continue
    if (isinstance(attr, tf.Module)
        or (isinstance(attr, list) and attr and isinstance(attr[0], tf.Module))):
      children.append((name, attr))
  return children 
开发者ID:OpenNMT,项目名称:OpenNMT-tf,代码行数:11,代码来源:misc.py

示例8: save_checkpoint

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def save_checkpoint(self, state: Any, round_num: int) -> None:
    """Saves a new checkpointed `state` for the given `round_num`.

    Args:
      state: A nested structure which `tf.convert_to_tensor` supports.
      round_num: An integer representing the current training round.
    """
    basename = '{}{}'.format(self._prefix, round_num)
    checkpoint_path = os.path.join(self._root_dir, basename)
    flat_obj = tf.nest.flatten(state)
    model = tf.Module()
    model.obj = flat_obj
    model.build_obj_fn = tf.function(lambda: model.obj, input_signature=())

    # First write to a temporary directory.
    temp_basename = '.temp_{}'.format(basename)
    temp_path = os.path.join(self._root_dir, temp_basename)
    try:
      tf.io.gfile.rmtree(temp_path)
    except tf.errors.NotFoundError:
      pass
    tf.io.gfile.makedirs(temp_path)
    tf.saved_model.save(model, temp_path, signatures={})

    # Rename the temp directory to the final location atomically.
    tf.io.gfile.rename(temp_path, checkpoint_path)
    logging.info('Checkpoint saved: %s', checkpoint_path)

    self._clear_old_checkpoints() 
开发者ID:tensorflow,项目名称:federated,代码行数:31,代码来源:checkpoint_manager.py

示例9: save

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def save(obj, export_dir, prefix=None):
  r"""Save a nested structure to `export_dir`.

  Note: to be compatible with `latest_checkpoint`, the basename of `export_dir`
  must follow the regular expression pattern `<prefix>\d+`, where the final
  digit matcher  determines the ordering of the checkpoints.

  Args:
    obj: A nested structure which `tf.convert_to_tensor` supports.
    export_dir: A directory in which to write the state.
    prefix: The common prefix shared by all checkpoint directories. If provided,
      we will fail if the export directory doesn't match this prefix. If not
      provided, no check will be performed.

  Raises:
    ValueError: If `prefix` is provided and `export_dir` doesn't use the prefix.
  """
  if prefix is not None and get_serial_number(export_dir, prefix) < 0:
    raise ValueError('Checkpoint dir "{}" is not named like "{}XXXX!'.format(
        export_dir, prefix))

  model = tf.Module()
  model.obj = tf.nest.flatten(obj)
  model.build_obj_fn = tf.function(lambda: model.obj, input_signature=())

  # First write to a temporary directory.
  temp_export_dir = os.path.join(
      os.path.dirname(export_dir), '.temp_' + os.path.basename(export_dir))
  try:
    tf.io.gfile.rmtree(temp_export_dir)
  except tf.errors.NotFoundError:
    pass
  tf.io.gfile.makedirs(temp_export_dir)
  tf.saved_model.save(model, temp_export_dir, signatures={})

  # Rename the temp directory to the final location atomically.
  tf.io.gfile.rename(temp_export_dir, export_dir)
  logging.info('Checkpoint saved to: %s', export_dir) 
开发者ID:tensorflow,项目名称:federated,代码行数:40,代码来源:checkpoint_utils.py

示例10: serialize_dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import Module [as 别名]
def serialize_dataset(
    dataset,
    max_serialized_size_bytes=DEFAULT_MAX_SERIALIZED_SEQUENCE_SIZE_BYTES):
  """Serializes a `tf.data.Dataset` value into a `bytes` object.

  Args:
    dataset: A `tf.data.Dataset`.
    max_serialized_size_bytes: An `int` size in bytes designating the threshold
      on when to raise an error if the resulting serialization is too big.

  Returns:
    A `bytes` object that can be sent to
  `tensorflow_serialization.deserialize_dataset` to recover the original
  `tf.data.Dataset`.

  Raises:
    SerializationError: if there was an error in TensorFlow during
      serialization.
  """
  py_typecheck.check_type(dataset,
                          type_conversions.TF_DATASET_REPRESENTATION_TYPES)
  module = tf.Module()
  module.dataset = dataset
  module.dataset_fn = tf.function(lambda: module.dataset, input_signature=())

  temp_dir = tempfile.mkdtemp('dataset')
  fd, temp_zip = tempfile.mkstemp('zip')
  os.close(fd)
  try:
    tf.saved_model.save(module, temp_dir, signatures={})
    with zipfile.ZipFile(temp_zip, 'w') as z:
      for topdir, _, filenames in tf.io.gfile.walk(temp_dir):
        dest_dir = topdir[len(temp_dir):]
        for filename in filenames:
          z.write(
              os.path.join(topdir, filename), os.path.join(dest_dir, filename))
    with open(temp_zip, 'rb') as z:
      zip_bytes = z.read()
  except Exception as e:  # pylint: disable=broad-except
    raise SerializationError(
        'Error serializing tff.Sequence value. Inner error: {!s}'.format(
            e)) from e
  finally:
    tf.io.gfile.rmtree(temp_dir)
    tf.io.gfile.remove(temp_zip)

  if len(zip_bytes) > max_serialized_size_bytes:
    raise ValueError('Serialized size of Dataset ({:d} bytes) exceeds maximum '
                     'allowed ({:d} bytes)'.format(
                         len(zip_bytes), max_serialized_size_bytes))
  return zip_bytes 
开发者ID:tensorflow,项目名称:federated,代码行数:53,代码来源:tensorflow_serialization.py


注:本文中的tensorflow.Module方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。