当前位置: 首页>>代码示例>>Python>>正文


Python function._from_library函数代码示例

本文整理汇总了Python中tensorflow.python.framework.function._from_library函数的典型用法代码示例。如果您正苦于以下问题:Python _from_library函数的具体用法?Python _from_library怎么用?Python _from_library使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了_from_library函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testFromLibraryMissingFuncDef

  def testFromLibraryMissingFuncDef(self):

    @function.Defun(dtypes.float32, dtypes.float32)
    def G1(x, dy):
      return x * dy

    @function.Defun(dtypes.float32)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    gradient = function_pb2.GradientDef()
    gradient.function_name = F1.name
    gradient.gradient_func = G1.name

    # Create invalid function def that is missing G1 function def
    library = function_pb2.FunctionDefLibrary()
    library.gradient.extend([gradient])
    library.function.extend([F1.definition])

    with self.assertRaisesRegexp(
        ValueError,
        "FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
      function._from_library(library)

    # Create invalid function def that is missing F1 function def
    library = function_pb2.FunctionDefLibrary()
    library.gradient.extend([gradient])
    library.function.extend([G1.definition])

    with self.assertRaisesRegexp(
        ValueError,
        "FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
      function._from_library(library)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:33,代码来源:function_test.py

示例2: testFromLibraryCyclicGradFuncs

  def testFromLibraryCyclicGradFuncs(self):

    @function.Defun(dtypes.float32)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    @function.Defun(dtypes.float32)
    def F2(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # Create invalid function def library where F1 has gradient function F2 and
    # F2 has gradient function F1
    library = function_pb2.FunctionDefLibrary()
    library.function.extend([F1.definition, F2.definition])

    gradient1 = function_pb2.GradientDef()
    gradient1.function_name = F1.name
    gradient1.gradient_func = F2.name

    gradient2 = function_pb2.GradientDef()
    gradient2.function_name = F2.name
    gradient2.gradient_func = F1.name

    library.gradient.extend([gradient1, gradient2])

    with self.assertRaisesRegexp(
        ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
      function._from_library(library)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:28,代码来源:function_test.py

示例3: testFromLibrary

  def testFromLibrary(self):
    # Define some functions with different gradient functions. Note that many of
    # the below functions are identical since function bodies don't matter for
    # this test.

    @function.Defun(dtypes.float32, dtypes.float32)
    def G1(x, dy):
      return x * dy

    @function.Defun(dtypes.float32, dtypes.float32)
    def G2(x, dy):
      return x * dy

    # F1 and F2 have the same gradient function
    @function.Defun(dtypes.float32, grad_func=G1)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    @function.Defun(dtypes.float32, grad_func=G1)
    def F2(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # F3 has a different gradient function
    @function.Defun(dtypes.float32, grad_func=G2)
    def F3(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # F4 has no gradient function
    @function.Defun(dtypes.float32)
    def F4(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    # Instantiate all functions
    g = ops.Graph()
    with g.as_default():
      c = constant_op.constant(1.0, dtypes.float32)
      f1 = F1(c)
      f2 = F2(c)
      f3 = F3(c)
      f4 = F4(c)
      gradients_impl.gradients([f1, f2, f3, f4], c)

    library = g.as_graph_def().library
    new_funcs = function._from_library(library)

    def CheckNewFunc(func):
      new_func = [f for f in new_funcs if f.name == func.name]
      self.assertEqual(len(new_func), 1)
      self.expectFunctionsEqual(func, new_func=new_func[0])

    CheckNewFunc(G1)
    CheckNewFunc(G2)
    CheckNewFunc(F1)
    CheckNewFunc(F2)
    CheckNewFunc(F3)
    CheckNewFunc(F4)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:56,代码来源:function_test.py

示例4: testFromLibraryEmptyLib

 def testFromLibraryEmptyLib(self):
   library = function_pb2.FunctionDefLibrary()
   self.assertEqual(len(function._from_library(library)), 0)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:3,代码来源:function_test.py

示例5: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

#.........这里部分代码省略.........
开发者ID:Immexxx,项目名称:tensorflow,代码行数:101,代码来源:importer.py

示例6: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
#.........这里部分代码省略.........
开发者ID:andrewharp,项目名称:tensorflow,代码行数:101,代码来源:importer.py

示例7: import_graph_def


#.........这里部分代码省略.........
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  op_dict = op_def_registry.get_registered_ops()

  graph_def = _ProcessGraphDefParam(graph_def, op_dict)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # Save unique prefix generated by name_scope
    if scope:
      assert scope.endswith('/')
      prefix = scope[:-1]
    else:
      prefix = ''

    # Generate any input map tensors inside name scope
    input_map = _ConvertInputMapValues(name, input_map)

  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
  options = scoped_options.options
  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                   return_elements)

  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
  # Session.run call cannot occur between creating the TF_Operations in the
  # TF_GraphImportGraphDefWithResults call and mutating the them in
  # _ProcessNewOps.
  with graph._mutation_lock():  # pylint: disable=protected-access
    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        results = c_api.TF_GraphImportGraphDefWithResults(
            graph._c_graph, serialized, options)  # pylint: disable=protected-access
        results = c_api_util.ScopedTFImportGraphDefResults(results)
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
    # _USE_C_SHAPES is removed.
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    _ProcessNewOps(graph)

  # Treat input mappings that don't appear in the graph as an error, because
  # they are likely to be due to a typo.
  missing_unused_input_keys = (
      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
          results.results))
  if missing_unused_input_keys:
    missing_unused_input_keys = [
        compat.as_str(s) for s in missing_unused_input_keys
    ]
    raise ValueError(
        'Attempted to map inputs that were not found in graph_def: [%s]' %
        ', '.join(missing_unused_input_keys))

  if return_elements is None:
    return None
  else:
    return _GatherReturnElements(return_elements, graph, results.results)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:101,代码来源:importer.py


注:本文中的tensorflow.python.framework.function._from_library函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。