当前位置: 首页>>代码示例>>Python>>正文


Python graph_editor.sgv方法代码示例

本文整理汇总了Python中tensorflow.contrib.graph_editor.sgv方法的典型用法代码示例。如果您正苦于以下问题:Python graph_editor.sgv方法的具体用法?Python graph_editor.sgv怎么用?Python graph_editor.sgv使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.graph_editor的用法示例。


在下文中一共展示了graph_editor.sgv方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_reroute_can_modify

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_reroute_can_modify(self):
    graph = tf.Graph()
    # create a special graph where "a" is an ambiguous tensor. That is
    # it is both an input and an output of the ops in sgv0.
    with graph.as_default():
      a = tf.constant(1.0, shape=[2], name="a")
      b = tf.constant(2.0, shape=[2], name="b")
      c = tf.add(a, b, name="c")
      d = tf.add(a, c, name="d")

      e = tf.constant(1.0, shape=[2], name="e")
      f = tf.constant(2.0, shape=[2], name="f")
      g = tf.add(e, f, name="g")

    sgv0 = ge.sgv(a.op, b.op, c.op)
    sgv1 = ge.sgv(e.op, f.op)

    ge.reroute.swap_outputs(sgv0, sgv1)
    self.assertTrue(ge.matcher("g").input_ops("a", ge.matcher("c")
                                              .input_ops("a", "b"))(g.op))
    self.assertTrue(ge.matcher("d").input_ops("e", "f")(d.op)) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:23,代码来源:reroute_test.py

示例2: test_subgraph_remap

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_subgraph_remap(self):
    sgv = ge.sgv(self.c.op)
    self.assertEqual(list(sgv.outputs), [self.c])
    self.assertEqual(list(sgv.inputs), [self.a, self.b])

    sgv = sgv.remap_outputs_to_consumers()
    self.assertEqual(list(sgv.outputs), [self.c, self.c, self.c])
    sgv = sgv.remap_outputs_make_unique()
    self.assertEqual(list(sgv.outputs), [self.c])

    sgv = sgv.remap(new_input_indices=[], new_output_indices=[])
    self.assertEqual(len(sgv.inputs), 0)
    self.assertEqual(len(sgv.outputs), 0)
    sgv = sgv.remap_default()
    self.assertEqual(list(sgv.outputs), [self.c])
    self.assertEqual(list(sgv.inputs), [self.a, self.b]) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:18,代码来源:subgraph_test.py

示例3: test_multiswap

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_multiswap(self):
    with self.graph.as_default():
      a3 = tf.constant(3.0, shape=[2], name="a3")
    ge.reroute.swap(ge.sgv(a3.op).remap_outputs([0, 0]),
                    ge.sgv(self.a0.op, self.a1.op))
    self.assertTrue(ge.matcher("c0").input_ops("a3", "b0")(self.c0.op))
    self.assertTrue(ge.matcher("c1").input_ops("a3", "b1")(self.c1.op)) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:9,代码来源:reroute_test.py

示例4: test_subgraph

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_subgraph(self):
    sgv = ge.sgv(self.graph)
    self.assertEqual(list(sgv.outputs), [self.e, self.h])
    self.assertEqual(list(sgv.inputs), [])
    self.assertEqual(len(sgv.ops), 8)

    sgv = ge.sgv(self.f.op, self.g.op)
    self.assertEqual(list(sgv.outputs), [self.f, self.g])
    self.assertEqual(list(sgv.inputs), [self.c, self.d, self.a])

    sgv = ge.sgv_scope("foo/bar", graph=self.graph)
    self.assertEqual(list(sgv.ops),
                     [self.e.op, self.f.op, self.g.op, self.h.op]) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:15,代码来源:subgraph_test.py

示例5: test_remove_unused_ops

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_remove_unused_ops(self):
    sgv = ge.sgv(self.graph)
    self.assertEqual(list(sgv.outputs), [self.e, self.h])
    self.assertEqual(len(sgv.ops), 8)

    sgv = sgv.remap_outputs(new_output_indices=[1]).remove_unused_ops()
    self.assertEqual(list(sgv.outputs), [self.h])
    self.assertEqual(len(sgv.ops), 7) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:10,代码来源:subgraph_test.py

示例6: test_connect

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_connect(self):
    """Test for ge.connect."""
    with self.graph.as_default():
      x = tf.constant([1., 1.], shape=[2], name="x")
      y = tf.constant([2., 2.], shape=[2], name="y")
      z = tf.add(x, y, name="z")

    sgv = ge.sgv(x.op, y.op, z.op)
    ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0]))
    self.assertTrue(ge.matcher("^foo/bar/e$").input_ops("^z$", "foo/d$")
                    (self.e.op)) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:13,代码来源:edit_test.py

示例7: test_bypass

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_bypass(self):
    """Test for ge.bypass."""
    ge.bypass(ge.sgv(self.f.op).remap_inputs([0]))
    self.assertTrue(ge.matcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")
                    (self.h.op)) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:7,代码来源:edit_test.py

示例8: replace_read_ops

# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def replace_read_ops(loss_or_losses, var_list):
    """
    Replaces read ops of each variable in `vars` with new read ops obtained
    from `read_value()`, thus forcing to read the most up-to-date values of
    the variables (which might incur copies across devices).
    The graph is seeded from the tensor(s) `loss_or_losses`.
    """
    # ops between var ops and the loss
    ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses))
    if not ops:  # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace
        return

    # filter out variables that are not involved in computing the loss
    var_list = [var for var in var_list if var.op in ops]

    for var in var_list:
        output, = var.op.outputs
        read_ops = set(output.consumers()) & ops
        for read_op in read_ops:
            with tf.name_scope('/'.join(read_op.name.split('/')[:-1])):
                with tf.device(read_op.device):
                    read_t, = read_op.outputs
                    consumer_ops = set(read_t.consumers()) & ops
                    # consumer_sgv might have multiple inputs, but we only care
                    # about replacing the input that is read_t
                    consumer_sgv = ge.sgv(consumer_ops)
                    consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)])
                    ge.connect(ge.sgv(var.read_value().op), consumer_sgv) 
开发者ID:alexlee-gk,项目名称:video_prediction,代码行数:30,代码来源:tf_utils.py


注:本文中的tensorflow.contrib.graph_editor.sgv方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。