當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.metrics.Metric用法及代碼示例


封裝度量邏輯和狀態。

繼承自:LayerModule

用法

tf.keras.metrics.Metric(
    name=None, dtype=None, **kwargs
)

參數

  • name (可選)指標實例的字符串名稱。
  • dtype (可選)度量結果的數據類型。
  • **kwargs 附加層關鍵字參數。

單機使用:

m = SomeMetric(...)
for input in ...:
  m.update_state(input)
print('Final result:', m.result().numpy())

compile() API 的用法:

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))

model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=[tf.keras.metrics.CategoricalAccuracy()])

data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))

dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)

model.fit(dataset, epochs=10)

由子類實現:

  • __init__() :所有狀態變量都應在此方法中通過調用 self.add_weight() 來創建,例如:self.var = self.add_weight(...)
  • update_state() :對狀態變量進行所有更新,例如:self.var.assign_add(...)。
  • result() :計算並返返回自狀態變量的度量的標量值或標量值字典。

示例子類實現:

class BinaryTruePositives(tf.keras.metrics.Metric):

  def __init__(self, name='binary_true_positives', **kwargs):
    super(BinaryTruePositives, self).__init__(name=name, **kwargs)
    self.true_positives = self.add_weight(name='tp', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
      sample_weight = tf.cast(sample_weight, self.dtype)
      sample_weight = tf.broadcast_to(sample_weight, values.shape)
      values = tf.multiply(values, sample_weight)
    self.true_positives.assign_add(tf.reduce_sum(values))

  def result(self):
    return self.true_positives

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.metrics.Metric。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。