當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.data.Dataset.get_single_element用法及代碼示例


用法

get_single_element(
    name=None
)

參數

  • name (可選。) tf.data 操作的名稱。

返回

  • tf.Tensor 對象的嵌套結構,對應於 dataset 的單個元素。

拋出

  • InvalidArgumentError (在運行時)如果 dataset 不包含恰好一個元素。

返回 dataset 的單個元素。

該函數使您能夠在無狀態的“tensor-in tensor-out”表達式中使用tf.data.Dataset,而無需創建迭代器。這有助於在張量上使用優化的tf.data.Dataset 抽象來簡化數據轉換。

例如,讓我們考慮一個preprocessing_fn,它將原始特征作為輸入並返回處理後的特征及其標簽。

def preprocessing_fn(raw_feature):
  # ... the raw_feature is preprocessed as per the use-case
  return feature

raw_features = ...  # input batch of BATCH_SIZE elements.
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
          .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
          .batch(BATCH_SIZE))

processed_features = dataset.get_single_element()

在上麵的例子中,長度=BATCH_SIZE的raw_features張量被轉換為tf.data.Dataset。接下來,使用preprocessing_fn 映射每個raw_feature,並將處理後的特征分組為一個批次。最終的dataset 僅包含一個元素,它是所有已處理特征的批次。

注意:dataset 應該隻包含一個元素。

現在,不是為dataset創建迭代器並檢索這批特征,而是使用tf.data.get_single_element()函數跳過迭代器創建過程並直接輸出這批特征。

當您的張量轉換表示為 tf.data.Dataset 操作並且您希望在為模型提供服務時使用這些轉換時,這可能特別有用。

喀拉斯

model = ... # A pre-built or custom model

class PreprocessingModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__(self)
    self.model = model

  @tf.function(input_signature=[...])
  def serving_fn(self, data):
    ds = tf.data.Dataset.from_tensor_slices(data)
    ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
    ds = ds.batch(batch_size=BATCH_SIZE)
    return tf.argmax(self.model(ds.get_single_element()), axis=-1)

preprocessing_model = PreprocessingModel(model)
your_exported_model_dir = ... # save the model to this path.
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
              signatures={'serving_default':preprocessing_model.serving_fn}
              )

估計器

在估計器的情況下,您通常需要定義一個serving_input_fn,這將需要模型在推理時處理特征。

def serving_input_fn():

  raw_feature_spec = ... # Spec for the raw_features
  input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  )
  serving_input_receiver = input_fn()
  raw_features = serving_input_receiver.features

  def preprocessing_fn(raw_feature):
    # ... the raw_feature is preprocessed as per the use-case
    return feature

  dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
            .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
            .batch(BATCH_SIZE))

  processed_features = dataset.get_single_element()

  # Please note that the value of `BATCH_SIZE` should be equal to
  # the size of the leading dimension of `raw_features`. This ensures
  # that `dataset` has only element, which is a pre-requisite for
  # using `dataset.get_single_element()`.

  return tf.estimator.export.ServingInputReceiver(
      processed_features, serving_input_receiver.receiver_tensors)

estimator = ... # A pre-built or custom estimator
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.Dataset.get_single_element。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。