当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.Dataset.get_single_element用法及代码示例


用法

get_single_element(
    name=None
)

参数

  • name (可选。) tf.data 操作的名称。

返回

  • tf.Tensor 对象的嵌套结构,对应于 dataset 的单个元素。

抛出

  • InvalidArgumentError (在运行时)如果 dataset 不包含恰好一个元素。

返回 dataset 的单个元素。

该函数使您能够在无状态的“tensor-in tensor-out”表达式中使用tf.data.Dataset,而无需创建迭代器。这有助于在张量上使用优化的tf.data.Dataset 抽象来简化数据转换。

例如,让我们考虑一个preprocessing_fn,它将原始特征作为输入并返回处理后的特征及其标签。

def preprocessing_fn(raw_feature):
  # ... the raw_feature is preprocessed as per the use-case
  return feature

raw_features = ...  # input batch of BATCH_SIZE elements.
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
          .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
          .batch(BATCH_SIZE))

processed_features = dataset.get_single_element()

在上面的例子中,长度=BATCH_SIZE的raw_features张量被转换为tf.data.Dataset。接下来,使用preprocessing_fn 映射每个raw_feature,并将处理后的特征分组为一个批次。最终的dataset 仅包含一个元素,它是所有已处理特征的批次。

注意:dataset 应该只包含一个元素。

现在,不是为dataset创建迭代器并检索这批特征,而是使用tf.data.get_single_element()函数跳过迭代器创建过程并直接输出这批特征。

当您的张量转换表示为 tf.data.Dataset 操作并且您希望在为模型提供服务时使用这些转换时,这可能特别有用。

喀拉斯

model = ... # A pre-built or custom model

class PreprocessingModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__(self)
    self.model = model

  @tf.function(input_signature=[...])
  def serving_fn(self, data):
    ds = tf.data.Dataset.from_tensor_slices(data)
    ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
    ds = ds.batch(batch_size=BATCH_SIZE)
    return tf.argmax(self.model(ds.get_single_element()), axis=-1)

preprocessing_model = PreprocessingModel(model)
your_exported_model_dir = ... # save the model to this path.
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
              signatures={'serving_default':preprocessing_model.serving_fn}
              )

估计器

在估计器的情况下,您通常需要定义一个serving_input_fn,这将需要模型在推理时处理特征。

def serving_input_fn():

  raw_feature_spec = ... # Spec for the raw_features
  input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  )
  serving_input_receiver = input_fn()
  raw_features = serving_input_receiver.features

  def preprocessing_fn(raw_feature):
    # ... the raw_feature is preprocessed as per the use-case
    return feature

  dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
            .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
            .batch(BATCH_SIZE))

  processed_features = dataset.get_single_element()

  # Please note that the value of `BATCH_SIZE` should be equal to
  # the size of the leading dimension of `raw_features`. This ensures
  # that `dataset` has only element, which is a pre-requisite for
  # using `dataset.get_single_element()`.

  return tf.estimator.export.ServingInputReceiver(
      processed_features, serving_input_receiver.receiver_tensors)

estimator = ... # A pre-built or custom estimator
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.Dataset.get_single_element。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。