当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python tf.estimator.add_metrics用法及代码示例

创建一个具有给定指标的新 tf.estimator.Estimator

用法

tf.estimator.add_metrics(
    estimator, metric_fn
)

参数

  • estimator tf.estimator.Estimator 对象。
  • metric_fn 应遵循以下签名的函数:
    • Args:只能以任意顺序有以下四个参数:
      • 预测:预测 TensorTensor 的字典由给定的 estimator 创建。
      • 特征:输入由input_fn创建的Tensor对象的dict,作为参数提供给estimator.evaluate
      • 标签:标签Tensorinput_fn创建的Tensor的字典,作为参数提供给estimator.evaluate
      • estimator 的 config:config 属性。
      • 返回:按名称键入的度量结果的字典。最终指标是此指标和 estimator's 现有指标的联合。如果此和estimator 的现有指标之间存在名称冲突,这将覆盖现有指标。 dict 的值是调用度量函数的结果,即 (metric_tensor, update_op) 元组。

返回

例子:

def my_auc(labels, predictions):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
    return {'auc':auc_metric}

  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)

使用特征的自定义指标的示例用法:

def my_auc(labels, predictions, features):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
                            sample_weight=features['weight'])
    return {'auc':auc_metric}

  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.estimator.add_metrics。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。