當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.one_hot用法及代碼示例


返回一個 one-hot 張量。

用法

tf.one_hot(
    indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None
)

參數

  • indices 索引的Tensor
  • depth 定義一個熱維度的深度的標量。
  • on_value 一個標量,定義在 indices[j] = i 時填充輸出的值。 (默認值:1)
  • off_value 一個標量,定義在 indices[j] != i 時填充輸出的值。 (默認:0)
  • axis 要填充的軸(默認值:-1,新的inner-most 軸)。
  • dtype 輸出張量的數據類型。
  • name 操作的名稱(可選)。

返回

  • output one-hot 張量。

拋出

  • TypeError 如果 on_valueoff_value 的 dtype 不匹配 dtype
  • TypeError 如果 on_valueoff_value 的 dtype 彼此不匹配

另見tf.filltf.eye

indices 中的索引表示的位置取值 on_value ,而所有其他位置取值 off_value

on_valueoff_value 必須具有匹配的數據類型。如果還提供了 dtype,則它們必須與 dtype 指定的數據類型相同。

如果未提供on_value,則默認值為1,類型為dtype

如果未提供off_value,則默認值為0,類型為dtype

如果輸入 indices 排名為 N ,則輸出排名為 N+1 。新軸在維度 axis 處創建(默認:新軸附加在末尾)。

如果 indices 是標量,則輸出形狀將是長度為 depth 的向量

如果 indices 是長度為 features 的向量,則輸出形狀將為:

features x depth if axis == -1
  depth x features if axis == 0

如果 indices 是形狀為 [batch, features] 的矩陣(批次),則輸出形狀將為:

batch x features x depth if axis == -1
  batch x depth x features if axis == 1
  depth x batch x features if axis == 0

如果 indices 是 RaggedTensor,則 'axis' 參數必須為正表示非參差不齊的軸。輸出將等效於在 RaggedTensor 的值上應用 'one_hot',並根據結果創建新的 RaggedTensor。

如果未提供 dtype ,它將嘗試假設 on_valueoff_value 的數據類型,如果傳入一個或兩個。如果沒有提供 on_value , off_valuedtype ,則 dtype將默認為值 tf.float32

注意:如果需要非數字數據類型輸出(tf.string,tf.bool等),兩者on_valueoff_value 必須提供給one_hot.

例如:

indices = [0, 1, 2]
depth = 3
tf.one_hot(indices, depth)  # output:[3 x 3]
# [[1., 0., 0.],
#  [0., 1., 0.],
#  [0., 0., 1.]]

indices = [0, 2, -1, 1]
depth = 3
tf.one_hot(indices, depth,
           on_value=5.0, off_value=0.0,
           axis=-1)  # output:[4 x 3]
# [[5.0, 0.0, 0.0],  # one_hot(0)
#  [0.0, 0.0, 5.0],  # one_hot(2)
#  [0.0, 0.0, 0.0],  # one_hot(-1)
#  [0.0, 5.0, 0.0]]  # one_hot(1)

indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
           on_value=1.0, off_value=0.0,
           axis=-1)  # output:[2 x 2 x 3]
# [[[1.0, 0.0, 0.0],   # one_hot(0)
#   [0.0, 0.0, 1.0]],  # one_hot(2)
#  [[0.0, 1.0, 0.0],   # one_hot(1)
#   [0.0, 0.0, 0.0]]]  # one_hot(-1)

indices = tf.ragged.constant([[0, 1], [2]])
depth = 3
tf.one_hot(indices, depth)  # output:[2 x None x 3]
# [[[1., 0., 0.],
#   [0., 1., 0.]],
#  [[0., 0., 1.]]]

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.one_hot。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。