当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.data.Dataset.map用法及代码示例


用法

map(
    map_func, num_parallel_calls=None, deterministic=None, name=None
)

参数

  • map_func 将数据集元素映射到另一个数据集元素的函数。
  • num_parallel_calls (可选。)tf.int64 标量 tf.Tensor ,表示要并行异步处理的数量元素。如果未指定,元素将按顺序处理。如果使用值tf.data.AUTOTUNE,则并行调用的数量将根据可用 CPU 动态设置。
  • deterministic (可选。)指定 num_parallel_calls 时,如果指定了此布尔值( TrueFalse ),它将控制转换生成元素的顺序。如果设置为 False ,则允许转换产生无序元素,以用确定性换取性能。如果未指定,则 tf.data.Options.deterministic 选项(默认为 True)控制行为。
  • name (可选。) tf.data 操作的名称。

返回

  • Dataset 一个Dataset

跨此数据集的元素映射map_func

此转换将 map_func 应用于此数据集的每个元素,并返回包含转换后元素的新数据集,其顺序与它们在输入中出现的顺序相同。 map_func 可用于更改数据集元素的值和结构。此处记录了支持的结构构造。

例如,map 可用于将每个元素加 1,或投影元素组件的子集。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x:x + 1)
list(dataset.as_numpy_iterator())
[2, 3, 4, 5, 6]

map_func 的输入签名由该数据集中每个元素的结构决定。

dataset = Dataset.range(5)
# `map_func` takes a single argument of type `tf.Tensor` with the same
# shape and dtype.
result = dataset.map(lambda x:x + 1)
# Each element is a tuple containing two `tf.Tensor` objects.
elements = [(1, "foo"), (2, "bar"), (3, "baz")]
dataset = tf.data.Dataset.from_generator(
    lambda:elements, (tf.int32, tf.string))
# `map_func` takes two arguments of type `tf.Tensor`. This function
# projects out just the first component.
result = dataset.map(lambda x_int, y_str:x_int)
list(result.as_numpy_iterator())
[1, 2, 3]
# Each element is a dictionary mapping strings to `tf.Tensor` objects.
elements =  ([{"a":1, "b":"foo"},
              {"a":2, "b":"bar"},
              {"a":3, "b":"baz"}])
dataset = tf.data.Dataset.from_generator(
    lambda:elements, {"a":tf.int32, "b":tf.string})
# `map_func` takes a single argument of type `dict` with the same keys
# as the elements.
result = dataset.map(lambda d:str(d["a"]) + d["b"])

map_func 返回的一个或多个值决定了返回数据集中每个元素的结构。

dataset = tf.data.Dataset.range(3)
# `map_func` returns two `tf.Tensor` objects.
def g(x):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), dtype=tf.string, name=None))
# Python primitives, lists, and NumPy arrays are implicitly converted to
# `tf.Tensor`.
def h(x):
  return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
result = dataset.map(h)
result.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, name=None))
# `map_func` can return nested structures.
def i(x):
  return (37.0, [42, 16]), "foo"
result = dataset.map(i)
result.element_spec
((TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
 TensorSpec(shape=(), dtype=tf.string, name=None))

map_func 可以接受作为参数并返回任何类型的数据集元素。

请注意,无论定义 map_func 的上下文如何(eager vs. graph),tf.data 都会跟踪函数并将其作为图执行。要在函数内部使用 Python 代码,您有几个选项:

1)依靠 AutoGraph 将 Python 代码转换为等效的图计算。这种方法的缺点是 AutoGraph 可以转换部分但不是全部 Python 代码。

2) 使用 tf.py_function ,它允许您编写任意 Python 代码,但通常会导致比 1) 更差的性能。例如:

d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
# transform a string tensor to upper case string using a Python function
def upper_case_fn(t:tf.Tensor):
  return t.numpy().decode('utf-8').upper()
d = d.map(lambda x:tf.py_function(func=upper_case_fn,
          inp=[x], Tout=tf.string))
list(d.as_numpy_iterator())
[b'HELLO', b'WORLD']

3) 使用 tf.numpy_function ,它还允许您编写任意 Python 代码。请注意,tf.py_function 接受 tf.Tensortf.numpy_function 接受 numpy 数组并仅返回 numpy 数组。例如:

d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
def upper_case_fn(t:np.ndarray):
  return t.decode('utf-8').upper()
d = d.map(lambda x:tf.numpy_function(func=upper_case_fn,
          inp=[x], Tout=tf.string))
list(d.as_numpy_iterator())
[b'HELLO', b'WORLD']

请注意,使用 tf.numpy_functiontf.py_function 通常会排除并行执行用户定义转换的可能性(因为 Python GIL)。

性能通常可以通过设置num_parallel_calls 来提高,这样map 将使用多个线程来处理元素。如果不需要确定性顺序,设置 deterministic=False 也可以提高性能。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x:x + 1,
    num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=False)

如果 deterministic=True ,则此转换产生的元素顺序是确定性的。如果 map_func 包含有状态操作和 num_parallel_calls > 1 ,则访问该状态的顺序是未定义的,因此无论 deterministic 标志值如何,输出元素的值都可能不是确定性的。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.data.Dataset.map。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。