当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.train.import_meta_graph用法及代码示例


重新创建保存在 MetaGraphDef 原型中的图形。

用法

tf.compat.v1.train.import_meta_graph(
    meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs
)

参数

  • meta_graph_or_file MetaGraphDef 协议缓冲区或包含 MetaGraphDef 的文件名(包括路径)。
  • clear_devices 导入期间是否清除OperationTensor 的设备字段。
  • import_scope 可选 string 。要添加的名称范围。仅在从协议缓冲区初始化时使用。
  • **kwargs 可选的键控参数。

返回

  • 一个由saver_defMetaGraphDef或无。

    如果 MetaGraphDef 中不存在变量(即没有要恢复的变量),则返回 None 值。

抛出

  • RuntimeError 如果在启用即刻执行的情况下调用。

此函数将 MetaGraphDef 协议缓冲区作为输入。如果参数是包含 MetaGraphDef 协议缓冲区的文件,它会根据文件内容构造协议缓冲区。然后,该函数将 graph_def 字段中的所有节点添加到当前图形,重新创建所有集合,并返回从 saver_def 字段构造的保护程序。

export_meta_graph() 结合使用,此函数可用于

  • 将图形与其他 Python 对象(例如 QueueRunner , Variable)序列化为 MetaGraphDef

  • 从保存的图表和检查点重新开始训练。

  • 从保存的图形和检查点运行推理。

...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.compat.v1.add_to_collection('train_op', train_op)
sess = tf.compat.v1.Session()
for step in range(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # Saves checkpoint, which by default also exports a meta_graph
        # named 'my-model-global_step.meta'.
        saver.save(sess, 'my-model', global_step=step)

稍后我们可以从保存的 meta_graph 继续训练,而无需从头开始构建模型。

with tf.Session() as sess:
  new_saver =
  tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() returns a list. In this example we only want
  # the first one.
  train_op = tf.get_collection('train_op')[0]
  for step in range(1000000):
    sess.run(train_op)

注意:从保存的 meta_graph 重新开始训练仅在设备分配未更改的情况下有效。

例子:

还可以存储变量、占位符和独立操作,如下例所示。

# Saving contents and operations.
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.math.multiply(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")

稍后可以恢复此模型并加载内容。

# Restoring variables and running operations.
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0":12.0, "v2:0":3.3})
print(result)

eager模式兼容性

不支持导出/导入元图。启用即刻执行时不存在图表。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.import_meta_graph。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。