当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch SummaryWriter.add_pr_curve用法及代码示例


本文简要介绍python语言中 torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve 的用法。

用法:

add_pr_curve(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None)

参数

  • tag(string) -数据标识符

  • labels(torch.Tensor,numpy.array, 或者字符串/blob 名称) -地面实况数据。每个元素的二进制标签。

  • predictions(torch.Tensor,numpy.array, 或者字符串/blob 名称) -元素被分类为真的概率。值应该在 [0, 1]

  • global_step(int) -要记录的全局步长 值

  • num_thresholds(int) -用于绘制曲线的阈值数。

  • walltime(float) -事件纪元后的可选覆盖默认 walltime (time.time()) 秒

添加精确召回曲线。绘制 precision-recall 曲线可让您了解模型在不同阈值设置下的性能。通过此函数,您可以为每个目标提供真实标签 (T/F) 和预测置信度(通常是模型的输出)。 TensorBoard UI 将让您以交互方式选择阈值。

例子:

from torch.utils.tensorboard import SummaryWriter
import numpy as np
labels = np.random.randint(2, size=100)  # binary label
predictions = np.random.rand(100)
writer = SummaryWriter()
writer.add_pr_curve('pr_curve', labels, predictions, 0)
writer.close()

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。