當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.io.decode_proto用法及代碼示例


該操作將序列化協議緩衝區消息中的字段提取到張量中。

用法

tf.io.decode_proto(
    bytes, message_type, field_names, output_types,
    descriptor_source='local://', message_format='binary',
    sanitize=False, name=None
)

參數

  • bytes Tensor 類型為 string 。形狀為 batch_shape 的序列化原型的張量。
  • message_type 一個string。要解碼的原始消息類型的名稱。
  • field_names strings 的列表。包含原始字段名稱的字符串列表。擴展字段可以通過使用其全名來解碼,例如EXT_PACKAGE.EXT_FIELD_NAME。
  • output_types tf.DTypes 的列表。用於field_names 中各個字段的 TF 類型列表。
  • descriptor_source 可選的 string 。默認為 "local://" 。特殊值 local:// 或包含序列化 FileDescriptorSet 的文件的路徑。
  • message_format 可選的 string 。默認為 "binary"binarytext
  • sanitize 可選的 bool 。默認為 False 。是否清理結果。
  • name 操作的名稱(可選)。

返回

  • Tensor 對象(大小、值)的元組。
  • sizes Tensor 類型為 int32
  • values 類型為 output_typesTensor 對象的列表。

注意:此 API 專為正交性而不是 human-friendliness 而設計。它可用於手動解析輸入原型,但它旨在用於生成的代碼。

decode_proto 操作將序列化協議緩衝區消息中的字段提取到張量中。如果可能,field_names 中的字段將被解碼並轉換為相應的output_types

必須提供 message_type 名稱以提供字段名稱的上下文。可以在linked-in 說明符池或調用者使用descriptor_source 屬性提供的文件名中查找實際的消息說明符。

每個輸出張量都是一個密集張量。這意味著它被填充以保存在輸入小批量中看到的最大數量的重複元素。 (形狀也被填充一以防止尺寸為零)。小批量中每個示例的實際重複計數可以在 sizes 輸出中找到。在許多情況下,如果不考慮缺失值,decode_proto 的輸出會立即輸入 tf.squeeze。使用 tf.squeeze 時,始終顯式傳遞擠壓維度以避免意外。

在大多數情況下,Proto 字段類型和 TensorFlow dtypes 之間的映射很簡單。但是,有一些特殊情況:

  • 包含子消息或組的 proto 字段隻能轉換為DT_STRING(序列化的子消息)。這是為了降低 API 的複雜性。生成的字符串可用作decode_proto op 的另一個實例的輸入。

  • TensorFlow 缺乏對無符號整數的支持。操作將 uint64 類型表示為具有相同二進製補碼位模式的 DT_INT64(顯而易見的方式)。無符號 int32 值可以通過指定類型 DT_INT64 來精確表示,或者如果調用者在 output_types 屬性中指定 DT_INT32 則使用二進製補碼。

  • map 字段不直接解碼。它們被視為相應條目類型的repeated 字段。 proto-compiler 定義每個映射字段的條目類型。 type-name 是字段名稱,轉換為 "CamelCase" 並附加 "Entry"。 tf.train.Features.FeatureEntry 消息是這些隱式 Entry 類型之一的示例。

  • enum 字段應讀取為 int32。

支持二進製和文本原始序列化,並且可以使用format 屬性進行選擇。

descriptor_source 屬性選擇在查找 message_type 時要參考的協議說明符的來源。這可能是:

  • 一個空字符串或"local://",在這種情況下,協議說明符是為鏈接到二進製文件的 C++(不是 Python)原型定義創建的。

  • 一個文件,在這種情況下,協議說明符是從該文件創建的,該文件應包含序列化為字符串的FileDescriptorSet。注意:您可以使用協議編譯器 protoc--descriptor_set_out--include_imports 選項構建 descriptor_source 文件。

  • 一個“字節://",其中協議說明符是從 <bytes> 創建的,它應該是序列化為字符串的 FileDescriptorSet

這是一個例子:

內部的Summary.Value proto 包含一個oneof {float simple_value; Image image; ...}

from google.protobuf import text_format

# A Summary.Value contains:oneof {float simple_value; Image image}
values = [
   "simple_value:2.2",
   "simple_value:1.2",
   "image { height:128 width:512 }",
   "image { height:256 width:256 }",]
values = [
   text_format.Parse(v, tf.compat.v1.Summary.Value()).SerializeToString()
   for v in values]

以下可以從序列化字符串中解碼這兩個字段:

sizes, [simple_value, image]  = tf.io.decode_proto(
 values,
 tf.compat.v1.Summary.Value.DESCRIPTOR.full_name,
 field_names=['simple_value', 'image'],
 output_types=[tf.float32, tf.string])

sizes 具有與輸入相同的形狀,在已解碼的字段之間有一個附加軸。這裏sizes的第一列是解碼的simple_value字段的大小:

print(sizes)
tf.Tensor(
[[1 0]
 [1 0]
 [0 1]
 [0 1]], shape=(4, 2), dtype=int32)

每個結果張量都比輸入byte-strings 多一個索引。每個結果張量的有效元素由 sizes 的相應列指示。無效元素用默認值填充:

print(simple_value)
tf.Tensor(
[[2.2]
 [1.2]
 [0. ]
 [0. ]], shape=(4, 1), dtype=float32)

嵌套的原型被提取為字符串張量:

print(image.dtype)
<dtype:'string'>
print(image.shape.as_list())
[4, 1]

要轉換為 tf.RaggedTensor 表示,請使用:

tf.RaggedTensor.from_tensor(simple_value, lengths=sizes[:, 0]).to_list()
[[2.2], [1.2], [], []]

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.io.decode_proto。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。