当前位置: 首页>>代码示例>>Python>>正文


Python iterators.MultiprocessIterator方法代码示例

本文整理汇总了Python中chainer.iterators.MultiprocessIterator方法的典型用法代码示例。如果您正苦于以下问题:Python iterators.MultiprocessIterator方法的具体用法?Python iterators.MultiprocessIterator怎么用?Python iterators.MultiprocessIterator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.iterators的用法示例。


在下文中一共展示了iterators.MultiprocessIterator方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_val_data_iterator

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_val_data_iterator(data_dir,
                          batch_size,
                          num_workers,
                          num_classes):

    val_dir_path = os.path.join(data_dir, 'val')
    val_dataset = DirectoryParsingLabelDataset(val_dir_path)
    val_dataset_len = len(val_dataset)
    assert(len(directory_parsing_label_names(val_dir_path)) == num_classes)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)

    return val_iterator, val_dataset_len 
开发者ID:osmr,项目名称:imgclsmob,代码行数:21,代码来源:imagenet1k1.py

示例2: get_val_data_iterator

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_val_data_iterator(dataset_name,
                          batch_size,
                          num_workers):

    if dataset_name == "CIFAR10":
        _, test_ds = cifar.get_cifar10()
    elif dataset_name == "CIFAR100":
        _, test_ds = cifar.get_cifar100()
    elif dataset_name == "SVHN":
        _, test_ds = svhn.get_svhn()
    else:
        raise Exception('Unrecognized dataset: {}'.format(dataset_name))

    val_dataset = test_ds
    val_dataset_len = len(val_dataset)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)

    return val_iterator, val_dataset_len 
开发者ID:osmr,项目名称:imgclsmob,代码行数:27,代码来源:cifar1.py

示例3: get_data_iterators

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_data_iterators(batch_size,
                       num_workers):

    train_dataset = PreprocessedCIFARDataset(train=True)
    train_iterator = iterators.MultiprocessIterator(
        dataset=train_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=True,
        n_processes=num_workers)

    val_dataset = PreprocessedCIFARDataset(train=False)
    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers)

    return train_iterator, val_iterator 
开发者ID:osmr,项目名称:imgclsmob,代码行数:22,代码来源:cifar1.py

示例4: get_train_data_source

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_train_data_source(ds_metainfo,
                          batch_size,
                          num_workers):
    transform = ds_metainfo.train_transform(ds_metainfo=ds_metainfo)
    dataset = ds_metainfo.dataset_class(
        root=ds_metainfo.root_dir_path,
        mode="train",
        transform=transform)
    ds_metainfo.update_from_dataset(dataset)
    iterator = MultiprocessIterator(
        dataset=dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=True,
        n_processes=num_workers,
        shared_mem=300000000)
    return {
        # "transform": transform,
        "iterator": iterator,
        "ds_len": len(dataset)
    } 
开发者ID:osmr,项目名称:imgclsmob,代码行数:23,代码来源:dataset_utils.py

示例5: get_test_data_source

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_test_data_source(ds_metainfo,
                         batch_size,
                         num_workers):
    transform = ds_metainfo.test_transform(ds_metainfo=ds_metainfo)
    dataset = ds_metainfo.dataset_class(
        root=ds_metainfo.root_dir_path,
        mode="test",
        transform=transform)
    ds_metainfo.update_from_dataset(dataset)
    iterator = MultiprocessIterator(
        dataset=dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)
    return {
        # "transform": transform,
        "iterator": iterator,
        "ds_len": len(dataset)
    } 
开发者ID:osmr,项目名称:imgclsmob,代码行数:23,代码来源:dataset_utils.py

示例6: test_iterator_not_repeat_not_even

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_not_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=False, **self.options)

        self.assertAlmostEqual(it.epoch_detail, 0 / 5)
        self.assertIsNone(it.previous_epoch_detail)
        batch1 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 2 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5)
        batch2 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 4 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5)
        batch3 = it.next()
        self.assertAlmostEqual(it.epoch_detail, 5 / 5)
        self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5)
        self.assertRaises(StopIteration, it.next)

        self.assertEqual(len(batch3), 1)
        self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) 
开发者ID:chainer,项目名称:chainer,代码行数:22,代码来源:test_multiprocess_iterator.py

示例7: get_data_iterators

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def get_data_iterators(data_dir,
                       batch_size,
                       num_workers,
                       num_classes,
                       input_image_size=224,
                       resize_inv_factor=0.875):
    assert (resize_inv_factor > 0.0)
    resize_value = int(math.ceil(float(input_image_size) / resize_inv_factor))

    train_dir_path = os.path.join(data_dir, 'train')
    train_dataset = PreprocessedDataset(
        root=train_dir_path,
        scale_size=resize_value,
        crop_size=input_image_size)
    assert(len(directory_parsing_label_names(train_dir_path)) == num_classes)

    val_dir_path = os.path.join(data_dir, 'val')
    val_dataset = PreprocessedDataset(
        root=val_dir_path,
        scale_size=resize_value,
        crop_size=input_image_size)
    assert (len(directory_parsing_label_names(val_dir_path)) == num_classes)

    train_iterator = iterators.MultiprocessIterator(
        dataset=train_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=True,
        n_processes=num_workers)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers)

    return train_iterator, val_iterator 
