本文整理汇总了Python中tensorflow.contrib.data.python.ops.optimization.assert_next函数的典型用法代码示例。如果您正苦于以下问题:Python assert_next函数的具体用法?Python assert_next怎么用?Python assert_next使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_next函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testFilterFusion
def testFilterFusion(self, map_function, predicates):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(["Map", "Filter",
"Prefetch"])).map(map_function)
for predicate in predicates:
dataset = dataset.filter(predicate)
dataset = dataset.prefetch(0).apply(
optimization.optimize(["filter_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
for x in range(5):
r = map_function(x)
filtered = False
for predicate in predicates:
if isinstance(r, tuple):
b = predicate(*r) # Pass tuple as multiple arguments.
else:
b = predicate(r)
if not sess.run(b):
filtered = True
break
if not filtered:
result = sess.run(get_next)
self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
示例2: testMapFilterFusion
def testMapFilterFusion(self, function, predicate):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map",
"FilterByLastComponent"])).map(function).filter(predicate).apply(
optimization.optimize(["map_and_filter_fusion"]))
self._testMapAndFilter(dataset, function, predicate)
示例3: testLatencyStatsOptimization
def testLatencyStatsOptimization(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.from_tensors(1).apply(
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
optimization.optimize(["latency_all_edges"])).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
sess.run(iterator.initializer)
self.assertEqual(1 * 1, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str,
"record_latency_TensorDataset/_1", 1)
self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
1)
self._assertSummaryHasCount(summary_str,
"record_latency_PrefetchDataset/_6", 1)
示例4: testHoisting
def testHoisting(self, function, will_optimize):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(
["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
self._testDataset(dataset)
示例5: testAssertNext
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
optimization.assert_next(["Map"])).map(lambda x: x)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
示例6: testAssertNextShort
def testAssertNextShort(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Asserted next 2 transformations but encountered only 1."):
sess.run(get_next)
示例7: testStatefulFunctionOptimization
def testStatefulFunctionOptimization(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next([
"MapAndBatch"
])).map(lambda _: random_ops.random_uniform([])).batch(10).apply(
optimization.optimize(["map_and_batch_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(get_next)
示例8: testOptimization
def testOptimization(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
optimization.optimize(["map_and_batch_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
示例9: testAssertSuffixInvalid
def testAssertSuffixInvalid(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
optimization.assert_next(["Whoops"])).map(lambda x: x)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Asserted Whoops transformation at offset 0 but encountered "
"Map transformation instead."):
sess.run(get_next)
示例10: testAdditionalInputs
def testAdditionalInputs(self):
a = constant_op.constant(1, dtype=dtypes.float32)
b = constant_op.constant(0, dtype=dtypes.float32)
some_tensor = math_ops.mul(a, b)
def random_with_capture(_):
return some_tensor + random_ops.random_uniform(
[], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(
["Zip[0]", "Map"])).map(random_with_capture).apply(
optimization.optimize(["hoist_random_uniform"]))
self._testDataset(dataset)
示例11: testAdditionalInputs
def testAdditionalInputs(self):
a = constant_op.constant(3, dtype=dtypes.int64)
b = constant_op.constant(4, dtype=dtypes.int64)
some_tensor = math_ops.mul(a, b)
function = lambda x: x * x
def predicate(y):
return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
# We are currently not supporting functions with additional inputs.
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
["Map", "Filter"])).map(function).filter(predicate).apply(
optimization.optimize(["map_and_filter_fusion"]))
self._testMapAndFilter(dataset, function, predicate)
示例12: testMapParallelization
def testMapParallelization(self, function, should_optimize):
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(next_nodes)).map(function).apply(
optimization.optimize(["map_parallelization"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
for x in range(5):
result = sess.run(get_next)
# No need to run the pipeline if it was not optimized. Also the results
# might be hard to check because of random.
if not should_optimize:
return
r = function(x)
self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
示例13: testMapFusion
def testMapFusion(self, functions):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(["Map", "Prefetch"]))
for function in functions:
dataset = dataset.map(function)
dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
for x in range(5):
result = sess.run(get_next)
r = x
for function in functions:
if isinstance(r, tuple):
r = function(*r) # Pass tuple as multiple arguments.
else:
r = function(r)
self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
示例14: testNoopElimination
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
b = constant_op.constant(2, dtype=dtypes.int64)
some_tensor = math_ops.mul(a, b)
dataset = dataset_ops.Dataset.range(5)
dataset = dataset.apply(
optimization.assert_next(
["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
0).repeat(1).prefetch(0)
dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.test_session() as sess:
for x in range(5):
result = sess.run(get_next)
self.assertAllEqual(result, x)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
示例15: _make_dataset
def _make_dataset(node_names):
return base_dataset.apply(optimization.assert_next(node_names)).map(
map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)