当前位置: 首页>>代码示例>>Python>>正文


Python compat.as_str函数代码示例

本文整理汇总了Python中tensorflow.python.util.compat.as_str函数的典型用法代码示例。如果您正苦于以下问题:Python as_str函数的具体用法?Python as_str怎么用?Python as_str使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了as_str函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: meta_graph_transform

def meta_graph_transform(
    base_meta_graph_def, input_names, output_names, transforms, tags,
    checkpoint_path=None):
  """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' transform.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
  meta_graph_def = _meta_graph_pb2.MetaGraphDef()

  initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

  transformed_graph_def = _do_transforms(
      base_meta_graph_def.graph_def,
      input_names,
      output_names,
      initializer_names,
      transforms,
      base_meta_graph_def.saver_def,
      checkpoint_path)

  meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
  meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
  meta_graph_def.meta_info_def.ClearField('tags')
  for tag in tags:
    meta_graph_def.meta_info_def.tags.append(tag)

  base_op_names = [compat.as_str(node.name)
                   for node in base_meta_graph_def.graph_def.node]
  retained_op_names = [compat.as_str(node.name)
                       for node in meta_graph_def.graph_def.node]
  removed_op_names = set(base_op_names) - set(retained_op_names)

  # Copy saver, excluding any pruned nodes
  _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)

  # Copy collections, excluding any pruned nodes
  for collection_name in base_meta_graph_def.collection_def:
    _add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name,
        removed_op_names)

  # Copy signature_defs, excluding any pruned nodes
  for signature_name in base_meta_graph_def.signature_def:
    _add_pruned_signature(
        base_meta_graph_def, meta_graph_def, signature_name,
        removed_op_names)

  return meta_graph_def
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:60,代码来源:meta_graph_transform.py

示例2: _PopulateTFImportGraphDefOptions

def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements):
  """Populates the TF_ImportGraphDefOptions `options`."""
  c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
  c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
  c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)

  for input_src, input_dst in input_map.items():
    input_src = compat.as_str(input_src)
    if input_src.startswith('^'):
      src_name = compat.as_bytes(input_src[1:])
      dst_op = input_dst._as_tf_output().oper  # pylint: disable=protected-access
      c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name,
                                                           dst_op)
    else:
      src_name, src_idx = _ParseTensorName(input_src)
      src_name = compat.as_str(src_name)
      dst_output = input_dst._as_tf_output()  # pylint: disable=protected-access
      c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name,
                                                    src_idx, dst_output)
  for name in return_elements or []:
    if ':' in name:
      op_name, index = _ParseTensorName(name)
      op_name = compat.as_str(op_name)
      c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
    else:
      c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
                                                       compat.as_str(name))
开发者ID:andrewharp,项目名称:tensorflow,代码行数:28,代码来源:importer.py

示例3: _init_from_proto

  def _init_from_proto(self, hparam_def):
    """Creates a new HParams from `HParamDef` protocol buffer.

    Args:
      hparam_def: `HParamDef` protocol buffer.
    """
    assert isinstance(hparam_def, hparam_pb2.HParamDef)
    for name, value in hparam_def.hparam.items():
      kind = value.WhichOneof('kind')
      if kind.endswith('_value'):
        # Single value.
        if kind.startswith('int64'):
          # Setting attribute value to be 'int' to ensure the type is compatible
          # with both Python2 and Python3.
          self.add_hparam(name, int(getattr(value, kind)))
        elif kind.startswith('bytes'):
          # Setting attribute value to be 'str' to ensure the type is compatible
          # with both Python2 and Python3. UTF-8 encoding is assumed.
          self.add_hparam(name, compat.as_str(getattr(value, kind)))
        else:
          self.add_hparam(name, getattr(value, kind))
      else:
        # List of values.
        if kind.startswith('int64'):
          # Setting attribute value to be 'int' to ensure the type is compatible
          # with both Python2 and Python3.
          self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
        elif kind.startswith('bytes'):
          # Setting attribute value to be 'str' to ensure the type is compatible
          # with both Python2 and Python3. UTF-8 encoding is assumed.
          self.add_hparam(
              name, [compat.as_str(v) for v in getattr(value, kind).value])
        else:
          self.add_hparam(name, [v for v in getattr(value, kind).value])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:34,代码来源:hparam.py

示例4: _ProcessReturnElementsParam

def _ProcessReturnElementsParam(return_elements):
  """Type-checks and possibly canonicalizes `return_elements`."""
  if return_elements is None: return None
  if not all(isinstance(x, compat.bytes_or_text_types)
             for x in return_elements):
    raise TypeError('return_elements must be a list of strings.')
  return tuple(compat.as_str(x) for x in return_elements)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:7,代码来源:importer.py

示例5: _clean_save_and_restore

def _clean_save_and_restore(graph_def, op, removed_op_names):
  """Clean the specified save and restore op.

  Updates the dtypes attribute of the save / restore op and the associated name
  and shape tensors to remove entries for variables that have been removed.

  Args:
    graph_def: A GraphDef proto to be transformed.
    op: The save or restore op to update.
    removed_op_names: List of op names that have been removed.
  """
  name = op.name + '/tensor_names'
  shape = op.name + '/shape_and_slices'
  name_op = _find_op(graph_def, name)
  shape_op = _find_op(graph_def, shape)
  name_op_value_tensor = name_op.attr['value'].tensor
  shape_op_value_tensor = shape_op.attr['value'].tensor
  names = []
  shapes = []
  dtypes = []
  for index, value in enumerate(name_op_value_tensor.string_val):
    if not _is_removed(compat.as_str(value), removed_op_names):
      names.append(value)
      shapes.append(shape_op_value_tensor.string_val[index])
      dtypes.append(op.attr['dtypes'].list.type[index])
  name_op_value_tensor.string_val[:] = names
  name_op_value_tensor.tensor_shape.dim[0].size = len(names)
  shape_op_value_tensor.string_val[:] = shapes
  shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
  op.attr['dtypes'].list.type[:] = dtypes

  name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
  shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:33,代码来源:meta_graph_transform.py

示例6: encode_arg

  def encode_arg(arg, path):
    """A representation for this argument, for converting into signatures."""
    if isinstance(arg, ops.Tensor):
      user_specified_name = None
      try:
        user_specified_name = compat.as_str(
            arg.op.get_attr("_user_specified_name"))
      except ValueError:
        pass

      if path and user_specified_name and user_specified_name != path[0]:
        # The user has explicitly named the argument differently than the name
        # of the function argument.
        name = user_specified_name
      else:
        name = "/".join([str(p) for p in path])
      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
    if isinstance(arg, (
        int,
        float,
        bool,
        type(None),
        dtypes.DType,
        tensor_spec.TensorSpec,
    )):
      return arg
    return UnknownArgument()
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:func_graph.py

示例7: assert_equal_graph_def

def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
  """Asserts that two `GraphDef`s are (mostly) the same.

  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
  nodes, attrs, and control inputs.  Node names are used to match up nodes
  between the graphs, so the naming of nodes must be consistent.

  Args:
    actual: The `GraphDef` we have.
    expected: The `GraphDef` we expected.
    checkpoint_v2: boolean determining whether to ignore randomized attribute
        values that appear in V2 checkpoints.

  Raises:
    AssertionError: If the `GraphDef`s do not match.
    TypeError: If either argument is not a `GraphDef`.
  """
  if not isinstance(actual, graph_pb2.GraphDef):
    raise TypeError("Expected tf.GraphDef for actual, got %s" %
                    type(actual).__name__)
  if not isinstance(expected, graph_pb2.GraphDef):
    raise TypeError("Expected tf.GraphDef for expected, got %s" %
                    type(expected).__name__)

  if checkpoint_v2:
    _strip_checkpoint_v2_randomized(actual)
    _strip_checkpoint_v2_randomized(expected)

  diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
                                                expected.SerializeToString())
  if diff:
    raise AssertionError(compat.as_str(diff))
开发者ID:LUTAN,项目名称:tensorflow,代码行数:32,代码来源:test_util.py

示例8: _create_new_tf_function

def _create_new_tf_function(func_graph):
  """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
  c_func = c_api.TF_GraphToFunction_wrapper(
      func_graph._c_graph,
      compat.as_str(func_graph.name),
      False,  # append_hash_to_fn_name
      None,  # opers
      [t._as_tf_output() for t in func_graph.inputs],
      [t._as_tf_output() for t in func_graph.outputs],
      [],
      None,  # opts
      None)  # description
  _ = c_api_util.ScopedTFFunction(c_func)

  # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
  # deserializing it into a Python FunctionDef, then reserializing it to create
  # a new TF_Function that we add to the graph.
  fdef = _function.function_def_from_tf_function(c_func)
  defined_func = _function._from_definition(fdef)
  defined_func._sub_functions = func_graph._functions
  defined_func.add_to_graph(func_graph._outer_graph)

  return func_graph.name
