本文整理汇总了Python中chainer.dataset方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.dataset方法的具体用法?Python chainer.dataset怎么用?Python chainer.dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer
的用法示例。
在下文中一共展示了chainer.dataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: convert_sequence_chain
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def convert_sequence_chain(batch, device):
def to_device_batch(batch):
if device is None:
return batch
elif device < 0:
return [chainer.dataset.to_device(device, x) for x in batch]
else:
xp = cuda.cupy.get_array_module(*batch)
concat = xp.concatenate(batch, axis=0)
sections = np.cumsum([len(x) for x in batch[:-1]], dtype='i')
concat_dev = chainer.dataset.to_device(device, concat)
batch_dev = cuda.cupy.split(concat_dev, sections)
return batch_dev
return [to_device_batch([x[i] for x in batch])
for i in range(len(batch[0]))]
示例2: get_val_data_iterator
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def get_val_data_iterator(data_dir,
batch_size,
num_workers,
num_classes):
val_dir_path = os.path.join(data_dir, 'val')
val_dataset = DirectoryParsingLabelDataset(val_dir_path)
val_dataset_len = len(val_dataset)
assert(len(directory_parsing_label_names(val_dir_path)) == num_classes)
val_iterator = iterators.MultiprocessIterator(
dataset=val_dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers,
shared_mem=300000000)
return val_iterator, val_dataset_len
示例3: get_val_data_iterator
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def get_val_data_iterator(dataset_name,
batch_size,
num_workers):
if dataset_name == "CIFAR10":
_, test_ds = cifar.get_cifar10()
elif dataset_name == "CIFAR100":
_, test_ds = cifar.get_cifar100()
elif dataset_name == "SVHN":
_, test_ds = svhn.get_svhn()
else:
raise Exception('Unrecognized dataset: {}'.format(dataset_name))
val_dataset = test_ds
val_dataset_len = len(val_dataset)
val_iterator = iterators.MultiprocessIterator(
dataset=val_dataset,
batch_size=batch_size,
repeat=False,
shuffle=False,
n_processes=num_workers,
shared_mem=300000000)
return val_iterator, val_dataset_len
示例4: __init__
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def __init__(self, dataset, batch_size, repeat=True):
self.dataset = dataset
self.batch_size = batch_size # batch size
# Number of completed sweeps over the dataset. In this case, it is
# incremented if every word is visited at least once after the last
# increment.
self.epoch = 0
# True if the epoch is incremented at the last iteration.
self.is_new_epoch = False
self.repeat = repeat
length = len(dataset)
# Offsets maintain the position of each sequence in the mini-batch.
self.offsets = [i * length // batch_size for i in range(batch_size)]
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self.iteration = 0
# use -1 instead of None internally
self._previous_epoch_detail = -1.
示例5: __next__
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def __next__(self):
# This iterator returns a list representing a mini-batch. Each item
# indicates a different position in the original sequence. Each item is
# represented by a pair of two word IDs. The first word is at the
# "current" position, while the second word at the next position.
# At each iteration, the iteration count is incremented, which pushes
# forward the "current" position.
length = len(self.dataset)
if not self.repeat and self.iteration * self.batch_size >= length:
# If not self.repeat, this iterator stops at the end of the first
# epoch (i.e., when all words are visited once).
raise StopIteration
cur_words = self.get_words()
self._previous_epoch_detail = self.epoch_detail
self.iteration += 1
next_words = self.get_words()
epoch = self.iteration * self.batch_size // length
self.is_new_epoch = self.epoch < epoch
if self.is_new_epoch:
self.epoch = epoch
return list(zip(cur_words, next_words))
示例6: serialize
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def serialize(self, serializer):
# It is important to serialize the state to be recovered on resume.
self.iteration = serializer('iteration', self.iteration)
self.epoch = serializer('epoch', self.epoch)
try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + \
(self.current_position - self.batch_size) / len(self.dataset)
if self.epoch_detail > 0:
self._previous_epoch_detail = max(
self._previous_epoch_detail, 0.)
else:
self._previous_epoch_detail = -1.
示例7: _call_converter
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def _call_converter(converter, batch, device):
# Calls the converter.
# Converter can be either new-style (accepts chainer.backend.Device) or
# old-style (accepts int as device).
assert device is None or isinstance(device, backend.Device)
if isinstance(converter, Converter):
# New-style converter
return converter(batch, device)
# Old-style converter
if device is None:
return converter(batch, None)
if device.xp is numpy:
return converter(batch, -1)
if device.xp is cuda.cupy:
return converter(batch, device.device.id)
raise RuntimeError(
'Converter does not support ChainerX. '
'Use chainer.dataset.converter decorator.')
示例8: fetch
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def fetch(self):
"""Fetch data.
This method fetches all data of the dataset/view.
Note that this method returns a column-major data
(i.e. :obj:`([a[0], ..., a[3]], ..., [c[0], ... c[3]])`,
:obj:`{'a': [a[0], ..., a[3]], ..., 'c': [c[0], ..., c[3]]}`, or
:obj:`[a[0], ..., a[3]]`).
Returns:
If :attr:`mode` is :class:`tuple`,
this method returns a tuple of lists/arrays.
If :attr:`mode` is :class:`dict`,
this method returns a dict of lists/arrays.
"""
examples = self.get_examples(None, None)
if self.mode is tuple:
return examples
elif self.mode is dict:
return dict(six.moves.zip(self.keys, examples))
elif self.mode is None:
return examples[0]
示例9: convert
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def convert(self, data):
"""Convert fetched data.
This method takes data fetched by :meth:`fetch` and
pre-process them before passing them to models.
The default behaviour is converting each column into an ndarray.
This behaviour can be overridden by :meth:`with_converter`.
If the dataset is constructed by :meth:`concat` or :meth:`join`,
the converter of the first dataset is used.
Args:
data (tuple or dict): Data from :meth:`fetch`.
Returns:
A tuple or dict.
Each value is an ndarray.
"""
if isinstance(data, tuple):
return tuple(_as_array(d) for d in data)
elif isinstance(data, dict):
return {k: _as_array(v) for k, v in data.items()}
else:
return _as_array(data)
示例10: _preprocess_cifar
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def _preprocess_cifar(images, labels, withlabel, ndim, scale, dtype):
if ndim == 1:
images = images.reshape(-1, 3072)
elif ndim == 3:
images = images.reshape(-1, 3, 32, 32)
else:
raise ValueError('invalid ndim for CIFAR dataset')
dtype = chainer.get_dtype(dtype)
images = images.astype(dtype)
images *= scale / 255.
if withlabel:
labels = labels.astype(numpy.int32)
return tuple_dataset.TupleDataset(images, labels)
else:
return images
示例11: __call__
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def __call__(self, trainer):
print('## Calculate BLEU')
with chainer.no_backprop_mode():
with chainer.using_config('train', False):
references = []
hypotheses = []
for i in range(0, len(self.test_data), self.batch):
sources, targets = zip(*self.test_data[i:i + self.batch])
references.extend([[t.tolist()] for t in targets])
sources = [
chainer.dataset.to_device(self.device, x) for x in sources]
ys = [y.tolist()
for y in self.model.translate(sources, self.max_length)]
hypotheses.extend(ys)
bleu = bleu_score.corpus_bleu(
references, hypotheses,
smoothing_function=bleu_score.SmoothingFunction().method1) * 100
print('BLEU:', bleu)
reporter.report({self.key: bleu})
示例12: setup_dataset
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def setup_dataset(mode, crop_dir, mask_dir=None, mean_mask_dir=None,
mean_grid_dir=None, trimap_dir=None, alpha_dir=None,
alpha_weight_dir=None):
# Create dataset
dataset = datasets.create(mode, crop_dir, mask_dir, mean_mask_dir,
mean_grid_dir, trimap_dir, alpha_dir,
alpha_weight_dir)
# Create transform function
transform = transforms.create(mode)
transform_random = transforms.transform_random
# Split into train and test
train_raw, test_raw = datasets.split_dataset(dataset)
# Increase data variety
train_raw = chainer.datasets.TransformDataset(train_raw, transform_random)
# Transform for network inputs
train = chainer.datasets.TransformDataset(train_raw, transform)
test = chainer.datasets.TransformDataset(test_raw, transform)
return train, test
示例13: _preprocess_mnist
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format):
images = raw['x']
if ndim == 2:
images = images.reshape(-1, 28, 28)
elif ndim == 3:
images = images.reshape(-1, 1, 28, 28)
if rgb_format:
images = np.broadcast_to(images, (len(images), 3) + images.shape[2:])
elif ndim != 1:
raise ValueError('invalid ndim for MNIST dataset')
images = images.astype(image_dtype)
images *= scale / 255.
if withlabel:
labels = raw['y'].astype(label_dtype)
return tuple_dataset.TupleDataset(images, labels)
else:
return images
示例14: preview_convert
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def preview_convert(iterator_a, iterator_b, g_a, g_b, device, gla, dst):
@chainer.training.make_extension()
def make_preview(trainer):
with chainer.using_config('train', False):
with chainer.no_backprop_mode():
x_a = iterator_a.next()
x_a = convert.concat_examples(x_a, device)
x_a = chainer.Variable(x_a)
x_b = iterator_b.next()
x_b = convert.concat_examples(x_b, device)
x_b = chainer.Variable(x_b)
x_ab = g_a(x_a)
x_ba = g_b(x_b)
x_bab = g_a(x_ba)
x_aba = g_b(x_ab)
preview_dir = '{}/preview'.format(dst)
if not os.path.exists(preview_dir):
os.makedirs(preview_dir)
image_dir = '{}/image'.format(dst)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
names = ['a', 'ab', 'aba', 'b', 'ba', 'bab']
images = [x_a, x_ab, x_aba, x_b, x_ba, x_bab]
for n, i in zip(names, images):
i = cp.asnumpy(i.data)[:,:,padding:-padding,:].reshape(1, -1, 128)
image.save(image_dir+'/{}{}.jpg'.format(trainer.updater.epoch,n), i)
w = np.concatenate([gla.inverse(_i) for _i in dataset.reverse(i)])
dataset.save(preview_dir+'/{}{}.wav'.format(trainer.updater.epoch,n), 16000, w)
return make_preview
示例15: parse_args
# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import dataset [as 别名]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', '-b', type=int, default=32,
help='Number of examples in each mini-batch')
parser.add_argument('--bproplen', '-l', type=int, default=35,
help='Number of words in each mini-batch '
'(= length of truncated BPTT)')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=0,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--gradclip', '-c', type=float, default=5,
help='Gradient norm threshold to clip')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--test', action='store_true',
help='Use tiny datasets for quick tests')
parser.set_defaults(test=False)
parser.add_argument('--hidden_size', type=int, default=300,
help='Number of LSTM units in each layer')
parser.add_argument('--embed_size', type=int, default=300,
help='Size of embeddings')
parser.add_argument('--model', '-m', default='model.npz',
help='Model file name to serialize')
parser.add_argument('--glove', default='data/glove.6B.300d.txt',
help='Path to glove embedding file.')
args = parser.parse_args()
return args