返回一个 one-hot 张量。
用法
tf.one_hot(
indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None
)参数
-
indices索引的Tensor。 -
depth定义一个热维度的深度的标量。 -
on_value一个标量,定义在indices[j] = i时填充输出的值。 (默认值:1) -
off_value一个标量,定义在indices[j] != i时填充输出的值。 (默认:0) -
axis要填充的轴(默认值:-1,新的inner-most 轴)。 -
dtype输出张量的数据类型。 -
name操作的名称(可选)。
返回
-
outputone-hot 张量。
抛出
-
TypeError如果on_value或off_value的 dtype 不匹配dtype -
TypeError如果on_value和off_value的 dtype 彼此不匹配
indices 中的索引表示的位置取值 on_value ,而所有其他位置取值 off_value 。
on_value 和 off_value 必须具有匹配的数据类型。如果还提供了 dtype,则它们必须与 dtype 指定的数据类型相同。
如果未提供on_value,则默认值为1,类型为dtype
如果未提供off_value,则默认值为0,类型为dtype
如果输入 indices 排名为 N ,则输出排名为 N+1 。新轴在维度 axis 处创建(默认:新轴附加在末尾)。
如果 indices 是标量,则输出形状将是长度为 depth 的向量
如果 indices 是长度为 features 的向量,则输出形状将为:
features x depth if axis == -1
depth x features if axis == 0
如果 indices 是形状为 [batch, features] 的矩阵(批次),则输出形状将为:
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0
如果 indices 是 RaggedTensor,则 'axis' 参数必须为正表示非参差不齐的轴。输出将等效于在 RaggedTensor 的值上应用 'one_hot',并根据结果创建新的 RaggedTensor。
如果未提供 dtype ,它将尝试假设 on_value 或 off_value 的数据类型,如果传入一个或两个。如果没有提供 on_value , off_value 或 dtype ,则 dtype将默认为值 tf.float32 。
注意:如果需要非数字数据类型输出(tf.string,tf.bool等),两者on_value和off_value 必须提供给one_hot.
例如:
indices = [0, 1, 2]
depth = 3
tf.one_hot(indices, depth) # output:[3 x 3]
# [[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]]
indices = [0, 2, -1, 1]
depth = 3
tf.one_hot(indices, depth,
on_value=5.0, off_value=0.0,
axis=-1) # output:[4 x 3]
# [[5.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 5.0], # one_hot(2)
# [0.0, 0.0, 0.0], # one_hot(-1)
# [0.0, 5.0, 0.0]] # one_hot(1)
indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
on_value=1.0, off_value=0.0,
axis=-1) # output:[2 x 2 x 3]
# [[[1.0, 0.0, 0.0], # one_hot(0)
# [0.0, 0.0, 1.0]], # one_hot(2)
# [[0.0, 1.0, 0.0], # one_hot(1)
# [0.0, 0.0, 0.0]]] # one_hot(-1)
indices = tf.ragged.constant([[0, 1], [2]])
depth = 3
tf.one_hot(indices, depth) # output:[2 x None x 3]
# [[[1., 0., 0.],
# [0., 1., 0.]],
# [[0., 0., 1.]]]
相关用法
- Python tf.ones_like用法及代码示例
- Python tf.ones_initializer.from_config用法及代码示例
- Python tf.ones用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.math.special.fresnel_cos用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.Variable.__lt__用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.one_hot。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
