用法
scope()
返回
- 上下文管理器。
上下文管理器使策略成为当前策略并分配变量。
该方法返回一个上下文管理器,用法如下:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# Variable created inside scope:
with strategy.scope():
mirrored_variable = tf.Variable(1.)
mirrored_variable
MirroredVariable:{
0:<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
1:<tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
# Variable created outside scope:
regular_variable = tf.Variable(1.)
regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
输入 Strategy.scope 会发生什么?
strategy
作为"current" 策略安装在全局上下文中。在此范围内,tf.distribute.get_strategy()
现在将返回此策略。在此范围之外,它返回默认的no-op 策略。- 进入范围也进入“cross-replica上下文”。有关cross-replica 和副本上下文的说明,请参见
tf.distribute.StrategyExtended
。 scope
内的变量创建被策略拦截。每个策略都定义了它希望如何影响变量的创建。MirroredStrategy
,TPUStrategy
和MultiWorkerMiroredStrategy
等同步策略创建在每个副本上复制的变量,而ParameterServerStrategy
在参数服务器上创建变量。这是使用自定义tf.variable_creator_scope
完成的。- 在某些策略中,还可以输入默认设备范围:在
MultiWorkerMiroredStrategy
中,在每个worker上输入默认设备范围"/CPU:0"。
注意:进入范围不会自动分配计算,除非是像 keras 这样的高级训练框架model.fit
.如果你不使用model.fit
,你需要使用strategy.run
用于显式分发该计算的 API。请参阅中的示例自定义训练循环教程.
什么应该在范围内,什么应该在范围之外?
对于范围内需要发生的事情有许多要求。但是,在我们有关于正在使用哪种策略的信息的地方,我们经常为用户输入范围,因此他们不必明确地这样做(即,在范围内或范围外调用它们都可以)。
- 任何创建应该是分布式变量的变量都必须在
strategy.scope
中调用。这可以通过在作用域上下文中直接调用变量创建函数来完成,或者通过依赖另一个 API(如strategy.run
或keras.Model.fit
来自动为您输入它)来完成。在范围之外创建的任何变量都不会被分发,并且可能会影响性能。在 TF 中创建变量的一些常见对象是模型、优化器、度量。此类对象应始终在范围内初始化,任何可能延迟创建变量的函数(例如Model.__call__()
、跟踪tf.function
等)都应类似地在范围内调用。变量创建的另一个来源可以是检查点恢复 - 当变量被延迟创建时。请注意,在策略中创建的任何变量都会捕获策略信息。因此,在strategy.scope
之外读取和写入这些变量也可以无缝工作,而无需用户进入范围。 - 一些策略 API(如
strategy.run
和strategy.reduce
)需要在策略的范围内,会自动进入范围,这意味着在使用这些 API 时,您不需要自己显式地进入范围。 - 当在
strategy.scope
中创建tf.keras.Model
时,模型对象会捕获范围信息。当调用model.compile
,model.fit
等高级训练框架方法时,会自动进入捕获的范围,并使用相关策略分发训练等。详细示例见分布式keras教程。警告:仅调用model(..)
不会自动进入捕获的范围——只有高级训练框架API 支持此行为:model.compile
,model.fit
,model.evaluate
,model.predict
和model.save
都可以在范围内或范围外调用。 - 以下内容可以在范围内或范围外:
- 创建输入数据集
- 定义
tf.function
代表您的训练步骤 - 保存 API,例如
tf.saved_model.save
。加载会创建变量,因此如果您想以分布式方式训练模型,则应该在范围内。 - 检查点保存。如上所述 -
checkpoint.restore
如果创建变量,有时可能需要在范围内。
相关用法
- Python tf.distribute.TPUStrategy.experimental_assign_to_logical_device用法及代码示例
- Python tf.distribute.TPUStrategy.reduce用法及代码示例
- Python tf.distribute.TPUStrategy.experimental_replicate_to_logical_devices用法及代码示例
- Python tf.distribute.TPUStrategy.experimental_split_to_logical_devices用法及代码示例
- Python tf.distribute.TPUStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.distribute.TPUStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.distribute.TPUStrategy.gather用法及代码示例
- Python tf.distribute.TPUStrategy.run用法及代码示例
- Python tf.distribute.TPUStrategy用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.distribute.experimental_set_strategy用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.gather用法及代码示例
- Python tf.distribute.cluster_resolver.TFConfigClusterResolver用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy用法及代码示例
- Python tf.distribute.NcclAllReduce用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.distribute.experimental.rpc.Server.create用法及代码示例
- Python tf.distribute.experimental.MultiWorkerMirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.distribute.OneDeviceStrategy.gather用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.distribute.TPUStrategy.scope。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。