用法
get_single_element(
name=None
)
参数
-
name
(可选。) tf.data 操作的名称。
返回
-
tf.Tensor
对象的嵌套结构,对应于dataset
的单个元素。
抛出
-
InvalidArgumentError
(在运行时)如果dataset
不包含恰好一个元素。
返回 dataset
的单个元素。
该函数使您能够在无状态的“tensor-in tensor-out”表达式中使用tf.data.Dataset
,而无需创建迭代器。这有助于在张量上使用优化的tf.data.Dataset
抽象来简化数据转换。
例如,让我们考虑一个preprocessing_fn
,它将原始特征作为输入并返回处理后的特征及其标签。
def preprocessing_fn(raw_feature):
# ... the raw_feature is preprocessed as per the use-case
return feature
raw_features = ... # input batch of BATCH_SIZE elements.
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
.batch(BATCH_SIZE))
processed_features = dataset.get_single_element()
在上面的例子中,长度=BATCH_SIZE的raw_features
张量被转换为tf.data.Dataset
。接下来,使用preprocessing_fn
映射每个raw_feature
,并将处理后的特征分组为一个批次。最终的dataset
仅包含一个元素,它是所有已处理特征的批次。
注意:dataset
应该只包含一个元素。
现在,不是为dataset
创建迭代器并检索这批特征,而是使用tf.data.get_single_element()
函数跳过迭代器创建过程并直接输出这批特征。
当您的张量转换表示为 tf.data.Dataset
操作并且您希望在为模型提供服务时使用这些转换时,这可能特别有用。
喀拉斯
model = ... # A pre-built or custom model
class PreprocessingModel(tf.keras.Model):
def __init__(self, model):
super().__init__(self)
self.model = model
@tf.function(input_signature=[...])
def serving_fn(self, data):
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
ds = ds.batch(batch_size=BATCH_SIZE)
return tf.argmax(self.model(ds.get_single_element()), axis=-1)
preprocessing_model = PreprocessingModel(model)
your_exported_model_dir = ... # save the model to this path.
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
signatures={'serving_default':preprocessing_model.serving_fn}
)
估计器
在估计器的情况下,您通常需要定义一个serving_input_fn
,这将需要模型在推理时处理特征。
def serving_input_fn():
raw_feature_spec = ... # Spec for the raw_features
input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
raw_feature_spec, default_batch_size=None)
)
serving_input_receiver = input_fn()
raw_features = serving_input_receiver.features
def preprocessing_fn(raw_feature):
# ... the raw_feature is preprocessed as per the use-case
return feature
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
.batch(BATCH_SIZE))
processed_features = dataset.get_single_element()
# Please note that the value of `BATCH_SIZE` should be equal to
# the size of the leading dimension of `raw_features`. This ensures
# that `dataset` has only element, which is a pre-requisite for
# using `dataset.get_single_element()`.
return tf.estimator.export.ServingInputReceiver(
processed_features, serving_input_receiver.receiver_tensors)
estimator = ... # A pre-built or custom estimator
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
相关用法
- Python tf.data.Dataset.group_by_window用法及代码示例
- Python tf.data.Dataset.take_while用法及代码示例
- Python tf.data.Dataset.cardinality用法及代码示例
- Python tf.data.Dataset.from_tensors用法及代码示例
- Python tf.data.Dataset.concatenate用法及代码示例
- Python tf.data.Dataset.unique用法及代码示例
- Python tf.data.Dataset.cache用法及代码示例
- Python tf.data.Dataset.unbatch用法及代码示例
- Python tf.data.Dataset.as_numpy_iterator用法及代码示例
- Python tf.data.Dataset.random用法及代码示例
- Python tf.data.Dataset.reduce用法及代码示例
- Python tf.data.Dataset.map用法及代码示例
- Python tf.data.Dataset.repeat用法及代码示例
- Python tf.data.Dataset.bucket_by_sequence_length用法及代码示例
- Python tf.data.Dataset.flat_map用法及代码示例
- Python tf.data.Dataset.choose_from_datasets用法及代码示例
- Python tf.data.Dataset.from_tensor_slices用法及代码示例
- Python tf.data.Dataset.with_options用法及代码示例
- Python tf.data.Dataset.take用法及代码示例
- Python tf.data.Dataset.snapshot用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.Dataset.get_single_element。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。