返回一個 one-hot 張量。
用法
tf.one_hot(
indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None
)
參數
-
indices
索引的Tensor
。 -
depth
定義一個熱維度的深度的標量。 -
on_value
一個標量,定義在indices[j] = i
時填充輸出的值。 (默認值:1) -
off_value
一個標量,定義在indices[j] != i
時填充輸出的值。 (默認:0) -
axis
要填充的軸(默認值:-1,新的inner-most 軸)。 -
dtype
輸出張量的數據類型。 -
name
操作的名稱(可選)。
返回
-
output
one-hot 張量。
拋出
-
TypeError
如果on_value
或off_value
的 dtype 不匹配dtype
-
TypeError
如果on_value
和off_value
的 dtype 彼此不匹配
indices
中的索引表示的位置取值 on_value
,而所有其他位置取值 off_value
。
on_value
和 off_value
必須具有匹配的數據類型。如果還提供了 dtype
,則它們必須與 dtype
指定的數據類型相同。
如果未提供on_value
,則默認值為1
,類型為dtype
如果未提供off_value
,則默認值為0
,類型為dtype
如果輸入 indices
排名為 N
,則輸出排名為 N+1
。新軸在維度 axis
處創建(默認:新軸附加在末尾)。
如果 indices
是標量,則輸出形狀將是長度為 depth
的向量
如果 indices
是長度為 features
的向量,則輸出形狀將為:
features x depth if axis == -1
depth x features if axis == 0
如果 indices
是形狀為 [batch, features]
的矩陣(批次),則輸出形狀將為:
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0
如果 indices
是 RaggedTensor,則 'axis' 參數必須為正表示非參差不齊的軸。輸出將等效於在 RaggedTensor 的值上應用 'one_hot',並根據結果創建新的 RaggedTensor。
如果未提供 dtype
,它將嘗試假設 on_value
或 off_value
的數據類型,如果傳入一個或兩個。如果沒有提供 on_value
, off_value
或 dtype
,則 dtype
將默認為值 tf.float32
。
注意:如果需要非數字數據類型輸出(tf.string
,tf.bool
等),兩者on_value
和off_value
必須提供給one_hot
.
例如:
indices = [0, 1, 2]
depth = 3
tf.one_hot(indices, depth) # output:[3 x 3]
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]
indices = [0, 2, -1, 1]
depth = 3
tf.one_hot(indices, depth,
on_value=5.0, off_value=0.0,
axis=-1) # output:[4 x 3]
# [[5.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 5.0], # one_hot(2)
# [0.0, 0.0, 0.0], # one_hot(-1)
# [0.0, 5.0, 0.0]] # one_hot(1)
indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
on_value=1.0, off_value=0.0,
axis=-1) # output:[2 x 2 x 3]
# [[[1.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 1.0]], # one_hot(2)
# [[0.0, 1.0, 0.0], # one_hot(1)
# [0.0, 0.0, 0.0]]] # one_hot(-1)
indices = tf.ragged.constant([[0, 1], [2]])
depth = 3
tf.one_hot(indices, depth) # output:[2 x None x 3]
# [[[1., 0., 0.],
# [0., 1., 0.]],
# [[0., 0., 1.]]]
相關用法
- Python tf.ones_like用法及代碼示例
- Python tf.ones_initializer.from_config用法及代碼示例
- Python tf.ones用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代碼示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代碼示例
- Python tf.raw_ops.TPUReplicatedInput用法及代碼示例
- Python tf.raw_ops.Bitcast用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代碼示例
- Python tf.math.special.fresnel_cos用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.compat.v1.layers.conv3d用法及代碼示例
- Python tf.Variable.__lt__用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.one_hot。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。