解包用户提供的数据元组。
用法
tf.keras.utils.unpack_x_y_sample_weight(
data
)
参数
-
data
(x,)
,(x, y)
或(x, y, sample_weight)
形式的元组。
返回
-
未打包的元组,如果未提供
y
和sample_weight
,则为None
s。
这是在覆盖 Model.train_step
、 Model.test_step
或 Model.predict_step
时使用的便利实用程序。该实用程序可以轻松支持 (x,)
, (x, y)
或 (x, y, sample_weight)
形式的数据。
单机使用:
features_batch = tf.ones((10, 5))
labels_batch = tf.zeros((10, 5))
data = (features_batch, labels_batch)
# `y` and `sample_weight` will default to `None` if not provided.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
sample_weight is None
True
覆盖 Model.train_step
中的示例:
class MyModel(tf.keras.Model):
def train_step(self, data):
# If `sample_weight` is not provided, all samples will be weighted
# equally.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name:m.result() for m in self.metrics}
相关用法
- Python tf.keras.utils.custom_object_scope用法及代码示例
- Python tf.keras.utils.deserialize_keras_object用法及代码示例
- Python tf.keras.utils.array_to_img用法及代码示例
- Python tf.keras.utils.get_file用法及代码示例
- Python tf.keras.utils.experimental.DatasetCreator用法及代码示例
- Python tf.keras.utils.set_random_seed用法及代码示例
- Python tf.keras.utils.timeseries_dataset_from_array用法及代码示例
- Python tf.keras.utils.plot_model用法及代码示例
- Python tf.keras.utils.get_custom_objects用法及代码示例
- Python tf.keras.utils.pack_x_y_sample_weight用法及代码示例
- Python tf.keras.utils.img_to_array用法及代码示例
- Python tf.keras.utils.image_dataset_from_directory用法及代码示例
- Python tf.keras.utils.get_registered_object用法及代码示例
- Python tf.keras.utils.SidecarEvaluator用法及代码示例
- Python tf.keras.utils.to_categorical用法及代码示例
- Python tf.keras.utils.load_img用法及代码示例
- Python tf.keras.utils.text_dataset_from_directory用法及代码示例
- Python tf.keras.utils.SequenceEnqueuer用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.utils.unpack_x_y_sample_weight。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。