當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch SummaryWriter.add_pr_curve用法及代碼示例


本文簡要介紹python語言中 torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve 的用法。

用法:

add_pr_curve(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None)

參數

  • tag(string) -數據標識符

  • labels(torch.Tensor,numpy.array, 或者字符串/blob 名稱) -地麵實況數據。每個元素的二進製標簽。

  • predictions(torch.Tensor,numpy.array, 或者字符串/blob 名稱) -元素被分類為真的概率。值應該在 [0, 1]

  • global_step(int) -要記錄的全局步長 值

  • num_thresholds(int) -用於繪製曲線的閾值數。

  • walltime(float) -事件紀元後的可選覆蓋默認 walltime (time.time()) 秒

添加精確召回曲線。繪製 precision-recall 曲線可讓您了解模型在不同閾值設置下的性能。通過此函數,您可以為每個目標提供真實標簽 (T/F) 和預測置信度(通常是模型的輸出)。 TensorBoard UI 將讓您以交互方式選擇閾值。

例子:

from torch.utils.tensorboard import SummaryWriter
import numpy as np
labels = np.random.randint(2, size=100)  # binary label
predictions = np.random.rand(100)
writer = SummaryWriter()
writer.add_pr_curve('pr_curve', labels, predictions, 0)
writer.close()

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。