創建一個具有給定指標的新 tf.estimator.Estimator
。
用法
tf.estimator.add_metrics(
estimator, metric_fn
)
參數
-
estimator
tf.estimator.Estimator
對象。 -
metric_fn
應遵循以下簽名的函數:- Args:隻能以任意順序有以下四個參數:
- 預測:預測
Tensor
或Tensor
的字典由給定的estimator
創建。 - 特征:輸入由
input_fn
創建的Tensor
對象的dict
,作為參數提供給estimator.evaluate
。 - 標簽:標簽
Tensor
或input_fn
創建的Tensor
的字典,作為參數提供給estimator.evaluate
。 estimator
的 config:config 屬性。- 返回:按名稱鍵入的度量結果的字典。最終指標是此指標和
estimator's
現有指標的聯合。如果此和estimator
的現有指標之間存在名稱衝突,這將覆蓋現有指標。 dict 的值是調用度量函數的結果,即(metric_tensor, update_op)
元組。
- 預測:預測
- Args:隻能以任意順序有以下四個參數:
返回
-
一個新的
tf.estimator.Estimator
,它具有原始指標與給定指標的聯合。
例子:
def my_auc(labels, predictions):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
return {'auc':auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
使用特征的自定義指標的示例用法:
def my_auc(labels, predictions, features):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
sample_weight=features['weight'])
return {'auc':auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
相關用法
- Python tf.estimator.TrainSpec用法及代碼示例
- Python tf.estimator.LogisticRegressionHead用法及代碼示例
- Python tf.estimator.MultiHead用法及代碼示例
- Python tf.estimator.PoissonRegressionHead用法及代碼示例
- Python tf.estimator.WarmStartSettings用法及代碼示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代碼示例
- Python tf.estimator.RunConfig用法及代碼示例
- Python tf.estimator.MultiLabelHead用法及代碼示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代碼示例
- Python tf.estimator.BaselineEstimator用法及代碼示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代碼示例
- Python tf.estimator.Estimator用法及代碼示例
- Python tf.estimator.experimental.LinearSDCA用法及代碼示例
- Python tf.estimator.experimental.RNNClassifier用法及代碼示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代碼示例
- Python tf.estimator.LinearRegressor用法及代碼示例
- Python tf.estimator.LinearEstimator用法及代碼示例
- Python tf.estimator.DNNClassifier用法及代碼示例
- Python tf.estimator.BaselineClassifier用法及代碼示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.estimator.add_metrics。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。