当前位置: 首页>>代码示例>>Python>>正文


Python graph_to_function_def.graph_to_function_def函数代码示例

本文整理汇总了Python中tensorflow.python.framework.graph_to_function_def.graph_to_function_def函数的典型用法代码示例。如果您正苦于以下问题:Python graph_to_function_def函数的具体用法?Python graph_to_function_def怎么用?Python graph_to_function_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了graph_to_function_def函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _compute_backprop

 def _compute_backprop(self):
   """Computes the backprop function object for this function."""
   self._has_backprop = True
   with self._graph.as_default(), context.graph_mode():
     c = _CapturingContext()
     with c:
       filtered_outputs = [
           x for x in self._returns if x is not None
       ]
       self._out_grad_placeholders = [
           graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
       ]
       in_gradients = gradients_impl.gradients(
           filtered_outputs,
           self._input_placeholders,
           grad_ys=self._out_grad_placeholders)
       shapes = [x.shape for x in in_gradients if x is not None]
   captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
   forward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, self._ops, self._input_placeholders,
       filtered_outputs + captures)
   self._forward_fdef = _DefinedFunction(forward_function_def)
   _register_with_name(_forward_name(self._func_name), forward_function_def)
   backward_outputs = [x for x in in_gradients if x is not None]
   all_inputs = self._out_grad_placeholders + captures
   backward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, [x.op for x in self._out_grad_placeholders
                    ] + list(sorted(c.known_ops, key=lambda x: x.name)),
       all_inputs, backward_outputs)
   _register_with_name(_backward_name(self._func_name), backward_function_def)
   self._backward_function = _GraphModeFunction(
       all_inputs, [], backward_function_def, self._graph, c.known_ops,
       in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:33,代码来源:function.py

示例2: _create_definition_if_needed

  def _create_definition_if_needed(self):
    """Creates the function definition if it's not created yet."""

    if self._definition is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph()
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any([_ is None for _ in outputs]):
        raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Build the FunctionDef
    self._definition = graph_to_function_def.graph_to_function_def(
        temp_graph,
        temp_graph.get_operations(),
        inputs,
        outputs,
        out_names=self._out_names)

    # Extra kwargs are treated as attrs on the function def.
    sig_pre_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
                                         **self._extra_kwargs)
    for k in kwargs_attr:
      self._definition.attr[k].CopyFrom(kwargs_attr[k])

    # Hash the definition and its dependencies.
    self._hash_str = self._create_hash_str(
        self._definition.signature.input_arg,
        self._definition.signature.output_arg, self._definition.node_def)

    # Finally, we decide the function name to use.  If not specified,
    # make up something which is almost certainly unique (but deterministic).
    if not self._func_name:
      self._func_name = "_".join([_get_func_name(self._func), self._hash_str])
    self._definition.signature.name = self._func_name
    if self._func.__doc__:
      self._definition.signature.description = self._func.__doc__
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:57,代码来源:function.py

示例3: testTwoInputsSameOp

 def testTwoInputsSameOp(self):
   g = ops.Graph()
   with g.as_default():
     m = array_ops.placeholder(dtypes.float32)
     s, u, v = linalg_ops.svd(m)
     ss = math_ops.reduce_sum(s)
     uu = math_ops.reduce_sum(u)
     vv = math_ops.reduce_sum(v)
     result = ss + uu + vv
   f = graph_to_function_def.graph_to_function_def(
       g,
       g.get_operations()[1:],  # skip the placeholder
       [s, u, v],
       [result])
   self.assertEqual(len(f.signature.input_arg), 3)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:15,代码来源:function_test.py

示例4: _defun_internal

