当前位置: 首页>>代码示例>>Python>>正文


Python meta_graph.export_scoped_meta_graph函数代码示例

本文整理汇总了Python中tensorflow.python.framework.meta_graph.export_scoped_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python export_scoped_meta_graph函数的具体用法?Python export_scoped_meta_graph怎么用?Python export_scoped_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了export_scoped_meta_graph函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _testExportImportAcrossScopes

  def _testExportImportAcrossScopes(self, graph_fn):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn()
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn()

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
    self.assertProtoEquals(expected, result)
开发者ID:SylChan,项目名称:tensorflow,代码行数:25,代码来源:meta_graph_test.py

示例2: testScopedImportWithSelectedCollections

  def testScopedImportWithSelectedCollections(self):
    meta_graph_filename = os.path.join(
        _TestDir("selected_collections_import"), "meta_graph.pb")

    graph = ops.Graph()
    # Add a variable to populate two collections. The functionality tested is
    # not specific to variables, but using variables in the test is convenient.
    with graph.as_default():
      variables.Variable(initial_value=1.0, trainable=True)
    self.assertTrue(
        all([
            graph.get_collection(key)
            for key in
            [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
        ]))
    meta_graph.export_scoped_meta_graph(
        filename=meta_graph_filename, graph=graph)

    def _test_import(include_collection_keys, omit_collection_keys):
      assert set(include_collection_keys).isdisjoint(omit_collection_keys)
      newgraph = ops.Graph()
      import_scope = "some_scope_name"

      def _restore_collections_predicate(collection_key):
        return (collection_key in include_collection_keys and
                collection_key not in omit_collection_keys)

      meta_graph.import_scoped_meta_graph(
          meta_graph_filename,
          graph=newgraph,
          import_scope=import_scope,
          restore_collections_predicate=_restore_collections_predicate)
      collection_values = [
          newgraph.get_collection(name=key, scope=import_scope)
          for key in include_collection_keys
      ]
      self.assertTrue(all(collection_values))
      collection_values = [
          newgraph.get_collection(name=key, scope=import_scope)
          for key in omit_collection_keys
      ]
      self.assertFalse(any(collection_values))

    _test_import(
        include_collection_keys=[
            ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
        ],
        omit_collection_keys=[])
    _test_import(
        include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
        omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
    _test_import(
        include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
        omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
    _test_import(
        include_collection_keys=[],
        omit_collection_keys=[
            ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
        ])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:59,代码来源:meta_graph_test.py

示例3: testClearDevices

  def testClearDevices(self):
    graph1 = ops.Graph()
    with graph1.as_default():
      with ops.device("/device:CPU:0"):
        a = variables.Variable(
            constant_op.constant(
                1.0, shape=[2, 2]), name="a")
      with ops.device("/job:ps/replica:0/task:0/gpu:0"):
        b = variables.Variable(
            constant_op.constant(
                2.0, shape=[2, 2]), name="b")
      with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
        math_ops.matmul(a, b, name="matmul")

    self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
    self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
                     str(graph1.as_graph_element("b").device))
    self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
                     str(graph1.as_graph_element("matmul").device))

    # Verifies that devices are cleared on export.
    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        graph=graph1, clear_devices=True)

    graph2 = ops.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))

    # Verifies that devices are cleared on export when passing in graph_def.
    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        graph_def=graph1.as_graph_def(), clear_devices=True)

    graph2 = ops.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))

    # Verifies that devices are cleared on import.
    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        graph=graph1, clear_devices=False)

    graph2 = ops.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:55,代码来源:meta_graph_test.py

