当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.train.init_from_checkpoint用法及代码示例


替换 tf.Variable 初始化程序,以便它们从检查点文件加载。

用法

tf.compat.v1.train.init_from_checkpoint(
    ckpt_dir_or_file, assignment_map
)

参数

  • ckpt_dir_or_file 带有检查点文件或检查点路径的目录。
  • assignment_map 字典或键值对列表,其中键是检查点中变量的名称,值是当前变量或当前变量的名称(在默认图中)。

抛出

  • ValueError 如果当前图中缺少变量,或者检查点中缺少检查点或张量。

迁移到 TF2

警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。

tf.compat.v1.train.init_from_checkpoint 不推荐用于恢复 TF2 中的变量值。

要恢复 TF2 中的检查点,请使用 tf.keras.Model.load_weightstf.train.Checkpoint.restore 。这些 API 使用基于对象的检查点方法,而 tf.compat.v1.init_from_checkpoint 依赖于基于 more-fragile variable-name 的检查点方法。 TF2 中没有基于对象的等效项init_from_checkpoint

请立即使用基于对象的 API 重写您的检查点,有关详细信息,请参阅迁移指南。

您可以使用 tf.train.Checkpoint.restoretf.keras.Model.load_weights 加载由 tf.compat.v1.train.Saver 编写的基于名称的检查点。但是,您可能必须更改模型中的变量名称以匹配基于名称的检查点中的变量名称,可以使用 tf.train.list_variables(path) 查看。

另一种选择是创建一个assignment_map,将基于名称的检查点中的变量名称映射到模型中的变量,例如:

{
    'sequential/dense/bias':model.variables[0],
    'sequential/dense/kernel':model.variables[1]
}

并使用tf.compat.v1.train.init_from_checkpoint恢复基于名称的检查点。

恢复后,使用 tf.train.Checkpoint.savetf.keras.Model.save_weights 重新编码您的检查点。

值不会立即加载,而是在初始化程序运行时加载(通常通过运行 tf.compat.v1.global_variables_initializer 操作)。

注意:这会覆盖指定变量的默认初始化操作并重新定义 dtype。

分配映射支持以下语法:

  • 'checkpoint_scope_name/':'scope_name/' - 将从checkpoint_scope_name 加载当前scope_name 中具有匹配张量名称的所有变量。
  • 'checkpoint_scope_name/some_other_variable':'scope_name/variable_name' - 将从 checkpoint_scope_name/some_other_variable 初始化 scope_name/variable_name 变量。
  • 'scope_variable_name':variable - 将从检查点使用张量 'scope_variable_name' 初始化给定的 tf.Variable 对象。
  • 'scope_variable_name':list(variable) - 将从检查点使用张量 'scope_variable_name' 初始化分区变量列表。
  • '/':'scope_name/' - 将从检查点的根目录加载当前 scope_name 中的所有变量(例如,无范围)。

支持加载到分区变量中,表示为 '<variable>/part_<part #>'

分配映射可以是字典,也可以是对列表。后者对于从检查点中的同一变量初始化当前图中的多个变量是必要的。

例子:

# Say, '/tmp/model.ckpt' has the following tensors:
#  -- name='old_scope_1/var1', shape=[20, 2]
#  -- name='old_scope_1/var2', shape=[50, 4]
#  -- name='old_scope_2/var3', shape=[100, 100]

# Create new model's variables
with tf.compat.v1.variable_scope('new_scope_1'):
  var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
                         initializer=tf.compat.v1.zeros_initializer())
with tf.compat.v1.variable_scope('new_scope_2'):
  var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
                         initializer=tf.compat.v1.zeros_initializer())
  # Partition into 5 variables along the first axis.
  var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
                         initializer=tf.compat.v1.zeros_initializer(),
                         partitioner=lambda shape, dtype:[5, 1])

# Initialize all variables in `new_scope_1` from `old_scope_1`.
init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/':'new_scope_1/'})

# Use names to specify which variables to initialize from checkpoint.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_1/var1':'new_scope_1/var1',
                      'old_scope_1/var2':'new_scope_2/var2'})

# Or use tf.Variable objects to identify what to initialize.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_1/var1':var1,
                      'old_scope_1/var2':var2})

# Initialize partitioned variables using variable's name
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_2/var3':'new_scope_2/var3'})

# Or specify the list of tf.Variable objects.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_2/var3':var3._get_variable_list()})

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.init_from_checkpoint。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。