当前位置: 首页>>代码示例>>Python>>正文


Python chainer.dataset方法代码示例

本文整理汇总了Python中chainer.dataset方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.dataset方法的具体用法?Python chainer.dataset怎么用?Python chainer.dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer的用法示例。


在下文中一共展示了chainer.dataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: convert_sequence_chain

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def convert_sequence_chain(batch, device):
    def to_device_batch(batch):
        if device is None:
            return batch
        elif device < 0:
            return [chainer.dataset.to_device(device, x) for x in batch]
        else:
            xp = cuda.cupy.get_array_module(*batch)
            concat = xp.concatenate(batch, axis=0)
            sections = np.cumsum([len(x) for x in batch[:-1]], dtype='i')
            concat_dev = chainer.dataset.to_device(device, concat)
            batch_dev = cuda.cupy.split(concat_dev, sections)
            return batch_dev

    return [to_device_batch([x[i] for x in batch])
            for i in range(len(batch[0]))] 
开发者ID:pfnet-research,项目名称:contextual_augmentation,代码行数:18,代码来源:chain_utils.py

示例2: get_val_data_iterator

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def get_val_data_iterator(data_dir,
                          batch_size,
                          num_workers,
                          num_classes):

    val_dir_path = os.path.join(data_dir, 'val')
    val_dataset = DirectoryParsingLabelDataset(val_dir_path)
    val_dataset_len = len(val_dataset)
    assert(len(directory_parsing_label_names(val_dir_path)) == num_classes)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)

    return val_iterator, val_dataset_len 
开发者ID:osmr,项目名称:imgclsmob,代码行数:21,代码来源:imagenet1k1.py

示例3: get_val_data_iterator

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def get_val_data_iterator(dataset_name,
                          batch_size,
                          num_workers):

    if dataset_name == "CIFAR10":
        _, test_ds = cifar.get_cifar10()
    elif dataset_name == "CIFAR100":
        _, test_ds = cifar.get_cifar100()
    elif dataset_name == "SVHN":
        _, test_ds = svhn.get_svhn()
    else:
        raise Exception('Unrecognized dataset: {}'.format(dataset_name))

    val_dataset = test_ds
    val_dataset_len = len(val_dataset)

    val_iterator = iterators.MultiprocessIterator(
        dataset=val_dataset,
        batch_size=batch_size,
        repeat=False,
        shuffle=False,
        n_processes=num_workers,
        shared_mem=300000000)

    return val_iterator, val_dataset_len 
开发者ID:osmr,项目名称:imgclsmob,代码行数:27,代码来源:cifar1.py