示例4: testMetricsCollection

  def testMetricsCollection(self):

    def _enqueue_vector(sess, queue, values, shape=None):
      if not shape:
        shape = (1, len(values))
      dtype = queue.dtypes[0]
      sess.run(
          queue.enqueue(constant_op.constant(
              values, dtype=dtype, shape=shape)))

    meta_graph_filename = os.path.join(
        _TestDir("metrics_export"), "meta_graph.pb")

    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      _, update_op = metrics.mean(values)

      initializer = variables.local_variables_initializer()
      self.evaluate(initializer)
      self.evaluate(update_op)

    meta_graph.export_scoped_meta_graph(
        filename=meta_graph_filename, graph=graph)

    # Verifies that importing a meta_graph with LOCAL_VARIABLES collection
    # works correctly.
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      meta_graph.import_scoped_meta_graph(meta_graph_filename)
      initializer = variables.local_variables_initializer()
      self.evaluate(initializer)

    # Verifies that importing an old meta_graph where "local_variables"
    # collection is of node_list type works, but cannot build initializer
    # with the collection.
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      meta_graph.import_scoped_meta_graph(
          test.test_src_dir_path(
              "python/framework/testdata/metrics_export_meta_graph.pb"))
      self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
                       2)
      with self.assertRaisesRegexp(
          AttributeError, "'Tensor' object has no attribute 'initializer'"):
        initializer = variables.local_variables_initializer()
开发者ID:aeverall,项目名称:tensorflow,代码行数:53,代码来源:meta_graph_test.py

示例5: testNoVariables

  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = ops.Graph()
    with self.session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = array_ops.placeholder(
          dtypes.float32, shape=[], name="input")
      offset = constant_op.constant(42, dtype=dtypes.float32, name="offset")
      output_tensor = math_ops.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      ops.add_to_collection("input_tensor", input_tensor)
      ops.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
          filename=filename,
          graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None)
      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
                          "")
      self.assertEqual({}, var_list)

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = ops.Graph()
    with self.session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      meta_graph.import_scoped_meta_graph(filename)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
                                                                  "_new")
      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
                                               new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = ops.get_collection("input_tensor")[0]
      new_output_tensor = ops.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(new_output_tensor,
                                  {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:52,代码来源:meta_graph_test.py

示例6: testDefaultAttrStripping

  def testDefaultAttrStripping(self):
    """Verifies that default attributes are stripped from a graph def."""

    # Complex Op has 2 attributes with defaults:
    #   o "T"    : float32.
    #   o "Tout" : complex64.

    # When inputs to the Complex Op are float32 instances, "T" maps to float32
    # and "Tout" maps to complex64. Since these attr values map to their
    # defaults, they must be stripped unless stripping of default attrs is
    # disabled.
    with self.cached_session():
      real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")

      # strip_default_attrs is enabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertNotIn("T", node_def.attr)
      self.assertNotIn("Tout", node_def.attr)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)

      # strip_default_attrs is disabled.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=False)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertIn("T", node_def.attr)
      self.assertIn("Tout", node_def.attr)
      self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)

    # When inputs to the Complex Op are float64 instances, "T" maps to float64
    # and "Tout" maps to complex128. Since these attr values don't map to their
    # defaults, they must not be stripped.
    with self.session(graph=ops.Graph()):
      real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
      imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)
      node_def = test_util.get_node_def_from_graph("complex",
                                                   meta_graph_def.graph_def)
      self.assertEqual(node_def.attr["T"].type, dtypes.float64)
      self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
开发者ID:aeverall,项目名称:tensorflow,代码行数:51,代码来源:meta_graph_test.py