开发者ID:osmr,项目名称:imgclsmob,代码行数:40,代码来源:imagenet1k1.py

示例8: test_iterator_repeat

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_repeat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batch1 = it.next()
            self.assertEqual(len(batch1), 2)
            self.assertIsInstance(batch1, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 2 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6)
            batch2 = it.next()
            self.assertEqual(len(batch2), 2)
            self.assertIsInstance(batch2, list)
            self.assertFalse(it.is_new_epoch)
            self.assertAlmostEqual(it.epoch_detail, i + 4 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6)
            batch3 = it.next()
            self.assertEqual(len(batch3), 2)
            self.assertIsInstance(batch3, list)
            self.assertTrue(it.is_new_epoch)
            self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
            self.assertAlmostEqual(it.epoch_detail, i + 6 / 6)
            self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) 
开发者ID:chainer,项目名称:chainer,代码行数:31,代码来源:test_multiprocess_iterator.py

示例9: test_iterator_list_type

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_list_type(self):
        dataset = [[i, numpy.zeros((10,)) + i] for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(
                    it.epoch_detail, (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(
                    it.previous_epoch_detail, (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, list)
                    self.assertIsInstance(x[1], numpy.ndarray)
                    batches[x[0]] = x[1]

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                numpy.testing.assert_allclose(dataset[k][1], v) 
开发者ID:chainer,项目名称:chainer,代码行数:32,代码来源:test_multiprocess_iterator.py

示例10: test_iterator_tuple_type

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_tuple_type(self):
        dataset = [(i, numpy.zeros((10,)) + i) for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(
                    it.epoch_detail, (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(
                    it.previous_epoch_detail, (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, tuple)
                    self.assertIsInstance(x[1], numpy.ndarray)
                    batches[x[0]] = x[1]

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                numpy.testing.assert_allclose(dataset[k][1], v) 
开发者ID:chainer,项目名称:chainer,代码行数:32,代码来源:test_multiprocess_iterator.py

示例11: test_iterator_dict_type

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_dict_type(self):
        dataset = [{i: numpy.zeros((10,)) + i} for i in range(6)]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
            else:
                self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6)
            batches = {}
            for j in range(3):
                batch = it.next()
                self.assertEqual(len(batch), 2)
                if j != 2:
                    self.assertFalse(it.is_new_epoch)
                else:
                    self.assertTrue(it.is_new_epoch)
                self.assertAlmostEqual(
                    it.epoch_detail, (3 * i + j + 1) * 2 / 6)
                self.assertAlmostEqual(
                    it.previous_epoch_detail, (3 * i + j) * 2 / 6)
                for x in batch:
                    self.assertIsInstance(x, dict)
                    k = tuple(x)[0]
                    v = x[k]
                    self.assertIsInstance(v, numpy.ndarray)
                    batches[k] = v

            self.assertEqual(len(batches), len(dataset))
            for k, v in six.iteritems(batches):
                x = dataset[k][tuple(dataset[k])[0]]
                numpy.testing.assert_allclose(x, v) 
开发者ID:chainer,项目名称:chainer,代码行数:35,代码来源:test_multiprocess_iterator.py

示例12: test_iterator_repeat_not_even

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_repeat_not_even(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(dataset, 2, **self.options)

        batches = sum([it.next() for _ in range(5)], [])
        self.assertEqual(sorted(batches), sorted(dataset * 2)) 
开发者ID:chainer,项目名称:chainer,代码行数:8,代码来源:test_multiprocess_iterator.py

示例13: test_iterator_shuffle_divisible

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_shuffle_divisible(self):
        dataset = list(range(10))
        it = iterators.MultiprocessIterator(
            dataset, 10, **self.options)
        self.assertNotEqual(it.next(), it.next()) 
开发者ID:chainer,项目名称:chainer,代码行数:7,代码来源:test_multiprocess_iterator.py

示例14: test_iterator_shuffle_nondivisible

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_iterator_shuffle_nondivisible(self):
        dataset = list(range(10))
        it = iterators.MultiprocessIterator(
            dataset, 3, **self.options)
        out = sum([it.next() for _ in range(7)], [])
        self.assertNotEqual(out[0:10], out[10:20]) 
开发者ID:chainer,项目名称:chainer,代码行数:8,代码来源:test_multiprocess_iterator.py

示例15: test_copy_not_repeat

# 需要导入模块: from chainer import iterators [as 别名]
# 或者: from chainer.iterators import MultiprocessIterator [as 别名]
def test_copy_not_repeat(self):
        dataset = [1, 2, 3, 4, 5]
        it = iterators.MultiprocessIterator(
            dataset, 2, repeat=False, **self.options)
        copy_it = copy.copy(it)
        batches = sum([it.next() for _ in range(3)], [])
        self.assertEqual(sorted(batches), dataset)
        for _ in range(2):
            self.assertRaises(StopIteration, it.next)
        it = None

        batches = sum([copy_it.next() for _ in range(3)], [])
        self.assertEqual(sorted(batches), dataset)
        for _ in range(2):
            self.assertRaises(StopIteration, copy_it.next) 
开发者ID:chainer,项目名称:chainer,代码行数:17,代码来源:test_multiprocess_iterator.py


注:本文中的chainer.iterators.MultiprocessIterator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。