开发者ID:godyd2702,项目名称:tensorflow,代码行数:30,代码来源:cond_v2_impl.py

示例9: _node_def

def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
      if not export_scope or compat.as_str(v.s).startswith(export_scope):
        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
    else:
      node_def.attr[k].CopyFrom(v)

  if clear_devices:
    node_def.device = ""

  return node_def
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:46,代码来源:meta_graph.py

示例10: __init__

  def __init__(self, name, graph, operations, inputs, outputs, attrs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
    fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
        graph._c_graph,  # pylint: disable=protected-access
        compat.as_str(name),
        False,
        [o._c_op for o in operations],  # pylint: disable=protected-access
        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
        [],
        None,
        compat.as_str(""))

    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(iga): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use status.
      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
          fn, compat.as_str(name), serialized)

    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = function_def.signature.name
    self.signature = function_def.signature
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = c_api_util.ScopedTFFunction(fn)
    self._grad_func = None
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:46,代码来源:function.py

示例11: save

  def save(self, sess, save_path, global_step=None, latest_filename=None):
    """Saves variables.

    This method runs the ops added by the constructor for saving variables.
    It requires a session in which the graph was launched.  The variables to
    save must also have been initialized.

    The method returns the path of the newly created checkpoint file.  This
    path can be passed directly to a call to `restore()`.

    Args:
      sess: A Session to use to save the variables.
      save_path: String.  Path to the checkpoint filename.  If the saver is
        `sharded`, this is the prefix of the sharded checkpoint filename.
      global_step: If provided the global step number is appended to
        `save_path` to create the checkpoint filename. The optional argument
        can be a `Tensor`, a `Tensor` name or an integer.
      latest_filename: Optional name for the protocol buffer file that will
        contains the list of most recent checkpoint filenames.  That file,
        kept in the same directory as the checkpoint files, is automatically
        managed by the saver to keep track of recent checkpoints.  Defaults to
        'checkpoint'.

    Returns:
      A string: path at which the variables were saved.  If the saver is
        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
        is the number of shards created.

    Raises:
      TypeError: If `sess` is not a `Session`.
      ValueError: If `latest_filename` contains path components.
    """
    if latest_filename is None:
      latest_filename = "checkpoint"

    if os.path.split(latest_filename)[0]:
      raise ValueError("'latest_filename' must not contain path components")

    if global_step is not None:
      if not isinstance(global_step, compat.integral_types):
        global_step = training_util.global_step(sess, global_step)
      checkpoint_file = "%s-%d" % (save_path, global_step)
    else:
      checkpoint_file = save_path
    save_path = os.path.dirname(save_path)
    if not isinstance(sess, session.SessionInterface):
      raise TypeError("'sess' must be a Session; %s" % sess)

    model_checkpoint_path = sess.run(
        self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
    model_checkpoint_path = compat.as_str(model_checkpoint_path)
    self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
    update_checkpoint_state(save_path, model_checkpoint_path,
                            self.last_checkpoints, latest_filename)
    return model_checkpoint_path
开发者ID:hessenh,项目名称:Human-Activity-Recognition,代码行数:55,代码来源:saver.py

示例12: __init__

  def __init__(self, name, graph, operations, inputs, outputs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
    with errors.raise_exception_on_not_ok_status() as status:
      fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
          graph._c_graph,  # pylint: disable=protected-access
          compat.as_str(name),
          False,
          [o._c_op for o in operations],  # pylint: disable=protected-access
          [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
          [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
          [],
          None,
          compat.as_str(""),
          status)
    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      with errors.raise_exception_on_not_ok_status() as status:
        pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = function_def.signature.name
    self.signature = function_def.signature
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = fn
    self._grad_func = None
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:40,代码来源:function.py

示例13: request_stop

  def request_stop(self, ex=None):
    """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
    with self._lock:
      if not self._stop_event.is_set():
        if ex and self._exc_info_to_raise is None:
          if isinstance(ex, tuple):
            logging.info("Error reported to Coordinator: %s",
                         compat.as_str(unicode(ex[1])))
            self._exc_info_to_raise = ex
          else:
            logging.info("Error reported to Coordinator: %s",
                         compat.as_str(unicode(ex)))
            self._exc_info_to_raise = sys.exc_info()
        self._stop_event.set()
开发者ID:peace195,项目名称:tensorflow,代码行数:22,代码来源:coordinator.py

示例14: _set_c_attrs

  def _set_c_attrs(self, attrs):
    """Sets `attrs` as attributes of self._c_func.

    Requires that self._c_func is not None.

    Args:
      attrs: a dictionary from attribute name to attribute proto value
    """
    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use the same status.
      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
                                         serialized)
开发者ID:didukhle,项目名称:tensorflow,代码行数:14,代码来源:function.py

示例15: _ReadAndCheckRowsUsingFeatures

  def _ReadAndCheckRowsUsingFeatures(self, num_rows):
    self.server.handler.num_rows = num_rows

    with self.test_session() as sess:
      feature_configs = {
          "int64_col":
              parsing_ops.FixedLenFeature(
                  [1], dtype=dtypes.int64),
          "string_col":
              parsing_ops.FixedLenFeature(
                  [1], dtype=dtypes.string, default_value="s_default"),
      }
      reader = cloud.BigQueryReader(
          project_id=_PROJECT,
          dataset_id=_DATASET,
          table_id=_TABLE,
          num_partitions=4,
          features=feature_configs,
          timestamp_millis=1,
          test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
                                     self.server.httpd.server_address[1])))

      key, value = _SetUpQueue(reader)

      seen_rows = []
      features = parsing_ops.parse_example(
          array_ops.reshape(value, [1]), feature_configs)
      for _ in range(num_rows):
        int_value, str_value = sess.run(
            [features["int64_col"], features["string_col"]])

        # Parse values returned from the session.
        self.assertEqual(int_value.shape, (1, 1))
        self.assertEqual(str_value.shape, (1, 1))
        int64_col = int_value[0][0]
        string_col = str_value[0][0]
        seen_rows.append(int64_col)

        # Compare.
        expected_row = _ROWS[int64_col]
        self.assertEqual(int64_col, expected_row[0])
        self.assertEqual(
            compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
            else "s_default")

      self.assertItemsEqual(seen_rows, range(num_rows))

      with self.assertRaisesOpError("is closed and has insufficient elements "
                                    "\\(requested 1, current size 0\\)"):
        sess.run([key, value])
开发者ID:brainwy12,项目名称:tensorflow,代码行数:50,代码来源:bigquery_reader_ops_test.py


注:本文中的tensorflow.python.util.compat.as_str函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。