示例7: testWhileLoopGradients

  def testWhileLoopGradients(self):
    # Create a simple while loop.
    with ops.Graph().as_default():
      with ops.name_scope("export"):
        var = variables.Variable(0.)
        var_name = var.name
        _, output = control_flow_ops.while_loop(
            lambda i, x: i < 5,
            lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)),
            [0, var])
        output_name = output.name

      # Generate a MetaGraphDef containing the while loop with an export scope.
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          export_scope="export")

      # Build and run the gradients of the while loop. We use this below to
      # verify that the gradients are correct with the imported MetaGraphDef.
      init_op = variables.global_variables_initializer()
      grad = gradients_impl.gradients([output], [var])
      with session.Session() as sess:
        self.evaluate(init_op)
        expected_grad_value = self.evaluate(grad)

    # Restore the MetaGraphDef into a new Graph with an import scope.
    with ops.Graph().as_default():
      meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import")

      # Re-export and make sure we get the same MetaGraphDef.
      new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          export_scope="import")
      test_util.assert_meta_graph_protos_equal(
          self, meta_graph_def, new_meta_graph_def)

      # Make sure we can still build gradients and get the same result.

      def new_name(tensor_name):
        base_tensor_name = tensor_name.replace("export/", "")
        return "import/" + base_tensor_name

      var = ops.get_default_graph().get_tensor_by_name(new_name(var_name))
      output = ops.get_default_graph().get_tensor_by_name(new_name(output_name))
      grad = gradients_impl.gradients([output], [var])

      init_op = variables.global_variables_initializer()

      with session.Session() as sess:
        self.evaluate(init_op)
        actual_grad_value = self.evaluate(grad)
        self.assertEqual(expected_grad_value, actual_grad_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:50,代码来源:meta_graph_test.py

示例8: _testExportImportAcrossScopes

  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

      if use_resource:
        # Bringing in a collection that contains ResourceVariables adds ops
        # to the graph, so mimic the same behavior.
        for collection_key in sorted([
            ops.GraphKeys.GLOBAL_VARIABLES,
            ops.GraphKeys.TRAINABLE_VARIABLES,
        ]):
          for var in expected_graph.get_collection(collection_key):
            var._read_variable_op()

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # supposed to be orthogonal to scopes.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    self.assertProtoEquals(expected, result)
开发者ID:autodrive,项目名称:tensorflow,代码行数:48,代码来源:meta_graph_test.py

示例9: testClearDevices

  def testClearDevices(self):
    graph1 = tf.Graph()
    with graph1.as_default():
      with tf.device("/device:CPU:0"):
        a = tf.Variable(tf.constant(1.0, shape=[2, 2]), name="a")
      with tf.device("/job:ps/replica:0/task:0/gpu:0"):
        b = tf.Variable(tf.constant(2.0, shape=[2, 2]), name="b")
      with tf.device("/job:localhost/replica:0/task:0/cpu:0"):
        tf.matmul(a, b, name="matmul")

    self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
    self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
                     str(graph1.as_graph_element("b").device))
    self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
                     str(graph1.as_graph_element("matmul").device))

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)

    graph2 = tf.Graph()
    with graph2.as_default():
      meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)

    self.assertEqual("", str(graph2.as_graph_element("a").device))
    self.assertEqual("", str(graph2.as_graph_element("b").device))
    self.assertEqual("", str(graph2.as_graph_element("matmul").device))
开发者ID:caikehe,项目名称:tensorflow,代码行数:25,代码来源:meta_graph_test.py

示例10: testPotentialCycle

  def testPotentialCycle(self):
    graph1 = ops.Graph()
    with graph1.as_default():
      a = constant_op.constant(1.0, shape=[2, 2])
      b = constant_op.constant(2.0, shape=[2, 2])
      matmul = math_ops.matmul(a, b)
      with ops.name_scope("hidden1"):
        c = nn_ops.relu(matmul)
        d = constant_op.constant(3.0, shape=[2, 2])
        matmul = math_ops.matmul(c, d)

    orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
        export_scope="hidden1", graph=graph1)

    graph2 = ops.Graph()
    with graph2.as_default():
      with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
        meta_graph.import_scoped_meta_graph(
            orig_meta_graph, import_scope="new_hidden1")

      meta_graph.import_scoped_meta_graph(
          orig_meta_graph,
          import_scope="new_hidden1",
          input_map={
              "$unbound_inputs_MatMul": constant_op.constant(
                  4.0, shape=[2, 2])
          })
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:27,代码来源:meta_graph_test.py

示例11: testDefaultAttrStrippingNestedFunctions

  def testDefaultAttrStrippingNestedFunctions(self):
    """Verifies that default attributes are stripped from function node defs."""
    with self.cached_session():

      @function.Defun(dtypes.float32, dtypes.float32)
      def f0(i, j):
        return math_ops.complex(i, j, name="double_nested_complex")

      @function.Defun(dtypes.float32, dtypes.float32)
      def f1(i, j):
        return f0(i, j)

      _ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
      meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
          graph_def=ops.get_default_graph().as_graph_def(),
          strip_default_attrs=True)

      double_nested_complex_node_def = None
      for function_def in meta_graph_def.graph_def.library.function:
        for node_def in function_def.node_def:
          if node_def.name.startswith("double_nested_complex"):
            double_nested_complex_node_def = node_def
            break
        if double_nested_complex_node_def:
          break

      self.assertIsNotNone(double_nested_complex_node_def)
      self.assertNotIn("T", double_nested_complex_node_def.attr)
      self.assertNotIn("Tout", double_nested_complex_node_def.attr)
      self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:meta_graph_test.py

