当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.io.decode_proto用法及代码示例


该操作将序列化协议缓冲区消息中的字段提取到张量中。

用法

tf.io.decode_proto(
    bytes, message_type, field_names, output_types,
    descriptor_source='local://', message_format='binary',
    sanitize=False, name=None
)

参数

  • bytes Tensor 类型为 string 。形状为 batch_shape 的序列化原型的张量。
  • message_type 一个string。要解码的原始消息类型的名称。
  • field_names strings 的列表。包含原始字段名称的字符串列表。扩展字段可以通过使用其全名来解码,例如EXT_PACKAGE.EXT_FIELD_NAME。
  • output_types tf.DTypes 的列表。用于field_names 中各个字段的 TF 类型列表。
  • descriptor_source 可选的 string 。默认为 "local://" 。特殊值 local:// 或包含序列化 FileDescriptorSet 的文件的路径。
  • message_format 可选的 string 。默认为 "binary"binarytext
  • sanitize 可选的 bool 。默认为 False 。是否清理结果。
  • name 操作的名称(可选)。

返回

  • Tensor 对象(大小、值)的元组。
  • sizes Tensor 类型为 int32
  • values 类型为 output_typesTensor 对象的列表。

注意:此 API 专为正交性而不是 human-friendliness 而设计。它可用于手动解析输入原型,但它旨在用于生成的代码。

decode_proto 操作将序列化协议缓冲区消息中的字段提取到张量中。如果可能,field_names 中的字段将被解码并转换为相应的output_types

必须提供 message_type 名称以提供字段名称的上下文。可以在linked-in 说明符池或调用者使用descriptor_source 属性提供的文件名中查找实际的消息说明符。

每个输出张量都是一个密集张量。这意味着它被填充以保存在输入小批量中看到的最大数量的重复元素。 (形状也被填充一以防止尺寸为零)。小批量中每个示例的实际重复计数可以在 sizes 输出中找到。在许多情况下,如果不考虑缺失值,decode_proto 的输出会立即输入 tf.squeeze。使用 tf.squeeze 时,始终显式传递挤压维度以避免意外。

在大多数情况下,Proto 字段类型和 TensorFlow dtypes 之间的映射很简单。但是,有一些特殊情况:

  • 包含子消息或组的 proto 字段只能转换为DT_STRING(序列化的子消息)。这是为了降低 API 的复杂性。生成的字符串可用作decode_proto op 的另一个实例的输入。

  • TensorFlow 缺乏对无符号整数的支持。操作将 uint64 类型表示为具有相同二进制补码位模式的 DT_INT64(显而易见的方式)。无符号 int32 值可以通过指定类型 DT_INT64 来精确表示,或者如果调用者在 output_types 属性中指定 DT_INT32 则使用二进制补码。

  • map 字段不直接解码。它们被视为相应条目类型的repeated 字段。 proto-compiler 定义每个映射字段的条目类型。 type-name 是字段名称,转换为 "CamelCase" 并附加 "Entry"。 tf.train.Features.FeatureEntry 消息是这些隐式 Entry 类型之一的示例。

  • enum 字段应读取为 int32。

支持二进制和文本原始序列化,并且可以使用format 属性进行选择。

descriptor_source 属性选择在查找 message_type 时要参考的协议说明符的来源。这可能是:

  • 一个空字符串或"local://",在这种情况下,协议说明符是为链接到二进制文件的 C++(不是 Python)原型定义创建的。

  • 一个文件,在这种情况下,协议说明符是从该文件创建的,该文件应包含序列化为字符串的FileDescriptorSet。注意:您可以使用协议编译器 protoc--descriptor_set_out--include_imports 选项构建 descriptor_source 文件。

  • 一个“字节://",其中协议说明符是从 <bytes> 创建的,它应该是序列化为字符串的 FileDescriptorSet

这是一个例子:

内部的Summary.Value proto 包含一个oneof {float simple_value; Image image; ...}

from google.protobuf import text_format

# A Summary.Value contains:oneof {float simple_value; Image image}
values = [
   "simple_value:2.2",
   "simple_value:1.2",
   "image { height:128 width:512 }",
   "image { height:256 width:256 }",]
values = [
   text_format.Parse(v, tf.compat.v1.Summary.Value()).SerializeToString()
   for v in values]

以下可以从序列化字符串中解码这两个字段:

sizes, [simple_value, image]  = tf.io.decode_proto(
 values,
 tf.compat.v1.Summary.Value.DESCRIPTOR.full_name,
 field_names=['simple_value', 'image'],
 output_types=[tf.float32, tf.string])

sizes 具有与输入相同的形状,在已解码的字段之间有一个附加轴。这里sizes的第一列是解码的simple_value字段的大小:

print(sizes)
tf.Tensor(
[[1 0]
 [1 0]
 [0 1]
 [0 1]], shape=(4, 2), dtype=int32)

每个结果张量都比输入byte-strings 多一个索引。每个结果张量的有效元素由 sizes 的相应列指示。无效元素用默认值填充:

print(simple_value)
tf.Tensor(
[[2.2]
 [1.2]
 [0. ]
 [0. ]], shape=(4, 1), dtype=float32)

嵌套的原型被提取为字符串张量:

print(image.dtype)
<dtype:'string'>
print(image.shape.as_list())
[4, 1]

要转换为 tf.RaggedTensor 表示,请使用:

tf.RaggedTensor.from_tensor(simple_value, lengths=sizes[:, 0]).to_list()
[[2.2], [1.2], [], []]

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.io.decode_proto。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。