本文整理汇总了Python中tensorflow.python.data.Dataset.range方法的典型用法代码示例。如果您正苦于以下问题:Python Dataset.range方法的具体用法?Python Dataset.range怎么用?Python Dataset.range使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.data.Dataset
的用法示例。
在下文中一共展示了Dataset.range方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testNestedOutputs
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testNestedOutputs(self):
ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
Dataset.range(4)))))
total = 0
# The Iterator will return a nested structure of Tensor objects.
# Some funkiness to compare against simple integers.
for (i, x) in enumerate(datasets.Iterator(ds)):
want = (i, (i, i))
got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
self.assertEqual(got, want)
total += 1
self.assertEqual(4, total)
示例2: testMapAndFilter
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testMapAndFilter(self):
def even(x):
return math_ops.equal(math_ops.mod(x, 2), 0)
it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
got = [x.numpy() for x in it]
self.assertAllEqual([0, 4, 16, 36], got)
示例3: testSaveRestoreMultipleIterator
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testSaveRestoreMultipleIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
dataset = dataset.map(math_ops.square).batch(2)
iterator_1 = datasets.Iterator(dataset)
iterator_2 = datasets.Iterator(dataset)
dataset_2 = Dataset.range(10)
iterator_3 = datasets.Iterator(dataset_2)
checkpoint = checkpointable_utils.Checkpoint(
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
self.assertEqual(0, iterator_3.get_next().numpy())
self.assertEqual(1, iterator_3.get_next().numpy())
self.assertEqual(2, iterator_3.get_next().numpy())
save_path = checkpoint.save(checkpoint_prefix)
self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
self.assertEqual(3, iterator_3.get_next().numpy())
checkpoint.restore(save_path)
self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
self.assertEqual(3, iterator_3.get_next().numpy())
示例4: testOverrideThreadPool
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testOverrideThreadPool(self):
def get_thread_id(_):
# Python creates a dummy thread object to represent the current
# thread when called from an "alien" thread (such as a
# `PrivateThreadPool` thread in this case). It does not include
# the TensorFlow-given display name, but it has a unique
# identifier that maps one-to-one with the underlying OS thread.
return np.array(threading.current_thread().ident).astype(np.int64)
for num_threads in [1, 2, 4, 8, 16]:
dataset = (
Dataset.range(1000).map(
lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
num_parallel_calls=32).apply(unique.unique()))
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
num_threads, display_name='private_thread_pool_%d' % num_threads))
thread_ids = []
for next_element in datasets.Iterator(dataset):
thread_ids.append(next_element)
self.assertEqual(len(thread_ids), len(set(thread_ids)))
self.assertGreater(len(thread_ids), 0)
# NOTE(mrry): We don't control the thread pool scheduling, and
# so cannot guarantee that all of the threads in the pool will
# perform work.
self.assertLessEqual(len(thread_ids), num_threads)
示例5: testGetNextOneShotIterator
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testGetNextOneShotIterator(self):
iterator = Dataset.range(4).make_one_shot_iterator()
self.assertEqual(0, iterator.get_next().numpy())
self.assertEqual(1, iterator.get_next().numpy())
self.assertEqual(2, iterator.get_next().numpy())
self.assertEqual(3, iterator.get_next().numpy())
with self.assertRaises(errors.OutOfRangeError):
iterator.get_next()
示例6: testGetNext
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testGetNext(self):
iterator = datasets.Iterator(Dataset.range(4))
self.assertEqual(0, iterator.get_next().numpy())
self.assertEqual(1, iterator.get_next().numpy())
self.assertEqual(2, iterator.get_next().numpy())
self.assertEqual(3, iterator.get_next().numpy())
with self.assertRaises(errors.OutOfRangeError):
iterator.get_next()
示例7: testPyFunc
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testPyFunc(self):
def my_map(inp):
return [[x + 1 for x in inp]]
ds = Dataset.range(4).map(
lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
got = [x.numpy() for x in datasets.Iterator(ds)]
self.assertAllEqual([[1], [2], [3], [4]], got)
示例8: testMultipleIteratorsOnTheSameDataset
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testMultipleIteratorsOnTheSameDataset(self):
ds = Dataset.range(4)
it1 = datasets.Iterator(ds)
it2 = datasets.Iterator(ds)
got = [x.numpy() for x in it1]
self.assertAllEqual([0, 1, 2, 3], got)
got = [x.numpy() for x in it2]
self.assertAllEqual([0, 1, 2, 3], got)
示例9: testRestoreInReconstructedIterator
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testRestoreInReconstructedIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
dataset = Dataset.range(10)
for i in range(5):
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for j in range(2):
self.assertEqual(i * 2 + j, iterator.get_next().numpy())
checkpoint.save(file_prefix=checkpoint_prefix)
示例10: testRestoreExhaustedIterator
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testRestoreExhaustedIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
dataset = Dataset.range(3)
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
self.assertEqual(0, iterator.get_next().numpy())
self.assertEqual(1, iterator.get_next().numpy())
save_path = checkpoint.save(checkpoint_prefix)
self.assertEqual(2, iterator.get_next().numpy())
checkpoint.restore(save_path)
self.assertEqual(2, iterator.get_next().numpy())
示例11: testBasic
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testBasic(self):
got = []
for t in datasets.Iterator(Dataset.range(4)):
got.append(t.numpy())
self.assertAllEqual([0, 1, 2, 3], got)
示例12: testBasicImplicitIterator
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testBasicImplicitIterator(self):
got = []
for t in Dataset.range(4):
got.append(t.numpy())
self.assertAllEqual([0, 1, 2, 3], got)
示例13: testBasicOneShotIterator
# 需要导入模块: from tensorflow.python.data import Dataset [as 别名]
# 或者: from tensorflow.python.data.Dataset import range [as 别名]
def testBasicOneShotIterator(self):
got = []
for t in Dataset.range(4).make_one_shot_iterator():
got.append(t.numpy())
self.assertAllEqual([0, 1, 2, 3], got)