當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.compat.v1.metrics.accuracy用法及代碼示例

計算 predictions 匹配 labels 的頻率。

用法

tf.compat.v1.metrics.accuracy(
    labels, predictions, weights=None, metrics_collections=None,
    updates_collections=None, name=None
)

參數

  • labels 基本事實值,一個 Tensor,其形狀與 predictions 匹配。
  • predictions 預測值,任何形狀的Tensor
  • weights 可選Tensor,其秩為0,或與labels相同的秩,並且必須可廣播到labels(即,所有維度必須是1,或與相應的labels維度相同) .
  • metrics_collections accuracy 應添加到的可選集合列表。
  • updates_collections update_op 應添加到的可選集合列表。
  • name 可選的 variable_scope 名稱。

返回

  • accuracy 一個 Tensor 代表準確度,total 的值除以 count
  • update_op 適當增加 totalcount 變量並且其值與 accuracy 匹配的操作。

拋出

  • ValueError 如果 predictionslabels 的形狀不匹配,或者如果 weights 不是 None 並且其形狀與 predictions 不匹配,或者如果 metrics_collectionsupdates_collections 不是列表或元組。
  • RuntimeError 如果啟用了即刻執行。

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

tf.compat.v1.metrics.accuracy 與即刻執行或 tf.function 不兼容。請改用tf.keras.metrics.Accuracy 進行 TF2 遷移。實例化tf.keras.metrics.Accuracy對象後,可以先調用update_state()方法記錄預測/標簽,然後調用result()方法即刻獲取準確率。您還可以在調用compile 方法時將其附加到 Keras 模型。有關詳細信息,請參閱本指南。

到原生 TF2 的結構映射

前:

accuracy, update_op = tf.compat.v1.metrics.accuracy(
  labels=labels,
  predictions=predictions,
  weights=weights,
  metrics_collections=metrics_collections,
  update_collections=update_collections,
  name=name)

後:

m = tf.keras.metrics.Accuracy(
   name=name,
   dtype=None)

 m.update_state(
 y_true=labels,
 y_pred=predictions,
 sample_weight=weights)

 accuracy = m.result()

如何映射參數

TF1 參數名稱 TF2 參數名稱 注意
label y_true update_state() 方法中
predictions y_true update_state() 方法中
weights sample_weight update_state() 方法中
metrics_collections 不支持 應顯式跟蹤指標或使用 Keras API,例如add_metric,而不是通過集合
updates_collections 不支持 -
name name 在構造函數中

使用示例之前和之後

前:

g = tf.Graph()
with g.as_default():
  logits = [1, 2, 3]
  labels = [0, 2, 3]
  acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels)
  global_init = tf.compat.v1.global_variables_initializer()
  local_init = tf.compat.v1.local_variables_initializer()
sess = tf.compat.v1.Session(graph=g)
sess.run([global_init, local_init])
print(sess.run([acc, acc_op]))
[0.0, 0.66667]

後:

m = tf.keras.metrics.Accuracy()
m.update_state([1, 2, 3], [0, 2, 3])
m.result().numpy()
0.66667
# Used within Keras model
model.compile(optimizer='sgd',
              loss='mse',
              metrics=[tf.keras.metrics.Accuracy()])

accuracy 函數創建兩個局部變量 totalcount 用於計算 predictions 匹配 labels 的頻率。該頻率最終返回為 accuracy :一個冪等運算,隻需將 total 除以 count

為了估計數據流上的度量,該函數創建一個 update_op 操作來更新這些變量並返回 accuracy 。在內部,is_correct 操作計算 Tensor,其中元素 1.0 與 predictionslabels 的對應元素匹配,否則為 0.0。然後 update_optotalweightsis_correct 的乘積的減和相加,並用 weights 的減和後增加 count

如果 weightsNone ,則權重默認為 1。使用權重 0 來屏蔽值。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.metrics.accuracy。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。