当前位置: 首页>>代码示例>>Python>>正文


Python meta_graph.read_meta_graph_file方法代码示例

本文整理汇总了Python中tensorflow.python.framework.meta_graph.read_meta_graph_file方法的典型用法代码示例。如果您正苦于以下问题:Python meta_graph.read_meta_graph_file方法的具体用法?Python meta_graph.read_meta_graph_file怎么用?Python meta_graph.read_meta_graph_file使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.framework.meta_graph的用法示例。


在下文中一共展示了meta_graph.read_meta_graph_file方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: initialize_variables

# 需要导入模块: from tensorflow.python.framework import meta_graph [as 别名]
# 或者: from tensorflow.python.framework.meta_graph import read_meta_graph_file [as 别名]
def initialize_variables(self, save_file=None):
        self.session.run(tf.global_variables_initializer())
        if save_file is not None:
            try:
                self.saver.restore(self.session, save_file)
            except:
                # some wizardry here... basically, only restore variables
                # that are in the save file; otherwise, initialize them normally.
                from tensorflow.python.framework import meta_graph
                meta_graph_def = meta_graph.read_meta_graph_file(save_file + '.meta')
                stored_var_names = set([n.name
                    for n in meta_graph_def.graph_def.node
                    if n.op == 'VariableV2'])
                print(stored_var_names)
                var_list = [v for v in tf.global_variables()
                    if v.op.name in stored_var_names]
                # initialize all of the variables
                self.session.run(tf.global_variables_initializer())
                # then overwrite the ones we have in the save file
                # by using a throwaway saver, saved models are automatically
                # "upgraded" to the latest graph definition.
                throwaway_saver = tf.train.Saver(var_list=var_list)
                throwaway_saver.restore(self.session, save_file) 
开发者ID:llSourcell,项目名称:alphago_demo,代码行数:25,代码来源:policy.py

示例2: __init__

# 需要导入模块: from tensorflow.python.framework import meta_graph [as 别名]
# 或者: from tensorflow.python.framework.meta_graph import read_meta_graph_file [as 别名]
def __init__ (self, X, KEEP, view, name, dir_path, node='logits:0', softmax=True):
        self.name = name
        self.view = view
        paths = glob(os.path.join(dir_path, '*.meta'))
        assert len(paths) == 1
        path = os.path.splitext(paths[0])[0]
        mg = meta_graph.read_meta_graph_file(path + '.meta')
        if KEEP is None:
            fts, = tf.import_graph_def(mg.graph_def, name=name,
                    input_map={'images:0':X},
                                return_elements=[node])
        else:
            fts, = tf.import_graph_def(mg.graph_def, name=name,
                    input_map={'images:0':X, 'keep:0':KEEP},
                                return_elements=[node])
        if softmax:
            fts = logits2prob(fts)
        self.fts = fts
        self.saver = tf.train.Saver(saver_def=mg.saver_def, name=name)
        self.loader = lambda sess: self.saver.restore(sess, path)
        pass 
开发者ID:aaalgo,项目名称:plumo,代码行数:23,代码来源:process.py

示例3: __init__

# 需要导入模块: from tensorflow.python.framework import meta_graph [as 别名]
# 或者: from tensorflow.python.framework.meta_graph import read_meta_graph_file [as 别名]
def __init__ (self, X, KEEP, view, name, dir_path, node='logits:0', softmax=True):
        self.name = name
        self.view = view
        print dir_path
        paths = glob(os.path.join(dir_path, '*.meta'))
        print paths
        assert len(paths) == 1
        path = os.path.splitext(paths[0])[0]
        mg = meta_graph.read_meta_graph_file(path + '.meta')
        if KEEP is None:
            fts, = tf.import_graph_def(mg.graph_def, name=name,
                    input_map={'images:0':X},
                                return_elements=[node])
        else:
            fts, = tf.import_graph_def(mg.graph_def, name=name,
                    input_map={'images:0':X, 'keep:0':KEEP},
                                return_elements=[node])
        if softmax:
            fts = logits2prob(fts)
        self.fts = fts
        self.saver = tf.train.Saver(saver_def=mg.saver_def, name=name)
        self.loader = lambda sess: self.saver.restore(sess, path)
        pass 
开发者ID:aaalgo,项目名称:plumo,代码行数:25,代码来源:adsb3_cache_ft.py

