用法
compile(
optimizer='rmsprop', loss=None, metrics=None, loss_weights=None,
weighted_metrics=None, run_eagerly=None, steps_per_execution=None,
jit_compile=None, **kwargs
)
参数
-
optimizer
字符串(优化器名称)或优化器实例。见tf.keras.optimizers
。 -
loss
损失函数。可能是一个字符串(损失函数的名称),或者一个tf.keras.losses.Loss
实例。请参阅tf.keras.losses
。损失函数是任何带有签名loss = fn(y_true, y_pred)
的可调用函数,其中y_true
是真实值,而y_pred
是模型的预测。y_true
应该具有形状(batch_size, d0, .. dN)
(在稀疏损失函数的情况下,例如稀疏分类交叉熵,它需要形状为(batch_size, d0, .. dN-1)
的整数数组)。y_pred
应该具有形状(batch_size, d0, .. dN)
。损失函数应该返回一个浮点张量。如果使用自定义Loss
实例并将缩减设置为None
,则返回值具有形状(batch_size, d0, .. dN-1)
即 per-sample 或 per-timestep 损失值;否则,它是一个标量。如果模型有多个输出,您可以通过传递字典或损失列表对每个输出使用不同的损失。除非指定了loss_weights
,否则模型将最小化的损失值将是所有单个损失的总和。 -
metrics
模型在训练和测试期间要评估的指标列表。每个都可以是字符串(内置函数的名称)、函数或tf.keras.metrics.Metric
实例。请参阅tf.keras.metrics
。通常,您将使用metrics=['accuracy']
。函数是任何带有签名result = fn(y_true, y_pred)
的可调用函数。要为 multi-output 模型的不同输出指定不同的指标,您还可以传递字典,例如metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']}
。您还可以传递一个列表来为每个输出指定一个指标或指标列表,例如metrics=[['accuracy'], ['accuracy', 'mse']]
或metrics=['accuracy', ['accuracy', 'mse']]
。当您传递字符串 'accuracy' 或 'acc' 时,我们会根据使用的损失函数和模型输出形状将其转换为tf.keras.metrics.BinaryAccuracy
、tf.keras.metrics.CategoricalAccuracy
、tf.keras.metrics.SparseCategoricalAccuracy
之一。我们也对字符串'crossentropy' 和'ce' 进行类似的转换。 -
loss_weights
可选的列表或字典,指定标量系数(Python 浮点数)以加权不同模型输出的损失贡献。模型将最小化的损失值将是加权和所有个人损失,加权loss_weights
系数。如果是列表,则预计与模型的输出有 1:1 的映射关系。如果是 dict,则应将输出名称(字符串)映射到标量系数。 -
weighted_metrics
在训练和测试期间由sample_weight
或class_weight
评估和加权的指标列表。 -
run_eagerly
布尔。默认为False
。如果True
,这个Model
的逻辑将不会被包装在tf.function
中。建议将其保留为None
,除非您的Model
不能在tf.function
内运行。使用tf.distribute.experimental.ParameterServerStrategy
时不支持run_eagerly=True
。 -
steps_per_execution
Int. 默认为 1。每次tf.function
调用期间要运行的批次数。在单个tf.function
调用中运行多个批处理可以极大地提高 TPU 或具有大量 Python 开销的小型模型的性能。每次执行最多将运行一个完整的 epoch。如果传递的数字大于 epoch 的大小,则执行将被截断为 epoch 的大小。请注意,如果steps_per_execution
设置为N
,则Callback.on_batch_begin
和Callback.on_batch_end
方法将仅在每个N
批次(即在每个tf.function
执行之前/之后)调用。 -
jit_compile
如果True
,用 XLA 编译模型训练步骤。XLA是机器学习的优化编译器。jit_compile
默认情况下未启用。此选项无法启用run_eagerly=True
.注意jit_compile=True
不一定适用于所有型号。有关支持的操作的更多信息,请参阅XLA 文档.另请参阅已知的 XLA 问题更多细节。 -
**kwargs
仅支持向后兼容的参数。
配置模型进行训练。
例子:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.FalseNegatives()])
相关用法
- Python tf.keras.Model.compute_loss用法及代码示例
- Python tf.keras.Model.compute_metrics用法及代码示例
- Python tf.keras.Model.reset_metrics用法及代码示例
- Python tf.keras.Model.save_spec用法及代码示例
- Python tf.keras.Model.save用法及代码示例
- Python tf.keras.Model用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代码示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代码示例
- Python tf.keras.layers.Dropout用法及代码示例
- Python tf.keras.activations.softplus用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.Model.compile。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。