当前位置: 首页>>代码示例>>Python>>正文


Python op_def_registry.get_registered_ops函数代码示例

本文整理汇总了Python中tensorflow.python.framework.op_def_registry.get_registered_ops函数的典型用法代码示例。如果您正苦于以下问题:Python get_registered_ops函数的具体用法?Python get_registered_ops怎么用?Python get_registered_ops使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_registered_ops函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: stripped_op_list_for_graph

def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
    # This is the Python equivalent of StrippedOpListForGraph in C++.
    # Unfortunately, since the Python op registry can differ from that in C++, we
    # can't remove the duplication using swig (at least naively).
    # TODO(irving): Support taking graphs directly.

    used_ops = ops_used_by_graph_def(graph_def)

    # Verify that all used ops are registered.
    registered_ops = op_def_registry.get_registered_ops()
    # These internal ops used by functions are not registered, so we need to
    # whitelist them.  # TODO(irving): Do something better here.
    op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
    for op in used_ops:
        if op not in registered_ops and op not in op_whitelist:
            raise ValueError("Op %s is used by the graph, but is not registered" % op)

    # Build the stripped op list in sorted order
    return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
开发者ID:rhuangq,项目名称:tensorflow,代码行数:35,代码来源:meta_graph.py

示例2: testStripDefaultAttrsInconsistentConsumerDefaults

  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
    if ops._USE_C_API: return  # TODO(skyewm): get this working

    export_dir = self._get_export_dir(
        "test_strip_default_attrs_no_consumer_defaults")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled. This must remove the following
    # defaults for the "Complex" Op:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
    # the "Complex" OpDef.
    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
    original_complex_op_def = op_def_pb2.OpDef()
    original_complex_op_def.CopyFrom(complex_op_def)
    for attr_def in complex_op_def.attr:
      attr_def.ClearField("default_value")

    # Loading the SavedModel via the loader must fail because the SavedModel
    # does not have any attr values for the "Complex" node and the current
    # op registry does not have have any default values for the "Complex" op.
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        ValueError,
        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
      loader.load(sess, ["foo"], export_dir)

    # Update the Op registry to change the defaults for attr "Tout"
    # (complex64 -> complex128).
    complex_op_def.CopyFrom(original_complex_op_def)
    for attr_def in complex_op_def.attr:
      if attr_def.name == "Tout":
        attr_def.default_value.type = types_pb2.DT_COMPLEX128

    # Loading the SavedModel via the loader must set "Tout" attr_value for the
    # "Complex" node according to the latest defaults (complex128). This is
    # expected to fail the model import as there is no OpKernel registered to
    # handle attrs "T" (float32) and "Tout" (complex128).
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        ".*No OpKernel was registered to support Op \'Complex\' with these "
        "attrs..*"):
      loader.load(sess, ["foo"], export_dir)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:57,代码来源:saved_model_test.py

示例3: _is_array_type_input

def _is_array_type_input(op, i):
    registered_ops = op_def_registry.get_registered_ops()
    if op not in registered_ops:
        return False
    op_def = registered_ops[op]
    if i not in xrange(len(op_def.input_arg)):
        raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg))
    input_arg = op_def.input_arg[i]
    return True if input_arg.number_attr else False
开发者ID:yuanms2,项目名称:tensorflow,代码行数:9,代码来源:function.py

示例4: _register_function_ops

def _register_function_ops(func_list):
  """Registers custom ops in the default graph. This is needed
  Because our checkpoint is saved with ops that are not part of Tensorflow."""
  op_dict = op_def_registry.get_registered_ops()
  for func in func_list:
    #pylint: disable=W0212
    func._create_definition_if_needed()
    op_def = func._definition.signature
    op_dict[op_def.name] = op_def
    RegisterShape(op_def.name)(common_shapes.unknown_shape)
开发者ID:AbhinavJain13,项目名称:seq2seq,代码行数:10,代码来源:profile.py

示例5: _add_op_node

