当前位置: 首页>>代码示例>>Python>>正文


Python op_def_registry.get_registered_ops方法代码示例

本文整理汇总了Python中tensorflow.python.framework.op_def_registry.get_registered_ops方法的典型用法代码示例。如果您正苦于以下问题:Python op_def_registry.get_registered_ops方法的具体用法?Python op_def_registry.get_registered_ops怎么用?Python op_def_registry.get_registered_ops使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.framework.op_def_registry的用法示例。


在下文中一共展示了op_def_registry.get_registered_ops方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _get_ref_args

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _get_ref_args(self, node):
    """Determine whether an input of an op is ref-type.

    Args:
      node: A `NodeDef`.

    Returns:
      A list of the arg names (as strs) that are ref-type.
    """
    op_def = op_def_registry.get_registered_ops().get(node.op)
    ref_args = []
    if op_def:
      for i, output_arg in enumerate(op_def.output_arg):
        if output_arg.is_ref:
          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
          ref_args.append(arg_name)
    return ref_args 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:19,代码来源:debug_graphs.py

示例2: stripped_op_list_for_graph

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def stripped_op_list_for_graph(graph_def):
  """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
  # This is the Python equivalent of StrippedOpListForGraph in C++.
  # Unfortunately, since the Python op registry can differ from that in C++, we
  # can't remove the duplication using swig (at least naively).
  # TODO(irving): Support taking graphs directly.

  used_ops = ops_used_by_graph_def(graph_def)

  # Verify that all used ops are registered.
  registered_ops = op_def_registry.get_registered_ops()
  # These internal ops used by functions are not registered, so we need to
  # whitelist them.  # TODO(irving): Do something better here.
  op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
  for op in used_ops:
    if op not in registered_ops and op not in op_whitelist:
      raise ValueError("Op %s is used by the graph, but is not registered" % op)

  # Build the stripped op list in sorted order
  return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
                               if op in registered_ops]) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:38,代码来源:meta_graph.py

示例3: _get_op_def

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _get_op_def(op):
  return op.op_def or op_def_registry.get_registered_ops()[op.type] 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:4,代码来源:function.py

示例4: __call__

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def __call__(self, inputs, state, scope=None):
    if self._compile_stateful:
      compile_ops = True
    else:
      def compile_ops(node_def):
        global _REGISTERED_OPS
        if _REGISTERED_OPS is None:
          _REGISTERED_OPS = op_def_registry.get_registered_ops()
        return not _REGISTERED_OPS[node_def.op].is_stateful

    with jit.experimental_jit_scope(compile_ops=compile_ops):
      return self._cell(inputs, state, scope) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:14,代码来源:rnn_cell.py

示例5: _get_op_def

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _get_op_def(op):
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    return getattr(op, "_sig")
  else:
    return op_def_registry.get_registered_ops()[op.type]
  # pylint: enable=protected-access 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:9,代码来源:function.py

示例6: __call__

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def __call__(self, inputs, state, scope=None):
    if self._compile_stateful:
      compile_ops = True
    else:
      def compile_ops(node_def):
        global _REGISTERED_OPS
        if _REGISTERED_OPS is None:
          _REGISTERED_OPS = op_def_registry.get_registered_ops()
        return not _REGISTERED_OPS[node_def.op].is_stateful

    with jit.experimental_jit_scope(compile_ops=compile_ops):
      return self._cell(inputs, state, scope=scope) 
开发者ID:shaohua0116,项目名称:Multiview2Novelview,代码行数:14,代码来源:rnn_cell.py

示例7: get

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def get(name):
    registered_ops = op_def_registry.get_registered_ops()
    return registered_ops.get(name) 
开发者ID:tensorflow,项目名称:hub,代码行数:5,代码来源:native_module.py

示例8: sync

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def sync():
    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))

    registered_ops = op_def_registry.get_registered_ops()
    for op_def in cpp_op_list.op:
      # If an OpList is registered from a gen_*_ops.py, it does not any
      # descriptions. Strip them here as well to satisfy validation in
      # register_op_list.
      _remove_non_deprecated_descriptions(op_def)
      registered_ops[op_def.name] = op_def 
开发者ID:tensorflow,项目名称:hub,代码行数:14,代码来源:native_module.py

示例9: _add_op_node

# 需要导入模块: from tensorflow.python.framework import op_def_registry [as 别名]
# 或者: from tensorflow.python.framework.op_def_registry import get_registered_ops [as 别名]
def _add_op_node(op, func):
  """Converts an op to a function def node and add it to `func`."""
  node = function_pb2.FunctionDef.Node()
  node.op = op.type
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    op_def = getattr(op, "_sig")
  else:
    op_def = op_def_registry.get_registered_ops()[op.type]
  # pylint: enable=protected-access
  attrs = _get_node_def_attr(op)
  if not op_def.output_arg:
    node.ret.append(_make_argname_from_tensor_name(op.name))
  else:
    out_index = 0
    for arg_def in op_def.output_arg:
      if arg_def.number_attr:
        dtype = arg_def.type or attrs[arg_def.type_attr].type
        num = attrs[arg_def.number_attr].i
        node.ret.append(
            _add_output_array(op, out_index, out_index + num, dtype, func))
        out_index += num
      elif arg_def.type_list_attr:
        dtype_lst = attrs[arg_def.type_list_attr].list.type
        num = len(dtype_lst)
        node.ret.append(
            _add_output_list(op, out_index, out_index + num, dtype_lst, func))
        out_index += num
      else:
        node.ret.append(
            _make_argname_from_tensor_name(op.outputs[out_index].name))
        out_index += 1
  inp_index = 0
  for arg_def in op_def.input_arg:
    if arg_def.number_attr:
      dtype = arg_def.type or attrs[arg_def.type_attr].type
      num = attrs[arg_def.number_attr].i
      node.arg.append(
          _add_input_array(op, inp_index, inp_index + num, dtype, func))
      inp_index += num
    elif arg_def.type_list_attr:
      num = len(attrs[arg_def.type_list_attr].list.type)
      node.arg.extend([
          _make_argname_from_tensor_name(op.inputs[i].name)
          for i in range(inp_index, inp_index + num)
      ])
      inp_index += num
    else:
      node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
      inp_index += 1
  node.dep.extend(
      [_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
  for k, v in _get_node_def_attr(op).items():
    node.attr[k].CopyFrom(v)
  func.node.extend([node]) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:57,代码来源:function.py


注:本文中的tensorflow.python.framework.op_def_registry.get_registered_ops方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。