當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.get_variable用法及代碼示例


使用這些參數獲取現有變量或創建一個新變量。

用法

tf.compat.v1.get_variable(
    name, shape=None, dtype=None, initializer=None, regularizer=None,
    trainable=None, collections=None, caching_device=None, partitioner=None,
    validate_shape=True, use_resource=None, custom_getter=None, constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.compat.v1.VariableAggregation.NONE
)

參數

  • name 新變量或現有變量的名稱。
  • shape 新變量或現有變量的形狀。
  • dtype 新變量或現有變量的類型(默認為 DT_FLOAT )。
  • initializer 變量的初始化器(如果已創建)。可以是初始化對象或張量。如果它是張量,則必須知道其形狀,除非 validate_shape 為 False。
  • regularizer 一個(張量 -> 張量或無)函數;將其應用於新創建的變量的結果將添加到集合 tf.GraphKeys.REGULARIZATION_LOSSES 中,並可用於正則化。
  • trainable 如果 True 還將變量添加到圖形集合 GraphKeys.TRAINABLE_VARIABLES (請參閱 tf.Variable )。
  • collections 要將變量添加到的圖形集合鍵列表。默認為 [GraphKeys.GLOBAL_VARIABLES](參見 tf.Variable )。
  • caching_device 可選的設備字符串或函數,說明應緩存變量以供讀取的位置。默認為變量的設備。如果不是 None ,則緩存在另一台設備上。典型用途是在使用變量的操作所在的設備上進行緩存,通過Switch 和其他條件語句進行重複數據刪除。
  • partitioner 可選的可調用函數,接受要創建的變量的完全定義的TensorShapedtype,並返回每個軸的分區列表(目前隻能對一個軸進行分區)。
  • validate_shape 如果為 False,則允許使用未知形狀的值初始化變量。如果為 True(默認),則必須知道 initial_value 的形狀。為此,初始化器必須是張量而不是初始化器對象。
  • use_resource 如果為 False,則創建一個常規變量。如果為 true,則創建一個實驗性的 ResourceVariable,而不是定義明確的語義。默認為 False(稍後將更改為 True)。當啟用即刻執行時,此參數始終強製為 True。
  • custom_getter 可調用的,它將 true getter 作為第一個參數,並允許覆蓋內部 get_variable 方法。的簽名custom_getter應該與此方法匹配,但最麵向未來的版本將允許更改:def custom_getter(getter, *args, **kwargs).直接訪問所有get_variable參數也是允許的:def custom_getter(getter, name, *args, **kwargs).一個簡單的身份自定義獲取器,它簡單地創建具有修改名稱的變量是:
    def custom_getter(getter, name, *args, **kwargs):
      return getter(name + '_suffix', *args, **kwargs)
  • constraint Optimizer 更新後應用於變量的可選投影函數(例如,用於實現層權重的範數約束或值約束)。該函數必須將表示變量值的未投影張量作為輸入,並返回投影值的張量(必須具有相同的形狀)。在進行異步分布式訓練時使用約束是不安全的。
  • synchronization 指示何時聚合分布式變量。接受的值是在類 tf.VariableSynchronization 中定義的常量。默認情況下,同步設置為AUTO,當前的DistributionStrategy 選擇何時同步。
  • aggregation 指示如何聚合分布式變量。接受的值是在類 tf.VariableAggregation 中定義的常量。

返回

  • 創建的或現有的 Variable (或 PartitionedVariable ,如果使用了分區程序)。

拋出

  • ValueError 當創建新變量且未聲明形狀時,在變量創建期間違反重用時,或當 initializer dtype 和 dtype 不匹配時。重用設置在 variable_scope 內。

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

雖然它是一個遺留的compat.v1 api,但tf.compat.v1.get_variable 主要與即刻執行和tf.function 兼容,但前提是你將它與tf.compat.v1.keras.utils.track_tf1_style_variables 裝飾器結合使用。 (盡管它會表現得好像總是將重用設置為 AUTO_REUSE 。)

有關詳細信息,請參閱模型遷移指南。

如果您不將其與 tf.compat.v1.keras.utils.track_tf1_style_variables 結合使用,則 get_variable 將在每次調用時創建一個全新的變量,並且永遠不會重用變量,無論變量名稱或 reuse 參數如何。

此符號的 TF2 等效項是 tf.Variable ,但請注意,使用 tf.Variable 時,您必須確保手動或通過 tf.Moduletf.keras.layers.Layer 機製跟蹤變量(和正則化參數)。

遷移指南的一部分還提供了有關將這些用法增量遷移到 tf.Variable 的更多詳細信息。

注意: partitionerarg 與 TF2 行為不兼容,即使使用tf.compat.v1.keras.utils.track_tf1_style_variables.它可以通過使用來替換ParameterServerStrategy及其分區器。見multi-gpu 遷移指南並且 ParameterServerStrategy 指導它參考以獲取更多信息。

此函數以當前變量範圍作為名稱的前綴並執行重用檢查。有關重用工作原理的詳細說明,請參閱如何使用變量範圍。這是一個基本示例:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2

如果初始化程序是None(默認值),則將使用在變量範圍內傳遞的默認初始化程序。如果那個也是None,則將使用glorot_uniform_initializer。初始化器也可以是張量,在這種情況下,變量被初始化為此值和形狀。

類似地,如果正則化器是None(默認),則將使用在變量範圍中傳遞的默認正則化器(如果也是None,則默認情況下不執行正則化)。

如果提供了分區程序,則返回 PartitionedVariable。作為 Tensor 訪問此對象會返回沿分區軸連接的分片。

一些有用的分區器可用。例如,參見 variable_axis_size_partitionermin_max_variable_partitioner

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.get_variable。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。