获取现有的局部变量或创建一个新变量。
用法
tf.compat.v1.get_local_variable(
name, shape=None, dtype=None, initializer=None, regularizer=None,
trainable=False, collections=None, caching_device=None, partitioner=None,
validate_shape=True, use_resource=None, custom_getter=None, constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE
)
参数
-
name
新变量或现有变量的名称。 -
shape
新变量或现有变量的形状。 -
dtype
新变量或现有变量的类型(默认为DT_FLOAT
)。 -
initializer
变量的初始化器(如果已创建)。可以是初始化对象或张量。如果它是张量,则必须知道其形状,除非 validate_shape 为 False。 -
regularizer
一个(张量 -> 张量或无)函数;将其应用于新创建的变量的结果将添加到集合tf.GraphKeys.REGULARIZATION_LOSSES
中,并可用于正则化。 -
collections
要将变量添加到的图形集合键列表。默认为[GraphKeys.LOCAL_VARIABLES]
(参见tf.Variable
)。 -
caching_device
可选的设备字符串或函数,说明应缓存变量以供读取的位置。默认为变量的设备。如果不是None
,则缓存在另一台设备上。典型用途是在使用变量的操作所在的设备上进行缓存,通过Switch
和其他条件语句进行重复数据删除。 -
partitioner
可选的可调用函数,接受要创建的变量的完全定义的TensorShape
和dtype
,并返回每个轴的分区列表(目前只能对一个轴进行分区)。 -
validate_shape
如果为 False,则允许使用未知形状的值初始化变量。如果为 True(默认),则必须知道 initial_value 的形状。为此,初始化器必须是张量而不是初始化器对象。 -
use_resource
如果为 False,则创建一个常规变量。如果为 true,则创建一个实验性的 ResourceVariable,而不是定义明确的语义。默认为 False(稍后将更改为 True)。当启用即刻执行时,此参数始终强制为 True。 -
custom_getter
可调用的,它将 true getter 作为第一个参数,并允许覆盖内部 get_variable 方法。的签名custom_getter
应该与此方法匹配,但最面向未来的版本将允许更改:def custom_getter(getter, *args, **kwargs)
.直接访问所有get_variable
参数也是允许的:def custom_getter(getter, name, *args, **kwargs)
.一个简单的身份自定义获取器,它简单地创建具有修改名称的变量是:def custom_getter(getter, name, *args, **kwargs): return getter(name + '_suffix', *args, **kwargs)
-
constraint
由Optimizer
更新后应用于变量的可选投影函数(例如,用于实现层权重的范数约束或值约束)。该函数必须将表示变量值的未投影张量作为输入,并返回投影值的张量(必须具有相同的形状)。在进行异步分布式训练时使用约束是不安全的。 -
synchronization
指示何时聚合分布式变量。接受的值是在类tf.VariableSynchronization
中定义的常量。默认情况下,同步设置为AUTO
,当前的DistributionStrategy
选择何时同步。 -
aggregation
指示如何聚合分布式变量。接受的值是在类tf.VariableAggregation
中定义的常量。
返回
-
创建的或现有的
Variable
(或PartitionedVariable
,如果使用了分区程序)。
抛出
-
ValueError
当创建新变量且未声明形状时,在变量创建期间违反重用时,或当initializer
dtype 和dtype
不匹配时。重用设置在variable_scope
内。
迁移到 TF2
警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。
虽然它是一个遗留的compat.v1
api,但tf.compat.v1.get_variable
主要与即刻执行和tf.function
兼容,但前提是你将它与tf.compat.v1.keras.utils.track_tf1_style_variables
装饰器结合使用。 (尽管它会表现得好像总是将重用设置为 AUTO_REUSE
。)
有关详细信息,请参阅模型迁移指南。
如果您不将其与 tf.compat.v1.keras.utils.track_tf1_style_variables
结合使用,则 get_variable
将在每次调用时创建一个全新的变量,并且永远不会重用变量,无论变量名称或 reuse
参数如何。
此符号的 TF2 等效项是 tf.Variable
,但请注意,使用 tf.Variable
时,您必须确保手动或通过 tf.Module
或 tf.keras.layers.Layer
机制跟踪变量(和正则化参数)。
迁移指南的一部分还提供了有关将这些用法增量迁移到 tf.Variable 的更多详细信息。
注意: partitioner
arg 与 TF2 行为不兼容,即使使用tf.compat.v1.keras.utils.track_tf1_style_variables.它可以通过使用来替换ParameterServerStrategy
及其分区器。见multi-gpu 迁移指南并且 ParameterServerStrategy 指导它参考以获取更多信息。
行为与 get_variable
中的相同,只是将变量添加到 LOCAL_VARIABLES
集合中并且将 trainable
设置为 False
。此函数以当前变量范围作为名称的前缀并执行重用检查。有关重用工作原理的详细说明,请参阅如何使用变量范围。这是一个基本示例:
def foo():
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
如果初始化程序是None
(默认值),则将使用在变量范围内传递的默认初始化程序。如果那个也是None
,则将使用glorot_uniform_initializer
。初始化器也可以是张量,在这种情况下,变量被初始化为此值和形状。
类似地,如果正则化器是None
(默认),则将使用在变量范围中传递的默认正则化器(如果也是None
,则默认情况下不执行正则化)。
如果提供了分区程序,则返回 PartitionedVariable
。作为 Tensor
访问此对象会返回沿分区轴连接的分片。
一些有用的分区器可用。例如,参见 variable_axis_size_partitioner
和 min_max_variable_partitioner
。
相关用法
- Python tf.compat.v1.get_variable_scope用法及代码示例
- Python tf.compat.v1.get_variable用法及代码示例
- Python tf.compat.v1.get_session_tensor用法及代码示例
- Python tf.compat.v1.get_session_handle用法及代码示例
- Python tf.compat.v1.gfile.Copy用法及代码示例
- Python tf.compat.v1.gather用法及代码示例
- Python tf.compat.v1.gfile.Exists用法及代码示例
- Python tf.compat.v1.gradients用法及代码示例
- Python tf.compat.v1.gfile.FastGFile.close用法及代码示例
- Python tf.compat.v1.gather_nd用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.compat.v1.strings.length用法及代码示例
- Python tf.compat.v1.data.Dataset.snapshot用法及代码示例
- Python tf.compat.v1.data.experimental.SqlDataset.reduce用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.get_local_variable。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。