当前位置: 首页>>代码示例>>Python>>正文


Python ops.get_collection_proto_type函数代码示例

本文整理汇总了Python中tensorflow.python.framework.ops.get_collection_proto_type函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection_proto_type函数的具体用法?Python get_collection_proto_type怎么用?Python get_collection_proto_type使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_collection_proto_type函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: assert_meta_graph_protos_equal

def assert_meta_graph_protos_equal(tester, a, b):
  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
  # Carefully check the collection_defs
  tester.assertEqual(set(a.collection_def), set(b.collection_def))
  collection_keys = a.collection_def.keys()
  for k in collection_keys:
    a_value = a.collection_def[k]
    b_value = b.collection_def[k]
    proto_type = ops.get_collection_proto_type(k)
    if proto_type:
      a_proto = proto_type()
      b_proto = proto_type()
      # Number of entries in the collections is the same
      tester.assertEqual(len(a_value.bytes_list.value),
                         len(b_value.bytes_list.value))
      for (a_value_item, b_value_item) in zip(
          a_value.bytes_list.value,
          b_value.bytes_list.value):
        a_proto.ParseFromString(a_value_item)
        b_proto.ParseFromString(b_value_item)
        tester.assertProtoEquals(a_proto, b_proto)
    else:
      tester.assertEquals(a_value, b_value)
  # Compared the fields directly, remove their raw values from the
  # proto comparison below.
  a.ClearField("collection_def")
  b.ClearField("collection_def")
  tester.assertProtoEquals(a, b)
开发者ID:LUTAN,项目名称:tensorflow,代码行数:28,代码来源:test_util.py

示例2: add_collection_def

