当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.Model.compile用法及代码示例


用法

compile(
    optimizer='rmsprop', loss=None, metrics=None, loss_weights=None,
    weighted_metrics=None, run_eagerly=None, steps_per_execution=None,
    jit_compile=None, **kwargs
)

参数

  • optimizer 字符串(优化器名称)或优化器实例。见tf.keras.optimizers
  • loss 损失函数。可能是一个字符串(损失函数的名称),或者一个 tf.keras.losses.Loss 实例。请参阅tf.keras.losses。损失函数是任何带有签名 loss = fn(y_true, y_pred) 的可调用函数,其中 y_true 是真实值,而 y_pred 是模型的预测。 y_true 应该具有形状 (batch_size, d0, .. dN) (在稀疏损失函数的情况下,例如稀疏分类交叉熵,它需要形状为 (batch_size, d0, .. dN-1) 的整数数组)。 y_pred 应该具有形状 (batch_size, d0, .. dN) 。损失函数应该返回一个浮点张量。如果使用自定义 Loss 实例并将缩减设置为 None ,则返回值具有形状 (batch_size, d0, .. dN-1) 即 per-sample 或 per-timestep 损失值;否则,它是一个标量。如果模型有多个输出,您可以通过传递字典或损失列表对每个输出使用不同的损失。除非指定了loss_weights,否则模型将最小化的损失值将是所有单个损失的总和。
  • metrics 模型在训练和测试期间要评估的指标列表。每个都可以是字符串(内置函数的名称)、函数或tf.keras.metrics.Metric 实例。请参阅tf.keras.metrics。通常,您将使用 metrics=['accuracy'] 。函数是任何带有签名 result = fn(y_true, y_pred) 的可调用函数。要为 multi-output 模型的不同输出指定不同的指标,您还可以传递字典,例如 metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']} 。您还可以传递一个列表来为每个输出指定一个指标或指标列表,例如 metrics=[['accuracy'], ['accuracy', 'mse']]metrics=['accuracy', ['accuracy', 'mse']] 。当您传递字符串 'accuracy' 或 'acc' 时,我们会根据使用的损失函数和模型输出形状将其转换为 tf.keras.metrics.BinaryAccuracytf.keras.metrics.CategoricalAccuracytf.keras.metrics.SparseCategoricalAccuracy 之一。我们也对字符串'crossentropy' 和'ce' 进行类似的转换。
  • loss_weights 可选的列表或字典,指定标量系数(Python 浮点数)以加权不同模型输出的损失贡献。模型将最小化的损失值将是加权和所有个人损失,加权loss_weights系数。如果是列表,则预计与模型的输出有 1:1 的映射关系。如果是 dict,则应将输出名称(字符串)映射到标量系数。
  • weighted_metrics 在训练和测试期间由sample_weightclass_weight 评估和加权的指标列表。
  • run_eagerly 布尔。默认为 False 。如果 True ,这个 Model 的逻辑将不会被包装在 tf.function 中。建议将其保留为 None ,除非您的 Model 不能在 tf.function 内运行。使用 tf.distribute.experimental.ParameterServerStrategy 时不支持 run_eagerly=True
  • steps_per_execution Int. 默认为 1。每次 tf.function 调用期间要运行的批次数。在单个 tf.function 调用中运行多个批处理可以极大地提高 TPU 或具有大量 Python 开销的小型模型的性能。每次执行最多将运行一个完整的 epoch。如果传递的数字大于 epoch 的大小,则执行将被截断为 epoch 的大小。请注意,如果 steps_per_execution 设置为 N ,则 Callback.on_batch_beginCallback.on_batch_end 方法将仅在每个 N 批次(即在每个 tf.function 执行之前/之后)调用。
  • jit_compile 如果True,用 XLA 编译模型训练步骤。XLA是机器学习的优化编译器。jit_compile默认情况下未启用。此选项无法启用run_eagerly=True.注意jit_compile=True不一定适用于所有型号。有关支持的操作的更多信息,请参阅XLA 文档.另请参阅已知的 XLA 问题更多细节。
  • **kwargs 仅支持向后兼容的参数。

配置模型进行训练。

例子:

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.BinaryAccuracy(),
                       tf.keras.metrics.FalseNegatives()])

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.Model.compile。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。