示例4: _load_saved_model_from_session_bundle_path

# 需要导入模块: from tensorflow.python.framework import meta_graph [as 别名]
# 或者: from tensorflow.python.framework.meta_graph import read_meta_graph_file [as 别名]
def _load_saved_model_from_session_bundle_path(export_dir, target, config):
  """Load legacy TF Exporter/SessionBundle checkpoint.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
    tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    metagraph_def: The `MetaGraphDef` protocol buffer loaded in the provided
    session. This can be used to further extract signature-defs,
    collection-defs, etc.
    This model is up-converted to SavedModel format. Specifically, metagraph_def
    SignatureDef field is populated with Signatures converted from legacy
    signatures contained within CollectionDef

  Raises:
    RuntimeError: If metagraph already contains signature_def and cannot be
    up-converted.
  """

  meta_graph_filename = os.path.join(export_dir,
                                     legacy_constants.META_GRAPH_DEF_FILENAME)

  metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
  if metagraph_def.signature_def:
    raise RuntimeError("Legacy graph contains signature def, unable to "
                       "up-convert.")

  # Add SignatureDef to metagraph.
  default_signature_def, named_signature_def = (
      _convert_signatures_to_signature_defs(metagraph_def))
  if default_signature_def:
    metagraph_def.signature_def[
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].CopyFrom(
            default_signature_def)
  if named_signature_def:
    signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    if default_signature_def:
      signature_def_key += "_from_named"
    metagraph_def.signature_def[signature_def_key].CopyFrom(named_signature_def)

  # We cannot just output session we loaded with older metagraph_def and
  # up-converted metagraph definition because Session has an internal object of
  # type Graph which is populated from meta_graph_def. If we do not create
  # session with our new meta_graph_def, then Graph will be out of sync with
  # meta_graph_def.
  sess, metagraph_def = session_bundle.load_session_bundle_from_path(
      export_dir, target, config, meta_graph_def=metagraph_def)
  return sess, metagraph_def 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:54,代码来源:bundle_shim.py

示例5: testConvertSignaturesToSignatureDefs

# 需要导入模块: from tensorflow.python.framework import meta_graph [as 别名]
# 或者: from tensorflow.python.framework.meta_graph import read_meta_graph_file [as 别名]
def testConvertSignaturesToSignatureDefs(self):
    base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH)
    meta_graph_filename = os.path.join(base_path,
                                       constants.META_GRAPH_DEF_FILENAME)
    metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(default_signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(len(default_signature_def.inputs), 1)
    self.assertEqual(len(default_signature_def.outputs), 1)
    self.assertProtoEquals(
        default_signature_def.inputs[signature_constants.REGRESS_INPUTS],
        meta_graph_pb2.TensorInfo(name="tf_example:0"))
    self.assertProtoEquals(
        default_signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
        meta_graph_pb2.TensorInfo(name="Identity:0"))
    self.assertEqual(named_signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(len(named_signature_def.inputs), 1)
    self.assertEqual(len(named_signature_def.outputs), 1)
    self.assertProtoEquals(
        named_signature_def.inputs["x"], meta_graph_pb2.TensorInfo(name="x:0"))
    self.assertProtoEquals(
        named_signature_def.outputs["y"], meta_graph_pb2.TensorInfo(name="y:0"))

    # Now try default signature only
    collection_def = metagraph_def.collection_def
    signatures_proto = manifest_pb2.Signatures()
    signatures = collection_def[constants.SIGNATURES_KEY].any_list.value[0]
    signatures.Unpack(signatures_proto)
    named_only_signatures_proto = manifest_pb2.Signatures()
    named_only_signatures_proto.CopyFrom(signatures_proto)

    default_only_signatures_proto = manifest_pb2.Signatures()
    default_only_signatures_proto.CopyFrom(signatures_proto)
    default_only_signatures_proto.named_signatures.clear()
    default_only_signatures_proto.ClearField("named_signatures")
    metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
        0].Pack(default_only_signatures_proto)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(default_signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(named_signature_def, None)

    named_only_signatures_proto.ClearField("default_signature")
    metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
        0].Pack(named_only_signatures_proto)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(named_signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(default_signature_def, None) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:56,代码来源:bundle_shim_test.py


注:本文中的tensorflow.python.framework.meta_graph.read_meta_graph_file方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。