當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.data.TFRecordDataset.map用法及代碼示例


用法

map(
    map_func, num_parallel_calls=None, deterministic=None, name=None
)

參數

  • map_func 將數據集元素映射到另一個數據集元素的函數。
  • num_parallel_calls (可選。)tf.int64 標量 tf.Tensor ,表示要並行異步處理的數量元素。如果未指定,元素將按順序處理。如果使用值tf.data.AUTOTUNE,則並行調用的數量將根據可用 CPU 動態設置。
  • deterministic (可選。)指定 num_parallel_calls 時,如果指定了此布爾值( TrueFalse ),它將控製轉換生成元素的順序。如果設置為 False ,則允許轉換產生無序元素,以用確定性換取性能。如果未指定,則 tf.data.Options.deterministic 選項(默認為 True)控製行為。
  • name (可選。) tf.data 操作的名稱。

返回

  • Dataset 一個Dataset

跨此數據集的元素映射map_func

此轉換將 map_func 應用於此數據集的每個元素,並返回包含轉換後元素的新數據集,其順序與它們在輸入中出現的順序相同。 map_func 可用於更改數據集元素的值和結構。此處記錄了支持的結構構造。

例如,map 可用於將每個元素加 1,或投影元素組件的子集。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x:x + 1)
list(dataset.as_numpy_iterator())
[2, 3, 4, 5, 6]

map_func 的輸入簽名由該數據集中每個元素的結構決定。

dataset = Dataset.range(5)
# `map_func` takes a single argument of type `tf.Tensor` with the same
# shape and dtype.
result = dataset.map(lambda x:x + 1)
# Each element is a tuple containing two `tf.Tensor` objects.
elements = [(1, "foo"), (2, "bar"), (3, "baz")]
dataset = tf.data.Dataset.from_generator(
    lambda:elements, (tf.int32, tf.string))
# `map_func` takes two arguments of type `tf.Tensor`. This function
# projects out just the first component.
result = dataset.map(lambda x_int, y_str:x_int)
list(result.as_numpy_iterator())
[1, 2, 3]
# Each element is a dictionary mapping strings to `tf.Tensor` objects.
elements =  ([{"a":1, "b":"foo"},
              {"a":2, "b":"bar"},
              {"a":3, "b":"baz"}])
dataset = tf.data.Dataset.from_generator(
    lambda:elements, {"a":tf.int32, "b":tf.string})
# `map_func` takes a single argument of type `dict` with the same keys
# as the elements.
result = dataset.map(lambda d:str(d["a"]) + d["b"])

map_func 返回的一個或多個值決定了返回數據集中每個元素的結構。

dataset = tf.data.Dataset.range(3)
# `map_func` returns two `tf.Tensor` objects.
def g(x):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), dtype=tf.string, name=None))
# Python primitives, lists, and NumPy arrays are implicitly converted to
# `tf.Tensor`.
def h(x):
  return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
result = dataset.map(h)
result.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, name=None))
# `map_func` can return nested structures.
def i(x):
  return (37.0, [42, 16]), "foo"
result = dataset.map(i)
result.element_spec
((TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
 TensorSpec(shape=(), dtype=tf.string, name=None))

map_func 可以接受作為參數並返回任何類型的數據集元素。

請注意,無論定義 map_func 的上下文如何(eager vs. graph),tf.data 都會跟蹤函數並將其作為圖執行。要在函數內部使用 Python 代碼,您有幾個選項:

1)依靠 AutoGraph 將 Python 代碼轉換為等效的圖計算。這種方法的缺點是 AutoGraph 可以轉換部分但不是全部 Python 代碼。

2) 使用 tf.py_function ,它允許您編寫任意 Python 代碼,但通常會導致比 1) 更差的性能。例如:

d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
# transform a string tensor to upper case string using a Python function
def upper_case_fn(t:tf.Tensor):
  return t.numpy().decode('utf-8').upper()
d = d.map(lambda x:tf.py_function(func=upper_case_fn,
          inp=[x], Tout=tf.string))
list(d.as_numpy_iterator())
[b'HELLO', b'WORLD']

3) 使用 tf.numpy_function ,它還允許您編寫任意 Python 代碼。請注意,tf.py_function 接受 tf.Tensortf.numpy_function 接受 numpy 數組並僅返回 numpy 數組。例如:

d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
def upper_case_fn(t:np.ndarray):
  return t.decode('utf-8').upper()
d = d.map(lambda x:tf.numpy_function(func=upper_case_fn,
          inp=[x], Tout=tf.string))
list(d.as_numpy_iterator())
[b'HELLO', b'WORLD']

請注意,使用 tf.numpy_functiontf.py_function 通常會排除並行執行用戶定義轉換的可能性(因為 Python GIL)。

性能通常可以通過設置num_parallel_calls 來提高,這樣map 將使用多個線程來處理元素。如果不需要確定性順序,設置 deterministic=False 也可以提高性能。

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x:x + 1,
    num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=False)

如果 deterministic=True ,則此轉換產生的元素順序是確定性的。如果 map_func 包含有狀態操作和 num_parallel_calls > 1 ,則訪問該狀態的順序是未定義的,因此無論 deterministic 標誌值如何,輸出元素的值都可能不是確定性的。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.data.TFRecordDataset.map。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。