def _add_op_node(op, func):
  """Converts an op to a function def node and add it to `func`."""
  node = function_pb2.FunctionDef.Node()
  node.op = op.type
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    op_def = getattr(op, "_sig")
  else:
    op_def = op_def_registry.get_registered_ops()[op.type]
  # pylint: enable=protected-access
  attrs = _get_node_def_attr(op)
  if not op_def.output_arg:
    node.ret.append(_make_argname_from_tensor_name(op.name))
  else:
    out_index = 0
    for arg_def in op_def.output_arg:
      if arg_def.number_attr:
        dtype = arg_def.type or attrs[arg_def.type_attr].type
        num = attrs[arg_def.number_attr].i
        node.ret.append(
            _add_output_array(op, out_index, out_index + num, dtype, func))
        out_index += num
      elif arg_def.type_list_attr:
        dtype_lst = attrs[arg_def.type_list_attr].list.type
        num = len(dtype_lst)
        node.ret.append(
            _add_output_list(op, out_index, out_index + num, dtype_lst, func))
        out_index += num
      else:
        node.ret.append(
            _make_argname_from_tensor_name(op.outputs[out_index].name))
        out_index += 1
  inp_index = 0
  for arg_def in op_def.input_arg:
    if arg_def.number_attr:
      dtype = arg_def.type or attrs[arg_def.type_attr].type
      num = attrs[arg_def.number_attr].i
      node.arg.append(
          _add_input_array(op, inp_index, inp_index + num, dtype, func))
      inp_index += num
    elif arg_def.type_list_attr:
      num = len(attrs[arg_def.type_list_attr].list.type)
      node.arg.extend([
          _make_argname_from_tensor_name(op.inputs[i].name)
          for i in range(inp_index, inp_index + num)
      ])
      inp_index += num
    else:
      node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
      inp_index += 1
  node.dep.extend(
      [_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
  for k, v in _get_node_def_attr(op).items():
    node.attr[k].CopyFrom(v)
  func.node.extend([node])
开发者ID:KalraA,项目名称:tensorflow,代码行数:55,代码来源:function.py

示例6: _stripped_op_list_for_graph

def _stripped_op_list_for_graph(graph_def):
  """Returns OpDefs of ops used in graph_def."""
  op_set = set()
  registered_ops = op_def_registry.get_registered_ops()
  for n in graph_def.node:
    if n.op in registered_ops:
      op_set.add(n.op)
  for func in graph_def.library.function:
    for n in func.node:
      if n.op in registered_ops:
        op_set.add(n.op)
  return op_def_pb2.OpList(op=[registered_ops[x] for x in sorted(op_set)])
开发者ID:sherrym,项目名称:tensorflow,代码行数:12,代码来源:saver.py

示例7: list_registered_stateful_ops_without_inputs

def list_registered_stateful_ops_without_inputs():
  """Returns set of registered stateful ops that do not expect inputs.

  This list is used to identify the ops to be included in the state-graph and
  that are subsequently fed into the apply-graphs.

  Returns:
    A set of strings.
  """
  return set([
      name
      for name, op in op_def_registry.get_registered_ops().items()
      if op.is_stateful and not op.input_arg
  ])
开发者ID:nicolas-ivanov,项目名称:hub,代码行数:14,代码来源:native_module.py

示例8: _get_ref_args

  def _get_ref_args(self, node):
    """Determine whether an input of an op is ref-type.

    Args:
      node: A `NodeDef`.

    Returns:
      A list of the arg names (as strs) that are ref-type.
    """
    op_def = op_def_registry.get_registered_ops().get(node.op)
    ref_args = []
    if op_def:
      for i, output_arg in enumerate(op_def.output_arg):
        if output_arg.is_ref:
          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
          ref_args.append(arg_name)
    return ref_args
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:17,代码来源:debug_graphs.py

示例9: _strip_graph_default_valued_attrs

def _strip_graph_default_valued_attrs(meta_graph_def):
  """Strips default valued attributes for node defs in given MetaGraphDef.

  This method also sets `meta_info_def.stripped_default_attrs` in the given
  `MetaGraphDef` proto to True.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer

  Returns:
    None.
  """
  # Map function op names to their function definitions.
  op_name_to_function = {}
  for function_def in meta_graph_def.graph_def.library.function:
    op_name_to_function[function_def.signature.name] = function_def

  # Get all registered ops.
  registered_ops = op_def_registry.get_registered_ops()

  def _strip_node_default_valued_attrs(node_def):
    """Removes default valued attributes from a single node def."""
    if node_def.op in op_name_to_function or node_def.op not in registered_ops:
      return
    op_def = registered_ops[node_def.op]

    attrs_to_strip = set()
    for attr_name, attr_value in node_def.attr.items():
      if _is_default_attr_value(op_def, attr_name, attr_value):
        attrs_to_strip.add(attr_name)

    for attr in attrs_to_strip:
      del node_def.attr[attr]

  # Process all NodeDef instances in graph_def.
  for node_def in meta_graph_def.graph_def.node:
    _strip_node_default_valued_attrs(node_def)

  # Process all NodeDef instances in graph_def.library.function.
  for function_def in meta_graph_def.graph_def.library.function:
    for function_node_def in function_def.node_def:
      _strip_node_default_valued_attrs(function_node_def)

  # Tell consumers of this graph that default valued attrs have been stripped.
  meta_graph_def.meta_info_def.stripped_default_attrs = True
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:45,代码来源:meta_graph.py

示例10: register_ops_if_needed

def register_ops_if_needed(graph_ops):
  """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
  missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

  if not missing_ops:
    return

  p_buffer = c_api.TF_GetAllOpList()
  cpp_op_list = op_def_pb2.OpList()
  cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
  cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

  missing_op_list = op_def_pb2.OpList()
  for missing_op in missing_ops:
    if missing_op not in cpp_registry_ops:
      tf.logging.info(
          "Op %s is missing from both the python and C++ registry.",
          missing_op)
    else:
      missing_op_list.op.extend([cpp_registry_ops[missing_op]])
      tf.logging.info(
          "Adding op %s from c++ registry to python registry.",
          missing_op)

  op_def_registry.register_op_list(missing_op_list)

  # Note: Only raise missing op ValueError after trying to load ops.
  # This allows the test to exercise all the calls into TensorFlow
  # without having to write a C + python test.
  if not missing_ops <= set(cpp_registry_ops.keys()):
    raise RuntimeError(
        "Graph ops missing from the python registry (%s) are also absent from "
        "the c++ registry."
        % missing_ops.difference(set(cpp_registry_ops.keys())))
开发者ID:nicolas-ivanov,项目名称:hub,代码行数:42,代码来源:native_module.py

示例11: _stripped_op_list_for_graph

def _stripped_op_list_for_graph(graph_def):
  registered_ops = op_def_registry.get_registered_ops()
  used_ops = {n.op for n in graph_def.node}
  op_list = [registered_ops[op_name] for op_name in sorted(used_ops)]
  return op_def_pb2.OpList(op=op_list)
开发者ID:CdricGmd,项目名称:tensorflow,代码行数:5,代码来源:saver.py

示例12: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

#.........这里部分代码省略.........
开发者ID:Immexxx,项目名称:tensorflow,代码行数:101,代码来源:importer.py

示例13: _get_op_def

def _get_op_def(op):
  return op.op_def or op_def_registry.get_registered_ops()[op.type]
开发者ID:cameronphchen,项目名称:tensorflow,代码行数:2,代码来源:function.py

示例14: _get_op_def

def _get_op_def(op):
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    return getattr(op, "_sig")
  else:
    return op_def_registry.get_registered_ops()[op.type]
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:6,代码来源:function.py

示例15: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
#.........这里部分代码省略.........
开发者ID:andrewharp,项目名称:tensorflow,代码行数:101,代码来源:importer.py


注:本文中的tensorflow.python.framework.op_def_registry.get_registered_ops函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。