当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.one_hot用法及代码示例


返回一个 one-hot 张量。

用法

tf.one_hot(
    indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None
)

参数

  • indices 索引的Tensor
  • depth 定义一个热维度的深度的标量。
  • on_value 一个标量,定义在 indices[j] = i 时填充输出的值。 (默认值:1)
  • off_value 一个标量,定义在 indices[j] != i 时填充输出的值。 (默认:0)
  • axis 要填充的轴(默认值:-1,新的inner-most 轴)。
  • dtype 输出张量的数据类型。
  • name 操作的名称(可选)。

返回

  • output one-hot 张量。

抛出

  • TypeError 如果 on_valueoff_value 的 dtype 不匹配 dtype
  • TypeError 如果 on_valueoff_value 的 dtype 彼此不匹配

另见tf.filltf.eye

indices 中的索引表示的位置取值 on_value ,而所有其他位置取值 off_value

on_valueoff_value 必须具有匹配的数据类型。如果还提供了 dtype,则它们必须与 dtype 指定的数据类型相同。

如果未提供on_value,则默认值为1,类型为dtype

如果未提供off_value,则默认值为0,类型为dtype

如果输入 indices 排名为 N ,则输出排名为 N+1 。新轴在维度 axis 处创建(默认:新轴附加在末尾)。

如果 indices 是标量,则输出形状将是长度为 depth 的向量

如果 indices 是长度为 features 的向量,则输出形状将为:

features x depth if axis == -1
  depth x features if axis == 0

如果 indices 是形状为 [batch, features] 的矩阵(批次),则输出形状将为:

batch x features x depth if axis == -1
  batch x depth x features if axis == 1
  depth x batch x features if axis == 0

如果 indices 是 RaggedTensor,则 'axis' 参数必须为正表示非参差不齐的轴。输出将等效于在 RaggedTensor 的值上应用 'one_hot',并根据结果创建新的 RaggedTensor。

如果未提供 dtype ,它将尝试假设 on_valueoff_value 的数据类型,如果传入一个或两个。如果没有提供 on_value , off_valuedtype ,则 dtype将默认为值 tf.float32

注意:如果需要非数字数据类型输出(tf.string,tf.bool等),两者on_valueoff_value 必须提供给one_hot.

例如:

indices = [0, 1, 2]
depth = 3
tf.one_hot(indices, depth)  # output:[3 x 3]
# [[1., 0., 0.],
#  [0., 1., 0.],
#  [0., 0., 1.]]

indices = [0, 2, -1, 1]
depth = 3
tf.one_hot(indices, depth,
           on_value=5.0, off_value=0.0,
           axis=-1)  # output:[4 x 3]
# [[5.0, 0.0, 0.0],  # one_hot(0)
#  [0.0, 0.0, 5.0],  # one_hot(2)
#  [0.0, 0.0, 0.0],  # one_hot(-1)
#  [0.0, 5.0, 0.0]]  # one_hot(1)

indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
           on_value=1.0, off_value=0.0,
           axis=-1)  # output:[2 x 2 x 3]
# [[[1.0, 0.0, 0.0],   # one_hot(0)
#   [0.0, 0.0, 1.0]],  # one_hot(2)
#  [[0.0, 1.0, 0.0],   # one_hot(1)
#   [0.0, 0.0, 0.0]]]  # one_hot(-1)

indices = tf.ragged.constant([[0, 1], [2]])
depth = 3
tf.one_hot(indices, depth)  # output:[2 x None x 3]
# [[[1., 0., 0.],
#   [0., 1., 0.]],
#  [[0., 0., 1.]]]

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.one_hot。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。