當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.utils.unpack_x_y_sample_weight用法及代碼示例


解包用戶提供的數據元組。

用法

tf.keras.utils.unpack_x_y_sample_weight(
    data
)

參數

  • data (x,) , (x, y)(x, y, sample_weight) 形式的元組。

返回

  • 未打包的元組,如果未提供 ysample_weight,則為 None s。

這是在覆蓋 Model.train_stepModel.test_stepModel.predict_step 時使用的便利實用程序。該實用程序可以輕鬆支持 (x,) , (x, y)(x, y, sample_weight) 形式的數據。

單機使用:

features_batch = tf.ones((10, 5))
labels_batch = tf.zeros((10, 5))
data = (features_batch, labels_batch)
# `y` and `sample_weight` will default to `None` if not provided.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
sample_weight is None
True

覆蓋 Model.train_step 中的示例:

class MyModel(tf.keras.Model):

  def train_step(self, data):
    # If `sample_weight` is not provided, all samples will be weighted
    # equally.
    x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)

    with tf.GradientTape() as tape:
      y_pred = self(x, training=True)
      loss = self.compiled_loss(
        y, y_pred, sample_weight, regularization_losses=self.losses)
      trainable_variables = self.trainable_variables
      gradients = tape.gradient(loss, trainable_variables)
      self.optimizer.apply_gradients(zip(gradients, trainable_variables))

    self.compiled_metrics.update_state(y, y_pred, sample_weight)
    return {m.name:m.result() for m in self.metrics}

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.unpack_x_y_sample_weight。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。