示例4: __init__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def __init__(self, dataset, batch_size, repeat=True):
        self.dataset = dataset
        self.batch_size = batch_size  # batch size
        # Number of completed sweeps over the dataset. In this case, it is
        # incremented if every word is visited at least once after the last
        # increment.
        self.epoch = 0
        # True if the epoch is incremented at the last iteration.
        self.is_new_epoch = False
        self.repeat = repeat
        length = len(dataset)
        # Offsets maintain the position of each sequence in the mini-batch.
        self.offsets = [i * length // batch_size for i in range(batch_size)]
        # NOTE: this is not a count of parameter updates. It is just a count of
        # calls of ``__next__``.
        self.iteration = 0
        # use -1 instead of None internally
        self._previous_epoch_detail = -1. 
开发者ID:chainer,项目名称:chainer,代码行数:20,代码来源:train_ptb_custom_loop.py

示例5: __next__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def __next__(self):
        # This iterator returns a list representing a mini-batch. Each item
        # indicates a different position in the original sequence. Each item is
        # represented by a pair of two word IDs. The first word is at the
        # "current" position, while the second word at the next position.
        # At each iteration, the iteration count is incremented, which pushes
        # forward the "current" position.
        length = len(self.dataset)
        if not self.repeat and self.iteration * self.batch_size >= length:
            # If not self.repeat, this iterator stops at the end of the first
            # epoch (i.e., when all words are visited once).
            raise StopIteration
        cur_words = self.get_words()
        self._previous_epoch_detail = self.epoch_detail
        self.iteration += 1
        next_words = self.get_words()

        epoch = self.iteration * self.batch_size // length
        self.is_new_epoch = self.epoch < epoch
        if self.is_new_epoch:
            self.epoch = epoch

        return list(zip(cur_words, next_words)) 
开发者ID:chainer,项目名称:chainer,代码行数:25,代码来源:train_ptb_custom_loop.py

示例6: serialize

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def serialize(self, serializer):
        # It is important to serialize the state to be recovered on resume.
        self.iteration = serializer('iteration', self.iteration)
        self.epoch = serializer('epoch', self.epoch)
        try:
            self._previous_epoch_detail = serializer(
                'previous_epoch_detail', self._previous_epoch_detail)
        except KeyError:
            # guess previous_epoch_detail for older version
            self._previous_epoch_detail = self.epoch + \
                (self.current_position - self.batch_size) / len(self.dataset)
            if self.epoch_detail > 0:
                self._previous_epoch_detail = max(
                    self._previous_epoch_detail, 0.)
            else:
                self._previous_epoch_detail = -1. 
开发者ID:chainer,项目名称:chainer,代码行数:18,代码来源:train_ptb_custom_loop.py

示例7: _call_converter

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def _call_converter(converter, batch, device):
    # Calls the converter.
    # Converter can be either new-style (accepts chainer.backend.Device) or
    # old-style (accepts int as device).
    assert device is None or isinstance(device, backend.Device)

    if isinstance(converter, Converter):
        # New-style converter
        return converter(batch, device)

    # Old-style converter
    if device is None:
        return converter(batch, None)
    if device.xp is numpy:
        return converter(batch, -1)
    if device.xp is cuda.cupy:
        return converter(batch, device.device.id)
    raise RuntimeError(
        'Converter does not support ChainerX. '
        'Use chainer.dataset.converter decorator.') 
开发者ID:chainer,项目名称:chainer,代码行数:22,代码来源:convert.py

示例8: fetch

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def fetch(self):
        """Fetch data.

        This method fetches all data of the dataset/view.
        Note that this method returns a column-major data
        (i.e. :obj:`([a[0], ..., a[3]], ..., [c[0], ... c[3]])`,
        :obj:`{'a': [a[0], ..., a[3]], ..., 'c': [c[0], ..., c[3]]}`, or
        :obj:`[a[0], ..., a[3]]`).

        Returns:
            If :attr:`mode` is :class:`tuple`,
            this method returns a tuple of lists/arrays.
            If :attr:`mode` is :class:`dict`,
            this method returns a dict of lists/arrays.
        """
        examples = self.get_examples(None, None)
        if self.mode is tuple:
            return examples
        elif self.mode is dict:
            return dict(six.moves.zip(self.keys, examples))
        elif self.mode is None:
            return examples[0] 
开发者ID:chainer,项目名称:chainer,代码行数:24,代码来源:tabular_dataset.py

示例9: convert

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def convert(self, data):
        """Convert fetched data.

        This method takes data fetched by :meth:`fetch` and
        pre-process them before passing them to models.
        The default behaviour is converting each column into an ndarray.
        This behaviour can be overridden by :meth:`with_converter`.
        If the dataset is constructed by :meth:`concat` or :meth:`join`,
        the converter of the first dataset is used.

        Args:
            data (tuple or dict): Data from :meth:`fetch`.

        Returns:
            A tuple or dict.
            Each value is an ndarray.
        """
        if isinstance(data, tuple):
            return tuple(_as_array(d) for d in data)
        elif isinstance(data, dict):
            return {k: _as_array(v) for k, v in data.items()}
        else:
            return _as_array(data) 
开发者ID:chainer,项目名称:chainer,代码行数:25,代码来源:tabular_dataset.py

示例10: _preprocess_cifar

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def _preprocess_cifar(images, labels, withlabel, ndim, scale, dtype):
    if ndim == 1:
        images = images.reshape(-1, 3072)
    elif ndim == 3:
        images = images.reshape(-1, 3, 32, 32)
    else:
        raise ValueError('invalid ndim for CIFAR dataset')
    dtype = chainer.get_dtype(dtype)
    images = images.astype(dtype)
    images *= scale / 255.

    if withlabel:
        labels = labels.astype(numpy.int32)
        return tuple_dataset.TupleDataset(images, labels)
    else:
        return images 
开发者ID:chainer,项目名称:chainer,代码行数:18,代码来源:cifar.py

示例11: __call__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def __call__(self, trainer):
        print('## Calculate BLEU')
        with chainer.no_backprop_mode():
            with chainer.using_config('train', False):
                references = []
                hypotheses = []
                for i in range(0, len(self.test_data), self.batch):
                    sources, targets = zip(*self.test_data[i:i + self.batch])
                    references.extend([[t.tolist()] for t in targets])

                    sources = [
                        chainer.dataset.to_device(self.device, x) for x in sources]
                    ys = [y.tolist()
                          for y in self.model.translate(sources, self.max_length)]
                    hypotheses.extend(ys)

        bleu = bleu_score.corpus_bleu(
            references, hypotheses,
            smoothing_function=bleu_score.SmoothingFunction().method1) * 100
        print('BLEU:', bleu)
        reporter.report({self.key: bleu}) 
开发者ID:soskek,项目名称:convolutional_seq2seq,代码行数:23,代码来源:seq2seq.py

示例12: setup_dataset

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def setup_dataset(mode, crop_dir, mask_dir=None, mean_mask_dir=None,
                  mean_grid_dir=None, trimap_dir=None, alpha_dir=None,
                  alpha_weight_dir=None):
    # Create dataset
    dataset = datasets.create(mode, crop_dir, mask_dir, mean_mask_dir,
                              mean_grid_dir, trimap_dir, alpha_dir,
                              alpha_weight_dir)

    # Create transform function
    transform = transforms.create(mode)
    transform_random = transforms.transform_random

    # Split into train and test
    train_raw, test_raw = datasets.split_dataset(dataset)

    # Increase data variety
    train_raw = chainer.datasets.TransformDataset(train_raw, transform_random)

    # Transform for network inputs
    train = chainer.datasets.TransformDataset(train_raw, transform)
    test = chainer.datasets.TransformDataset(test_raw, transform)

    return train, test 
开发者ID:takiyu,项目名称:portrait_matting,代码行数:25,代码来源:train.py

示例13: _preprocess_mnist

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format):
    images = raw['x']
    if ndim == 2:
        images = images.reshape(-1, 28, 28)
    elif ndim == 3:
        images = images.reshape(-1, 1, 28, 28)
        if rgb_format:
            images = np.broadcast_to(images, (len(images), 3) + images.shape[2:])
    elif ndim != 1:
        raise ValueError('invalid ndim for MNIST dataset')
    images = images.astype(image_dtype)
    images *= scale / 255.

    if withlabel:
        labels = raw['y'].astype(label_dtype)
        return tuple_dataset.TupleDataset(images, labels)
    else:
        return images 
