當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.metrics.CosineSimilarity用法及代碼示例


計算標簽和預測之間的餘弦相似度。

繼承自:MeanMetricWrapperMeanMetricLayerModule

用法

tf.keras.metrics.CosineSimilarity(
    name='cosine_similarity', dtype=None, axis=-1
)

參數

  • name (可選)指標實例的字符串名稱。
  • dtype (可選)度量結果的數據類型。
  • axis (可選)默認為 -1。計算餘弦相似度的維度。

cosine similarity = (a . b) / ||a|| ||b||

請參閱:餘弦相似度。

此度量標準在數據流上保持predictionslabels 之間的平均餘弦相似度。

單機使用:

# l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]]
# l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]]
# l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
# result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
#        = ((0. + 0.) +  (0.5 + 0.5)) / 2
m = tf.keras.metrics.CosineSimilarity(axis=1)
m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])
m.result().numpy()
0.49999997
m.reset_state()
m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]],
               sample_weight=[0.3, 0.7])
m.result().numpy()
0.6999999

compile() API 的用法:

model.compile(
    optimizer='sgd',
    loss='mse',
    metrics=[tf.keras.metrics.CosineSimilarity(axis=1)])

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.metrics.CosineSimilarity。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。