def add_collection_def(meta_graph_def, key, graph=None,
                       export_scope=None):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
  """
  if graph and not isinstance(graph, ops.Graph):
    raise TypeError("graph must be of type Graph, not %s", type(graph))

  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s", type(key))
    return

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  collection_list = graph.get_collection(key)
  if not collection_list:
    return

  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x, export_scope=export_scope)
        if proto:
          assert isinstance(proto, proto_type)
          getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        for x in collection_list:
          if not export_scope or x.name.startswith(export_scope):
            getattr(col_def, kind).value.append(
                ops.strip_name_scope(x.name, export_scope))
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s", key, str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:59,代码来源:meta_graph.py

示例3: _import_meta_graph_def

def _import_meta_graph_def(meta_graph_def):
  """Recreates a Graph saved in a a `MetaGraphDef` proto.

  This function adds all the nodes from the meta graph def proto to the current
  graph, recreates all the collections, and returns a saver from saver_def.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer.

  Returns:
    A saver constructed rom `saver_def` in `meta_graph_def`.
  """
  # Gathers the list of nodes we are interested in.
  importer.import_graph_def(meta_graph_def.graph_def, name="")

  # Restores all the other collections.
  for key, col_def in meta_graph_def.collection_def.items():
    kind = col_def.WhichOneof("kind")
    if kind is None:
      logging.error("Cannot identify data type for collection %s. Skipping."
                    % key)
      continue
    from_proto = ops.get_from_proto_function(key)
    if from_proto:
      assert kind == "bytes_list"
      proto_type = ops.get_collection_proto_type(key)
      for value in col_def.bytes_list.value:
        proto = proto_type()
        proto.ParseFromString(value)
        ops.add_to_collection(key, from_proto(proto))
    else:
      field = getattr(col_def, kind)
      if kind == "node_list":
        for value in field.value:
          col_op = ops.get_default_graph().as_graph_element(value)
          ops.add_to_collection(key, col_op)
      elif kind == "int64_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python2 distinguishes between int and long, while Python3 has
        # only int.
        for value in field.value:
          ops.add_to_collection(key, int(value))
      else:
        for value in field.value:
          ops.add_to_collection(key, value)

  if meta_graph_def.HasField("saver_def"):
    return Saver(saver_def=meta_graph_def.saver_def)
  else:
    return Saver()
开发者ID:AboorvaDevarajan,项目名称:tensorflow,代码行数:50,代码来源:saver.py

示例4: _get_all_protos_from_collection

def _get_all_protos_from_collection(meta_graph_def, collection_key):
  """Obtain node names from a collection."""
  if collection_key not in meta_graph_def.collection_def:
    return []
  collection = meta_graph_def.collection_def[collection_key]
  if not collection.bytes_list.value:
    raise ValueError(
        'Collection {} is present but type is not bytes_list.'.format(
            collection_key))
  proto_type = _ops.get_collection_proto_type(collection_key)
  result = []
  for value in collection.bytes_list.value:
    proto = proto_type()
    proto.ParseFromString(value)
    result.append(proto)
  return result
开发者ID:bikong2,项目名称:tensorflow,代码行数:16,代码来源:meta_graph_transform.py

示例5: add_collection_def

def add_collection_def(meta_graph_def, key):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
    if not isinstance(key, six.string_types) and not isinstance(key, bytes):
        logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key))
        return
    collection_list = ops.get_collection(key)
    if not collection_list:
        return
    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x)
                if not isinstance(proto, proto_type):
                    raise TypeError("proto %s is not type %s" % (proto, proto_type))
                getattr(col_def, kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                getattr(col_def, kind).value.extend([x.name for x in collection_list])
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend([compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def, kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Error encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef.\n%s",
            key,
            str(e),
        )
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
开发者ID:rhuangq,项目名称:tensorflow,代码行数:47,代码来源:meta_graph.py

示例6: _restore_collections

 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
   """Restores collections that we need to keep."""
   scope = ""
   for key in collection_keys:
     collection_def = src_meta_graph_def.collection_def[key]
     kind = collection_def.WhichOneof("kind")
     if kind is None:
       tf_logging.error(
           "Cannot identify data type for collection %s. Skipping.", key)
       continue
     from_proto = ops.get_from_proto_function(key)
     if from_proto and kind == "bytes_list":
       proto_type = ops.get_collection_proto_type(key)
       # It is assumed that there are no Variables Keys in collections
       for value in collection_def.bytes_list.value:
         proto = proto_type()
         proto.ParseFromString(value)
         try:
           new_value = from_proto(proto, import_scope=scope)
         except:
           continue
         dest_graph.add_to_collection(key, new_value)
     else:
       field = getattr(collection_def, kind)
       if kind == "node_list":
         for value in field.value:
           name = ops.prepend_name_scope(value, scope)
           # Since the graph has been optimized, the node may no longer
           # exists
           try:
             col_op = dest_graph.as_graph_element(name)
           except (TypeError, ValueError, KeyError) as e:
             continue
           dest_graph.add_to_collection(key, col_op)
       elif kind == "int64_list":
         # NOTE(opensource): This force conversion is to work around the
         # fact that Python2 distinguishes between int and long, while
         # Python3 has only int.
         for value in field.value:
           dest_graph.add_to_collection(key, int(value))
       else:
         for value in field.value:
           dest_graph.add_to_collection(key,
                                        ops.prepend_name_scope(value, scope))
开发者ID:aritratony,项目名称:tensorflow,代码行数:44,代码来源:trt_convert.py

示例7: _add_collection_def

def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, (str, bytes, unicode)):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception, e:  # pylint: disable=broad-except
    logging.warning("Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % str(e))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return
开发者ID:hdzz,项目名称:tensorflow,代码行数:38,代码来源:saver.py

示例8: import_scoped_meta_graph


#.........这里部分代码省略.........
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:101,代码来源:meta_graph.py

示例9: import_scoped_meta_graph

def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
  """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in meta_graph_def.collection_def.items():
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto:
        assert kind == "bytes_list"
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=import_scope))
#.........这里部分代码省略.........
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:101,代码来源:meta_graph.py


注:本文中的tensorflow.python.framework.ops.get_collection_proto_type函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。