当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.distribute.TPUStrategy.scope用法及代码示例


用法

scope()

返回

  • 上下文管理器。

上下文管理器使策略成为当前策略并分配变量。

该方法返回一个上下文管理器,用法如下:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# Variable created inside scope:
with strategy.scope():
  mirrored_variable = tf.Variable(1.)
mirrored_variable
MirroredVariable:{
  0:<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
  1:<tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
# Variable created outside scope:
regular_variable = tf.Variable(1.)
regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

输入 Strategy.scope 会发生什么?

  • strategy 作为"current" 策略安装在全局上下文中。在此范围内,tf.distribute.get_strategy() 现在将返回此策略。在此范围之外,它返回默认的no-op 策略。
  • 进入范围也进入“cross-replica上下文”。有关cross-replica 和副本上下文的说明,请参见tf.distribute.StrategyExtended
  • scope 内的变量创建被策略拦截。每个策略都定义了它希望如何影响变量的创建。 MirroredStrategy , TPUStrategyMultiWorkerMiroredStrategy 等同步策略创建在每个副本上复制的变量,而 ParameterServerStrategy 在参数服务器上创建变量。这是使用自定义 tf.variable_creator_scope 完成的。
  • 在某些策略中,还可以输入默认设备范围:在MultiWorkerMiroredStrategy中,在每个worker上输入默认设备范围"/CPU:0"。

注意:进入范围不会自动分配计算,除非是像 keras 这样的高级训练框架model.fit.如果你不使用model.fit,你需要使用strategy.run用于显式分发该计算的 API。请参阅中的示例自定义训练循环教程.

什么应该在范围内,什么应该在范围之外?

对于范围内需要发生的事情有许多要求。但是,在我们有关于正在使用哪种策略的信息的地方,我们经常为用户输入范围,因此他们不必明确地这样做(即,在范围内或范围外调用它们都可以)。

  • 任何创建应该是分布式变量的变量都必须在 strategy.scope 中调用。这可以通过在作用域上下文中直接调用变量创建函数来完成,或者通过依赖另一个 API(如 strategy.runkeras.Model.fit 来自动为您输入它)来完成。在范围之外创建的任何变量都不会被分发,并且可能会影响性能。在 TF 中创建变量的一些常见对象是模型、优化器、度量。此类对象应始终在范围内初始化,任何可能延迟创建变量的函数(例如 Model.__call__() 、跟踪 tf.function 等)都应类似地在范围内调用。变量创建的另一个来源可以是检查点恢复 - 当变量被延迟创建时。请注意,在策略中创建的任何变量都会捕获策略信息。因此,在strategy.scope 之外读取和写入这些变量也可以无缝工作,而无需用户进入范围。
  • 一些策略 API(如 strategy.runstrategy.reduce )需要在策略的范围内,会自动进入范围,这意味着在使用这些 API 时,您不需要自己显式地进入范围。
  • 当在 strategy.scope 中创建 tf.keras.Model 时,模型对象会捕获范围信息。当调用model.compile , model.fit等高级训练框架方法时,会自动进入捕获的范围,并使用相关策略分发训练等。详细示例见分布式keras教程。警告:仅调用model(..) 不会自动进入捕获的范围——只有高级训练框架API 支持此行为:model.compile , model.fit , model.evaluate , model.predictmodel.save 都可以在范围内或范围外调用。
  • 以下内容可以在范围内或范围外:
    • 创建输入数据集
    • 定义 tf.function 代表您的训练步骤
    • 保存 API,例如 tf.saved_model.save 。加载会创建变量,因此如果您想以分布式方式训练模型,则应该在范围内。
    • 检查点保存。如上所述 - checkpoint.restore 如果创建变量,有时可能需要在范围内。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.TPUStrategy.scope。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。