示例12: testSummaryWithFamilyMetaGraphExport

  def testSummaryWithFamilyMetaGraphExport(self):
    with ops.name_scope('outer'):
      i = constant_op.constant(11)
      summ = summary_lib.scalar('inner', i)
      self.assertEquals(summ.op.name, 'outer/inner')
      summ_f = summary_lib.scalar('inner', i, family='family')
      self.assertEquals(summ_f.op.name, 'outer/family/inner')

    metagraph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='outer')

    with ops.Graph().as_default() as g:
      meta_graph.import_scoped_meta_graph(metagraph_def, graph=g,
                                          import_scope='new_outer')
      # The summaries should exist, but with outer scope renamed.
      new_summ = g.get_tensor_by_name('new_outer/inner:0')
      new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')

      # However, the tags are unaffected.
      with self.cached_session() as s:
        new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
        new_summ_pb = summary_pb2.Summary()
        new_summ_pb.ParseFromString(new_summ_str)
        self.assertEquals('outer/inner', new_summ_pb.value[0].tag)
        new_summ_f_pb = summary_pb2.Summary()
        new_summ_f_pb.ParseFromString(new_summ_f_str)
        self.assertEquals('family/outer/family/inner',
                          new_summ_f_pb.value[0].tag)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:27,代码来源:summary_test.py

示例13: testImportsUsingSameScopeName

 def testImportsUsingSameScopeName(self):
   with ops.Graph().as_default():
     variables.Variable(0, name="v")
     meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
   with ops.Graph().as_default():
     for suffix in ["", "_1"]:
       imported_variables = meta_graph.import_scoped_meta_graph(
           meta_graph_def, import_scope="s")
       self.assertEqual(len(imported_variables), 1)
       self.assertEqual(list(imported_variables.keys())[0], "v:0")
       self.assertEqual(list(imported_variables.values())[0].name,
                        "s" + suffix + "/v:0")
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:meta_graph_test.py

示例14: _testExportImportAcrossScopes

  def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    """Tests export and importing a graph across scopes.

    Args:
      graph_fn: A closure that creates a graph on the current scope.
      use_resource: A bool indicating whether or not to use ResourceVariables.
    """
    with ops.Graph().as_default() as original_graph:
      with variable_scope.variable_scope("dropA/dropB/keepA"):
        graph_fn(use_resource=use_resource)
    exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
        graph=original_graph,
        export_scope="dropA/dropB")[0]

    with ops.Graph().as_default() as imported_graph:
      meta_graph.import_scoped_meta_graph(
          exported_meta_graph_def,
          import_scope="importA")

    with ops.Graph().as_default() as expected_graph:
      with variable_scope.variable_scope("importA/keepA"):
        graph_fn(use_resource=use_resource)

    result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]

    if use_resource:
      # Clear all shared_name attributes before comparing, since they are
      # orthogonal to scopes and are not updated on export/import.
      for meta_graph_def in [result, expected]:
        for node in meta_graph_def.graph_def.node:
          shared_name_attr = "shared_name"
          shared_name_value = node.attr.get(shared_name_attr, None)
          if shared_name_value and shared_name_value.HasField("s"):
            if shared_name_value.s:
              node.attr[shared_name_attr].s = b""

    test_util.assert_meta_graph_protos_equal(self, expected, result)
开发者ID:aeverall,项目名称:tensorflow,代码行数:38,代码来源:meta_graph_test.py

示例15: testScopedImportUnderNameScope

  def testScopedImportUnderNameScope(self):
    graph = ops.Graph()
    with graph.as_default():
      variables.Variable(initial_value=1.0, trainable=True, name="myvar")
    meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)

    graph = ops.Graph()
    with graph.as_default():
      with ops.name_scope("foo"):
        imported_variables = meta_graph.import_scoped_meta_graph(
            meta_graph_def, import_scope="bar")
        self.assertEqual(len(imported_variables), 1)
        self.assertEqual(list(imported_variables.values())[0].name,
                         "foo/bar/myvar:0")
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:14,代码来源:meta_graph_test.py


注:本文中的tensorflow.python.framework.meta_graph.export_scoped_meta_graph函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。