当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.metrics.CosineSimilarity用法及代码示例


计算标签和预测之间的余弦相似度。

继承自:MeanMetricWrapperMeanMetricLayerModule

用法

tf.keras.metrics.CosineSimilarity(
    name='cosine_similarity', dtype=None, axis=-1
)

参数

  • name (可选)指标实例的字符串名称。
  • dtype (可选)度量结果的数据类型。
  • axis (可选)默认为 -1。计算余弦相似度的维度。

cosine similarity = (a . b) / ||a|| ||b||

请参阅:余弦相似度。

此度量标准在数据流上保持predictionslabels 之间的平均余弦相似度。

单机使用:

# l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]]
# l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]]
# l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
# result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
#        = ((0. + 0.) +  (0.5 + 0.5)) / 2
m = tf.keras.metrics.CosineSimilarity(axis=1)
m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])
m.result().numpy()
0.49999997
m.reset_state()
m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]],
               sample_weight=[0.3, 0.7])
m.result().numpy()
0.6999999

compile() API 的用法:

model.compile(
    optimizer='sgd',
    loss='mse',
    metrics=[tf.keras.metrics.CosineSimilarity(axis=1)])

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.metrics.CosineSimilarity。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。