本文整理汇总了Python中tensorflow.contrib.graph_editor.sgv方法的典型用法代码示例。如果您正苦于以下问题:Python graph_editor.sgv方法的具体用法?Python graph_editor.sgv怎么用?Python graph_editor.sgv使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.graph_editor
的用法示例。
在下文中一共展示了graph_editor.sgv方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_reroute_can_modify
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_reroute_can_modify(self):
graph = tf.Graph()
# create a special graph where "a" is an ambiguous tensor. That is
# it is both an input and an output of the ops in sgv0.
with graph.as_default():
a = tf.constant(1.0, shape=[2], name="a")
b = tf.constant(2.0, shape=[2], name="b")
c = tf.add(a, b, name="c")
d = tf.add(a, c, name="d")
e = tf.constant(1.0, shape=[2], name="e")
f = tf.constant(2.0, shape=[2], name="f")
g = tf.add(e, f, name="g")
sgv0 = ge.sgv(a.op, b.op, c.op)
sgv1 = ge.sgv(e.op, f.op)
ge.reroute.swap_outputs(sgv0, sgv1)
self.assertTrue(ge.matcher("g").input_ops("a", ge.matcher("c")
.input_ops("a", "b"))(g.op))
self.assertTrue(ge.matcher("d").input_ops("e", "f")(d.op))
示例2: test_subgraph_remap
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_subgraph_remap(self):
sgv = ge.sgv(self.c.op)
self.assertEqual(list(sgv.outputs), [self.c])
self.assertEqual(list(sgv.inputs), [self.a, self.b])
sgv = sgv.remap_outputs_to_consumers()
self.assertEqual(list(sgv.outputs), [self.c, self.c, self.c])
sgv = sgv.remap_outputs_make_unique()
self.assertEqual(list(sgv.outputs), [self.c])
sgv = sgv.remap(new_input_indices=[], new_output_indices=[])
self.assertEqual(len(sgv.inputs), 0)
self.assertEqual(len(sgv.outputs), 0)
sgv = sgv.remap_default()
self.assertEqual(list(sgv.outputs), [self.c])
self.assertEqual(list(sgv.inputs), [self.a, self.b])
示例3: test_multiswap
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_multiswap(self):
with self.graph.as_default():
a3 = tf.constant(3.0, shape=[2], name="a3")
ge.reroute.swap(ge.sgv(a3.op).remap_outputs([0, 0]),
ge.sgv(self.a0.op, self.a1.op))
self.assertTrue(ge.matcher("c0").input_ops("a3", "b0")(self.c0.op))
self.assertTrue(ge.matcher("c1").input_ops("a3", "b1")(self.c1.op))
示例4: test_subgraph
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_subgraph(self):
sgv = ge.sgv(self.graph)
self.assertEqual(list(sgv.outputs), [self.e, self.h])
self.assertEqual(list(sgv.inputs), [])
self.assertEqual(len(sgv.ops), 8)
sgv = ge.sgv(self.f.op, self.g.op)
self.assertEqual(list(sgv.outputs), [self.f, self.g])
self.assertEqual(list(sgv.inputs), [self.c, self.d, self.a])
sgv = ge.sgv_scope("foo/bar", graph=self.graph)
self.assertEqual(list(sgv.ops),
[self.e.op, self.f.op, self.g.op, self.h.op])
示例5: test_remove_unused_ops
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_remove_unused_ops(self):
sgv = ge.sgv(self.graph)
self.assertEqual(list(sgv.outputs), [self.e, self.h])
self.assertEqual(len(sgv.ops), 8)
sgv = sgv.remap_outputs(new_output_indices=[1]).remove_unused_ops()
self.assertEqual(list(sgv.outputs), [self.h])
self.assertEqual(len(sgv.ops), 7)
示例6: test_connect
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_connect(self):
"""Test for ge.connect."""
with self.graph.as_default():
x = tf.constant([1., 1.], shape=[2], name="x")
y = tf.constant([2., 2.], shape=[2], name="y")
z = tf.add(x, y, name="z")
sgv = ge.sgv(x.op, y.op, z.op)
ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0]))
self.assertTrue(ge.matcher("^foo/bar/e$").input_ops("^z$", "foo/d$")
(self.e.op))
示例7: test_bypass
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def test_bypass(self):
"""Test for ge.bypass."""
ge.bypass(ge.sgv(self.f.op).remap_inputs([0]))
self.assertTrue(ge.matcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")
(self.h.op))
示例8: replace_read_ops
# 需要导入模块: from tensorflow.contrib import graph_editor [as 别名]
# 或者: from tensorflow.contrib.graph_editor import sgv [as 别名]
def replace_read_ops(loss_or_losses, var_list):
"""
Replaces read ops of each variable in `vars` with new read ops obtained
from `read_value()`, thus forcing to read the most up-to-date values of
the variables (which might incur copies across devices).
The graph is seeded from the tensor(s) `loss_or_losses`.
"""
# ops between var ops and the loss
ops = set(ge.get_walks_intersection_ops([var.op for var in var_list], loss_or_losses))
if not ops: # loss_or_losses doesn't depend on any var in var_list, so there is nothiing to replace
return
# filter out variables that are not involved in computing the loss
var_list = [var for var in var_list if var.op in ops]
for var in var_list:
output, = var.op.outputs
read_ops = set(output.consumers()) & ops
for read_op in read_ops:
with tf.name_scope('/'.join(read_op.name.split('/')[:-1])):
with tf.device(read_op.device):
read_t, = read_op.outputs
consumer_ops = set(read_t.consumers()) & ops
# consumer_sgv might have multiple inputs, but we only care
# about replacing the input that is read_t
consumer_sgv = ge.sgv(consumer_ops)
consumer_sgv = consumer_sgv.remap_inputs([list(consumer_sgv.inputs).index(read_t)])
ge.connect(ge.sgv(var.read_value().op), consumer_sgv)