开发者ID:aws,项目名称:sagemaker-chainer-container,代码行数:20,代码来源:single_machine_custom_loop.py

示例14: preview_convert

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def preview_convert(iterator_a, iterator_b, g_a, g_b, device, gla, dst):
    @chainer.training.make_extension()
    def make_preview(trainer):
        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                x_a = iterator_a.next()
                x_a = convert.concat_examples(x_a, device)
                x_a = chainer.Variable(x_a)

                x_b = iterator_b.next()
                x_b = convert.concat_examples(x_b, device)
                x_b = chainer.Variable(x_b)

                x_ab = g_a(x_a)
                x_ba = g_b(x_b)

                x_bab = g_a(x_ba)
                x_aba = g_b(x_ab)

                preview_dir = '{}/preview'.format(dst)
                if not os.path.exists(preview_dir):
                    os.makedirs(preview_dir)
                image_dir = '{}/image'.format(dst)
                if not os.path.exists(image_dir):
                    os.makedirs(image_dir)

                names = ['a', 'ab', 'aba', 'b', 'ba', 'bab']
                images = [x_a, x_ab, x_aba, x_b, x_ba, x_bab]
                for n, i in zip(names, images):
                    i = cp.asnumpy(i.data)[:,:,padding:-padding,:].reshape(1, -1, 128)
                    image.save(image_dir+'/{}{}.jpg'.format(trainer.updater.epoch,n), i)
                    w = np.concatenate([gla.inverse(_i) for _i in dataset.reverse(i)])
                    dataset.save(preview_dir+'/{}{}.wav'.format(trainer.updater.epoch,n), 16000, w)

    return make_preview 
开发者ID:pstuvwx,项目名称:Deep_VoiceChanger,代码行数:37,代码来源:trainer.py

示例15: parse_args

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-b', type=int, default=32,
                        help='Number of examples in each mini-batch')
    parser.add_argument('--bproplen', '-l', type=int, default=35,
                        help='Number of words in each mini-batch '
                             '(= length of truncated BPTT)')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--gradclip', '-c', type=float, default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--test', action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.set_defaults(test=False)
    parser.add_argument('--hidden_size', type=int, default=300,
                        help='Number of LSTM units in each layer')
    parser.add_argument('--embed_size', type=int, default=300,
                        help='Size of embeddings')
    parser.add_argument('--model', '-m', default='model.npz',
                        help='Model file name to serialize')
    parser.add_argument('--glove', default='data/glove.6B.300d.txt',
                        help='Path to glove embedding file.')
    args = parser.parse_args()
    return args 
开发者ID:Pinafore,项目名称:qb,代码行数:32,代码来源:main.py


注:本文中的chainer.dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。