当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.get_variable用法及代码示例


使用这些参数获取现有变量或创建一个新变量。

用法

tf.compat.v1.get_variable(
    name, shape=None, dtype=None, initializer=None, regularizer=None,
    trainable=None, collections=None, caching_device=None, partitioner=None,
    validate_shape=True, use_resource=None, custom_getter=None, constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.compat.v1.VariableAggregation.NONE
)

参数

  • name 新变量或现有变量的名称。
  • shape 新变量或现有变量的形状。
  • dtype 新变量或现有变量的类型(默认为 DT_FLOAT )。
  • initializer 变量的初始化器(如果已创建)。可以是初始化对象或张量。如果它是张量,则必须知道其形状,除非 validate_shape 为 False。
  • regularizer 一个(张量 -> 张量或无)函数;将其应用于新创建的变量的结果将添加到集合 tf.GraphKeys.REGULARIZATION_LOSSES 中,并可用于正则化。
  • trainable 如果 True 还将变量添加到图形集合 GraphKeys.TRAINABLE_VARIABLES (请参阅 tf.Variable )。
  • collections 要将变量添加到的图形集合键列表。默认为 [GraphKeys.GLOBAL_VARIABLES](参见 tf.Variable )。
  • caching_device 可选的设备字符串或函数,说明应缓存变量以供读取的位置。默认为变量的设备。如果不是 None ,则缓存在另一台设备上。典型用途是在使用变量的操作所在的设备上进行缓存,通过Switch 和其他条件语句进行重复数据删除。
  • partitioner 可选的可调用函数,接受要创建的变量的完全定义的TensorShapedtype,并返回每个轴的分区列表(目前只能对一个轴进行分区)。
  • validate_shape 如果为 False,则允许使用未知形状的值初始化变量。如果为 True(默认),则必须知道 initial_value 的形状。为此,初始化器必须是张量而不是初始化器对象。
  • use_resource 如果为 False,则创建一个常规变量。如果为 true,则创建一个实验性的 ResourceVariable,而不是定义明确的语义。默认为 False(稍后将更改为 True)。当启用即刻执行时,此参数始终强制为 True。
  • custom_getter 可调用的,它将 true getter 作为第一个参数,并允许覆盖内部 get_variable 方法。的签名custom_getter应该与此方法匹配,但最面向未来的版本将允许更改:def custom_getter(getter, *args, **kwargs).直接访问所有get_variable参数也是允许的:def custom_getter(getter, name, *args, **kwargs).一个简单的身份自定义获取器,它简单地创建具有修改名称的变量是:
    def custom_getter(getter, name, *args, **kwargs):
      return getter(name + '_suffix', *args, **kwargs)
  • constraint Optimizer 更新后应用于变量的可选投影函数(例如,用于实现层权重的范数约束或值约束)。该函数必须将表示变量值的未投影张量作为输入,并返回投影值的张量(必须具有相同的形状)。在进行异步分布式训练时使用约束是不安全的。
  • synchronization 指示何时聚合分布式变量。接受的值是在类 tf.VariableSynchronization 中定义的常量。默认情况下,同步设置为AUTO,当前的DistributionStrategy 选择何时同步。
  • aggregation 指示如何聚合分布式变量。接受的值是在类 tf.VariableAggregation 中定义的常量。

返回

  • 创建的或现有的 Variable (或 PartitionedVariable ,如果使用了分区程序)。

抛出

  • ValueError 当创建新变量且未声明形状时,在变量创建期间违反重用时,或当 initializer dtype 和 dtype 不匹配时。重用设置在 variable_scope 内。

迁移到 TF2

警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。

虽然它是一个遗留的compat.v1 api,但tf.compat.v1.get_variable 主要与即刻执行和tf.function 兼容,但前提是你将它与tf.compat.v1.keras.utils.track_tf1_style_variables 装饰器结合使用。 (尽管它会表现得好像总是将重用设置为 AUTO_REUSE 。)

有关详细信息,请参阅模型迁移指南。

如果您不将其与 tf.compat.v1.keras.utils.track_tf1_style_variables 结合使用,则 get_variable 将在每次调用时创建一个全新的变量,并且永远不会重用变量,无论变量名称或 reuse 参数如何。

此符号的 TF2 等效项是 tf.Variable ,但请注意,使用 tf.Variable 时,您必须确保手动或通过 tf.Moduletf.keras.layers.Layer 机制跟踪变量(和正则化参数)。

迁移指南的一部分还提供了有关将这些用法增量迁移到 tf.Variable 的更多详细信息。

注意: partitionerarg 与 TF2 行为不兼容,即使使用tf.compat.v1.keras.utils.track_tf1_style_variables.它可以通过使用来替换ParameterServerStrategy及其分区器。见multi-gpu 迁移指南并且 ParameterServerStrategy 指导它参考以获取更多信息。

此函数以当前变量范围作为名称的前缀并执行重用检查。有关重用工作原理的详细说明,请参阅如何使用变量范围。这是一个基本示例:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2

如果初始化程序是None(默认值),则将使用在变量范围内传递的默认初始化程序。如果那个也是None,则将使用glorot_uniform_initializer。初始化器也可以是张量,在这种情况下,变量被初始化为此值和形状。

类似地,如果正则化器是None(默认),则将使用在变量范围中传递的默认正则化器(如果也是None,则默认情况下不执行正则化)。

如果提供了分区程序,则返回 PartitionedVariable。作为 Tensor 访问此对象会返回沿分区轴连接的分片。

一些有用的分区器可用。例如,参见 variable_axis_size_partitionermin_max_variable_partitioner

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.get_variable。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。