當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.estimator.add_metrics用法及代碼示例

創建一個具有給定指標的新 tf.estimator.Estimator

用法

tf.estimator.add_metrics(
    estimator, metric_fn
)

參數

  • estimator tf.estimator.Estimator 對象。
  • metric_fn 應遵循以下簽名的函數:
    • Args:隻能以任意順序有以下四個參數:
      • 預測:預測 TensorTensor 的字典由給定的 estimator 創建。
      • 特征:輸入由input_fn創建的Tensor對象的dict,作為參數提供給estimator.evaluate
      • 標簽:標簽Tensorinput_fn創建的Tensor的字典,作為參數提供給estimator.evaluate
      • estimator 的 config:config 屬性。
      • 返回:按名稱鍵入的度量結果的字典。最終指標是此指標和 estimator's 現有指標的聯合。如果此和estimator 的現有指標之間存在名稱衝突,這將覆蓋現有指標。 dict 的值是調用度量函數的結果,即 (metric_tensor, update_op) 元組。

返回

例子:

def my_auc(labels, predictions):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
    return {'auc':auc_metric}

  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)

使用特征的自定義指標的示例用法:

def my_auc(labels, predictions, features):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
                            sample_weight=features['weight'])
    return {'auc':auc_metric}

  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.estimator.add_metrics。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。