本文整理汇总了Python中chainer.iterators.MultiprocessIterator方法的典型用法代码示例。如果您正苦于以下问题:Python iterators.MultiprocessIterator方法的具体用法?Python iterators.MultiprocessIterator怎么用?Python iterators.MultiprocessIterator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer.iterators
的用法示例。
在下文中一共展示了iterators.MultiprocessIterator方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_val_data_iterator
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_val_data_iterator(data_dir,
batch_size,
num_workers,
num_classes):
val_dir_path = os.path.join(data_dir, 'val')
val_dataset = DirectoryParsingLabelDataset(val_dir_path)
val_dataset_len = len(val_dataset)
assert(len(directory_parsing_label_names(val_dir_path)) == num_classes)
val_iterator = iterators.MultiprocessIterator(
dataset=val_dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers,
shared_mem=300000000)
return val_iterator, val_dataset_len
示例2: get_val_data_iterator
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_val_data_iterator(dataset_name,
batch_size,
num_workers):
if dataset_name == "CIFAR10":
_, test_ds = cifar.get_cifar10()
elif dataset_name == "CIFAR100":
_, test_ds = cifar.get_cifar100()
elif dataset_name == "SVHN":
_, test_ds = svhn.get_svhn()
else:
raise Exception('Unrecognized dataset: {}'.format(dataset_name))
val_dataset = test_ds
val_dataset_len = len(val_dataset)
val_iterator = iterators.MultiprocessIterator(
dataset=val_dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers,
shared_mem=300000000)
return val_iterator, val_dataset_len
示例3: get_data_iterators
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_data_iterators(batch_size,
num_workers):
train_dataset = PreprocessedCIFARDataset(train=True)
train_iterator = iterators.MultiprocessIterator(
dataset=train_dataset,
batch_size=batch_size,
repeat=False,
shuffle=True,
n_processes=num_workers)
val_dataset = PreprocessedCIFARDataset(train=False)
val_iterator = iterators.MultiprocessIterator(
dataset=val_dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers)
return train_iterator, val_iterator
示例4: get_train_data_source
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_train_data_source(ds_metainfo,
batch_size,
num_workers):
transform = ds_metainfo.train_transform(ds_metainfo=ds_metainfo)
dataset = ds_metainfo.dataset_class(
root=ds_metainfo.root_dir_path,
mode="train",
transform=transform)
ds_metainfo.update_from_dataset(dataset)
iterator = MultiprocessIterator(
dataset=dataset,
batch_size=batch_size,
repeat=False,
shuffle=True,
n_processes=num_workers,
shared_mem=300000000)
return {
# "transform": transform,
"iterator": iterator,
"ds_len": len(dataset)
}
示例5: get_test_data_source
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_test_data_source(ds_metainfo,
batch_size,
num_workers):
transform = ds_metainfo.test_transform(ds_metainfo=ds_metainfo)
dataset = ds_metainfo.dataset_class(
root=ds_metainfo.root_dir_path,
mode="test",
transform=transform)
ds_metainfo.update_from_dataset(dataset)
iterator = MultiprocessIterator(
dataset=dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers,
shared_mem=300000000)
return {
# "transform": transform,
"iterator": iterator,
"ds_len": len(dataset)
}
示例6: test_iterator_not_repeat_not_even
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_not_repeat_not_even(self):
dataset = [1, 2, 3, 4, 5]
it = iterators.MultiprocessIterator(
dataset, 2, repeat=False, **self.options)
self.assertAlmostEqual(it.epoch_detail, 0 / 5)
self.assertIsNone(it.previous_epoch_detail)
batch1 = it.next()
self.assertAlmostEqual(it.epoch_detail, 2 / 5)
self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
batch2 = it.next()
self.assertAlmostEqual(it.epoch_detail, 4 / 5)
self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
batch3 = it.next()
self.assertAlmostEqual(it.epoch_detail, 5 / 5)
self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)
self.assertRaises(StopIteration, it.next)
self.assertEqual(len(batch3), 1)
self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
示例7: get_data_iterators
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_data_iterators(data_dir,
batch_size,
num_workers,
num_classes,
input_image_size=224,
resize_inv_factor=0.875):
assert (resize_inv_factor > 0.0)
resize_value = int(math.ceil(float(input_image_size) / resize_inv_factor))
train_dir_path = os.path.join(data_dir, 'train')
train_dataset = PreprocessedDataset(
root=train_dir_path,
scale_size=resize_value,
crop_size=input_image_size)
assert(len(directory_parsing_label_names(train_dir_path)) == num_classes)
val_dir_path = os.path.join(data_dir, 'val')
val_dataset = PreprocessedDataset(
root=val_dir_path,
scale_size=resize_value,
crop_size=input_image_size)
assert (len(directory_parsing_label_names(val_dir_path)) == num_classes)
train_iterator = iterators.MultiprocessIterator(
dataset=train_dataset,
batch_size=batch_size,
repeat=False,
shuffle=True,
n_processes=num_workers)
val_iterator = iterators.MultiprocessIterator(
dataset=val_dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers)
return train_iterator, val_iterator
示例8: test_iterator_repeat
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_repeat(self):
dataset = [1, 2, 3, 4, 5, 6]
it = iterators.MultiprocessIterator(dataset, 2, **self.options)
for i in range(3):
self.assertEqual(it.epoch, i)
self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
if i == 0:
self.assertIsNone(it.previous_epoch_detail)
else:
self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
batch1 = it.next()
self.assertEqual(len(batch1), 2)
self.assertIsInstance(batch1, list)
self.assertFalse(it.is_new_epoch)
self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
batch2 = it.next()
self.assertEqual(len(batch2), 2)
self.assertIsInstance(batch2, list)
self.assertFalse(it.is_new_epoch)
self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
batch3 = it.next()
self.assertEqual(len(batch3), 2)
self.assertIsInstance(batch3, list)
self.assertTrue(it.is_new_epoch)
self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6)
示例9: test_iterator_list_type
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_list_type(self):
dataset = [[i, numpy.zeros((10,)) + i] for i in range(6)]
it = iterators.MultiprocessIterator(dataset, 2, **self.options)
for i in range(3):
self.assertEqual(it.epoch, i)
self.assertAlmostEqual(it.epoch_detail, i)
if i == 0:
self.assertIsNone(it.previous_epoch_detail)
else:
self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
batches = {}
for j in range(3):
batch = it.next()
self.assertEqual(len(batch), 2)
if j != 2:
self.assertFalse(it.is_new_epoch)
else:
self.assertTrue(it.is_new_epoch)
self.assertAlmostEqual(
it.epoch_detail, (3 * i + j + 1) * 2 / 6)
self.assertAlmostEqual(
it.previous_epoch_detail, (3 * i + j) * 2 / 6)
for x in batch:
self.assertIsInstance(x, list)
self.assertIsInstance(x[1], numpy.ndarray)
batches[x[0]] = x[1]
self.assertEqual(len(batches), len(dataset))
for k, v in six.iteritems(batches):
numpy.testing.assert_allclose(dataset[k][1], v)
示例10: test_iterator_tuple_type
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_tuple_type(self):
dataset = [(i, numpy.zeros((10,)) + i) for i in range(6)]
it = iterators.MultiprocessIterator(dataset, 2, **self.options)
for i in range(3):
self.assertEqual(it.epoch, i)
self.assertAlmostEqual(it.epoch_detail, i)
if i == 0:
self.assertIsNone(it.previous_epoch_detail)
else:
self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
batches = {}
for j in range(3):
batch = it.next()
self.assertEqual(len(batch), 2)
if j != 2:
self.assertFalse(it.is_new_epoch)
else:
self.assertTrue(it.is_new_epoch)
self.assertAlmostEqual(
it.epoch_detail, (3 * i + j + 1) * 2 / 6)
self.assertAlmostEqual(
it.previous_epoch_detail, (3 * i + j) * 2 / 6)
for x in batch:
self.assertIsInstance(x, tuple)
self.assertIsInstance(x[1], numpy.ndarray)
batches[x[0]] = x[1]
self.assertEqual(len(batches), len(dataset))
for k, v in six.iteritems(batches):
numpy.testing.assert_allclose(dataset[k][1], v)
示例11: test_iterator_dict_type
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_dict_type(self):
dataset = [{i: numpy.zeros((10,)) + i} for i in range(6)]
it = iterators.MultiprocessIterator(dataset, 2, **self.options)
for i in range(3):
self.assertEqual(it.epoch, i)
self.assertAlmostEqual(it.epoch_detail, i)
if i == 0:
self.assertIsNone(it.previous_epoch_detail)
else:
self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
batches = {}
for j in range(3):
batch = it.next()
self.assertEqual(len(batch), 2)
if j != 2:
self.assertFalse(it.is_new_epoch)
else:
self.assertTrue(it.is_new_epoch)
self.assertAlmostEqual(
it.epoch_detail, (3 * i + j + 1) * 2 / 6)
self.assertAlmostEqual(
it.previous_epoch_detail, (3 * i + j) * 2 / 6)
for x in batch:
self.assertIsInstance(x, dict)
k = tuple(x)[0]
v = x[k]
self.assertIsInstance(v, numpy.ndarray)
batches[k] = v
self.assertEqual(len(batches), len(dataset))
for k, v in six.iteritems(batches):
x = dataset[k][tuple(dataset[k])[0]]
numpy.testing.assert_allclose(x, v)
示例12: test_iterator_repeat_not_even
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_repeat_not_even(self):
dataset = [1, 2, 3, 4, 5]
it = iterators.MultiprocessIterator(dataset, 2, **self.options)
batches = sum([it.next() for _ in range(5)], [])
self.assertEqual(sorted(batches), sorted(dataset * 2))
示例13: test_iterator_shuffle_divisible
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_shuffle_divisible(self):
dataset = list(range(10))
it = iterators.MultiprocessIterator(
dataset, 10, **self.options)
self.assertNotEqual(it.next(), it.next())
示例14: test_iterator_shuffle_nondivisible
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_shuffle_nondivisible(self):
dataset = list(range(10))
it = iterators.MultiprocessIterator(
dataset, 3, **self.options)
out = sum([it.next() for _ in range(7)], [])
self.assertNotEqual(out[0:10], out[10:20])
示例15: test_copy_not_repeat
# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_copy_not_repeat(self):
dataset = [1, 2, 3, 4, 5]
it = iterators.MultiprocessIterator(
dataset, 2, repeat=False, **self.options)
copy_it = copy.copy(it)
batches = sum([it.next() for _ in range(3)], [])
self.assertEqual(sorted(batches), dataset)
for _ in range(2):
self.assertRaises(StopIteration, it.next)
it = None
batches = sum([copy_it.next() for _ in range(3)], [])
self.assertEqual(sorted(batches), dataset)
for _ in range(2):
self.assertRaises(StopIteration, copy_it.next)