创建一个具有给定指标的新 tf.estimator.Estimator
。
用法
tf.estimator.add_metrics(
estimator, metric_fn
)
参数
-
estimator
tf.estimator.Estimator
对象。 -
metric_fn
应遵循以下签名的函数:- Args:只能以任意顺序有以下四个参数:
- 预测:预测
Tensor
或Tensor
的字典由给定的estimator
创建。 - 特征:输入由
input_fn
创建的Tensor
对象的dict
,作为参数提供给estimator.evaluate
。 - 标签:标签
Tensor
或input_fn
创建的Tensor
的字典,作为参数提供给estimator.evaluate
。 estimator
的 config:config 属性。- 返回:按名称键入的度量结果的字典。最终指标是此指标和
estimator's
现有指标的联合。如果此和estimator
的现有指标之间存在名称冲突,这将覆盖现有指标。 dict 的值是调用度量函数的结果,即(metric_tensor, update_op)
元组。
- 预测:预测
- Args:只能以任意顺序有以下四个参数:
返回
-
一个新的
tf.estimator.Estimator
,它具有原始指标与给定指标的联合。
例子:
def my_auc(labels, predictions):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
return {'auc':auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
使用特征的自定义指标的示例用法:
def my_auc(labels, predictions, features):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
sample_weight=features['weight'])
return {'auc':auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
相关用法
- Python tf.estimator.TrainSpec用法及代码示例
- Python tf.estimator.LogisticRegressionHead用法及代码示例
- Python tf.estimator.MultiHead用法及代码示例
- Python tf.estimator.PoissonRegressionHead用法及代码示例
- Python tf.estimator.WarmStartSettings用法及代码示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代码示例
- Python tf.estimator.RunConfig用法及代码示例
- Python tf.estimator.MultiLabelHead用法及代码示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代码示例
- Python tf.estimator.BaselineEstimator用法及代码示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代码示例
- Python tf.estimator.Estimator用法及代码示例
- Python tf.estimator.experimental.LinearSDCA用法及代码示例
- Python tf.estimator.experimental.RNNClassifier用法及代码示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代码示例
- Python tf.estimator.LinearRegressor用法及代码示例
- Python tf.estimator.LinearEstimator用法及代码示例
- Python tf.estimator.DNNClassifier用法及代码示例
- Python tf.estimator.BaselineClassifier用法及代码示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.add_metrics。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。