def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  with context.graph_mode():
    tmp_graph = ops.Graph()
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      captures = {}
      with capture_tensors(captures):
        func_outputs = func(*func_inputs, **kwds)
      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      outputs_list = nest.flatten(func_outputs)
      output_shapes = [x.shape for x in outputs_list if x is not None]

  flat_inputs = [
      x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
  ]
  all_inputs = flat_inputs + list(extra_placeholders)

  func_def_outputs = [x for x in outputs_list if x is not None]
  inference_function_def = graph_to_function_def.graph_to_function_def(
      tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    # TODO(ashankar): What about the gradient registry?
    _register_with_name(f.name, f.definition)
  _register_with_name(_inference_name(name), inference_function_def)

  return _GraphModeFunction(
      all_inputs, extra_inputs, inference_function_def, tmp_graph,
      tmp_graph.get_operations(), func_outputs,
      _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:45,代码来源:function.py

示例5: _build_function_def

  def _build_function_def(self):
    with ops.Graph().as_default() as g:
      # Inputs
      x = array_ops.placeholder(dtypes.float32, name="x")
      y = array_ops.placeholder(dtypes.float32, name="y")

      # Outputs
      sum_squares = math_ops.add_n(
          [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
      sum_cubes = math_ops.add_n(
          [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
    fdef = graph_to_function_def.graph_to_function_def(
        g,
        g.get_operations(),
        [x, y],  # Inputs
        [sum_squares, sum_cubes])  # Outputs.
    fdef.signature.name = "_whats_in_a_name"
    return fdef
开发者ID:aeverall,项目名称:tensorflow,代码行数:18,代码来源:function_def_to_graph_test.py

示例6: make_function_def

def make_function_def(graph, operations, inputs, outputs):
  """Makes function def where accesses to resources are serialized."""
  last_op_using_resource_tensor = {}

  # TODO(apassos) probably control flow has to be handled delicately here as in
  # if a resource is accessed inside a control flow context we need the control
  # dependency to point to something outside the context which is guaranteed to
  # happen after the access.
  #
  # TODO(apassos) this should do some form of alias analysis as ops which
  # forward the resources such as Identity and Switch can cause serialization to
  # fail.
  for op in operations:
    for t in op.inputs:
      if t.dtype == dtypes.resource:
        if t.name in last_op_using_resource_tensor:
          op._add_control_input(last_op_using_resource_tensor[t.name])  # pylint: disable=protected-access
        last_op_using_resource_tensor[t.name] = op
  return graph_to_function_def.graph_to_function_def(
      graph, operations, inputs, outputs)
开发者ID:SylChan,项目名称:tensorflow,代码行数:20,代码来源:function.py

示例7: _create_definition_if_needed_impl

  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    temp_graph = func_graph_from_py_func(
        self._func, self._arg_names, self._arg_types, self._func_name,
        self._capture_by_value, self._caller_device)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = _get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op.op_def.is_stateful]
开发者ID:didukhle,项目名称:tensorflow,代码行数:78,代码来源:function.py

示例8: _create_definition_if_needed_impl

  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)

      # There is no way of distinguishing between a function not returning
      # anything and a function returning None in Python.
      # We need to allow the former and ideally want to forbid the latter as
      # it is most likely user error.
      # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
      # allow users to explicitly mark the function as not returning anything.
      # For now, we allow a single None return and interpret it as a function
      # with no output.
      if outputs is None:
        outputs = []
      else:
        # If func only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
          outputs = (outputs,)
        if any([_ is None for _ in outputs]):
          raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    base_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                         **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          inputs,
          outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      with errors.raise_exception_on_not_ok_status() as status:
        self._c_func = c_api.TF_GraphToFunction_wrapper(
            temp_graph._c_graph,
            base_func_name,
            self._func_name is None,  # append_hash_to_fn_name
            None,  # opers
            [t._as_tf_output() for t in inputs],
            [t._as_tf_output() for t in outputs],
            output_names,
            None,  # opts
            description,
            status)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:98,代码来源:function.py

示例9: _create_definition_if_needed_impl

  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Copy variable collections (by reference) from the parent graph such that
    # name based variable sharing (e.g. via tf.make_template) works between the
    # func graph and parent graph.
    variable_keys = []
    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access

    collections_ref = {}
    parent_collections_ref = ops.get_default_graph()._collections  # pylint: disable=protected-access
    for key in variable_keys:
      if key not in parent_collections_ref:
        parent_collections_ref[key] = collections_ref[key] = []
      else:
        collections_ref[key] = parent_collections_ref[key]

    temp_graph = func_graph_from_py_func(
        self._func,
        self._arg_names,
        self._arg_types,
        self._func_name,
        self._capture_by_value,
        self._caller_device,
        collections_ref=collections_ref,
        whitelisted_stateful_ops=self._whitelisted_stateful_ops,
        capture_resource_var_by_value=self._capture_resource_var_by_value)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          [], # control_outputs
          [], # control_output_names
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
#.........这里部分代码省略.........
开发者ID:aritratony,项目名称:tensorflow,代码行数:101,代码来源:function.py


注:本文中的tensorflow.python.framework.graph_to_function_def.graph_to_function_def函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。