当前位置: 首页>>代码示例>>Python>>正文


Python datasets.IndexableDataset方法代码示例

本文整理汇总了Python中fuel.datasets.IndexableDataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.IndexableDataset方法的具体用法?Python datasets.IndexableDataset怎么用?Python datasets.IndexableDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在fuel.datasets的用法示例。


在下文中一共展示了datasets.IndexableDataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: setup_datastream

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def setup_datastream(path, batch_size, sort_batch_count, valid=False):
    A = numpy.load(os.path.join(path, ('valid_x_raw.npy' if valid else 'train_x_raw.npy')))
    B = numpy.load(os.path.join(path, ('valid_phn.npy' if valid else 'train_phn.npy')))
    C = numpy.load(os.path.join(path, ('valid_seq_to_phn.npy' if valid else 'train_seq_to_phn.npy')))

    D = [B[x[0]:x[1], 2] for x in C]

    ds = IndexableDataset({'input': A, 'output': D})
    stream = DataStream(ds, iteration_scheme=ShuffledExampleScheme(len(A)))

    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_batch_count))
    comparison = _balanced_batch_helper(stream.sources.index('input'))
    stream = Mapping(stream, SortMapping(comparison))
    stream = Unpack(stream)

    stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size, num_examples=len(A)))
    stream = Padding(stream, mask_sources=['input', 'output'])

    return ds, stream 
开发者ID:thomasmesnard,项目名称:CTC-LSTM,代码行数:21,代码来源:timit.py

示例2: build_fuel

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def build_fuel(data):
    # create fuel dataset.
    dataset     = datasets.IndexableDataset(indexables=OrderedDict([('data', data)]))
    dataset.example_iteration_scheme \
                = schemes.ShuffledExampleScheme(dataset.num_examples)
    return dataset, len(data) 
开发者ID:memray,项目名称:seq2seq-keyphrase,代码行数:8,代码来源:build_dataset.py

示例3: build_data

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def build_data(data):
    # create fuel dataset.
    dataset = datasets.IndexableDataset(indexables=OrderedDict([('source', data['source']),
                                                                ('target', data['target']),
                                                                # ('target_c', data['target_c']),
                                                                ]))
    dataset.example_iteration_scheme \
        = schemes.ShuffledExampleScheme(dataset.num_examples)
    return dataset 
开发者ID:memray,项目名称:seq2seq-keyphrase,代码行数:11,代码来源:keyphrase_copynet.py

示例4: test_mean_aggregator

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_mean_aggregator():
    num_examples = 4
    batch_size = 2

    features = numpy.array([[0, 3],
                           [2, 9],
                           [2, 4],
                           [5, 1]], dtype=theano.config.floatX)

    dataset = IndexableDataset(OrderedDict([('features', features)]))

    data_stream = DataStream(dataset,
                             iteration_scheme=SequentialScheme(num_examples,
                                                               batch_size))

    x = tensor.matrix('features')
    y = (x**2).mean(axis=0)
    y.name = 'y'
    z = y.sum()
    z.name = 'z'

    y.tag.aggregation_scheme = Mean(y, 1.)
    z.tag.aggregation_scheme = Mean(z, 1.)

    assert_allclose(DatasetEvaluator([y]).evaluate(data_stream)['y'],
                    numpy.array([8.25, 26.75], dtype=theano.config.floatX))
    assert_allclose(DatasetEvaluator([z]).evaluate(data_stream)['z'],
                    numpy.array([35], dtype=theano.config.floatX)) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:30,代码来源:test_aggregation.py

示例5: test_getattr

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_getattr(self):
        assert_equal(getattr(IndexableDataset({'a': (1, 2)}), 'a'), (1, 2)) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:4,代码来源:test_datasets.py

示例6: test_value_error_get_data_state

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_value_error_get_data_state(self):
        assert_raises(
            ValueError, IndexableDataset([1, 2, 3]).get_data, True, [1, 2]) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:5,代码来源:test_datasets.py

示例7: test_value_error_get_data_none_request

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_value_error_get_data_none_request(self):
        assert_raises(
            ValueError, IndexableDataset([1, 2, 3]).get_data, None, None) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:5,代码来源:test_datasets.py

示例8: test_pickling

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_pickling(self):
        cPickle.loads(cPickle.dumps(IndexableDataset({'a': (1, 2)}))) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:4,代码来源:test_datasets.py

示例9: test_num_examples

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_num_examples():
    assert_raises(ValueError, IterableDataset,
                  {'features': range(10), 'targets': range(7)})
    dataset = IterableDataset({'features': range(7),
                               'targets': range(7)})
    assert dataset.num_examples == 7
    dataset = IterableDataset(repeat(1))
    assert numpy.isnan(dataset.num_examples)
    x = numpy.random.rand(5, 3)
    y = numpy.random.rand(5, 4)
    dataset = IndexableDataset({'features': x, 'targets': y})
    assert dataset.num_examples == 5
    assert_raises(ValueError, IndexableDataset,
                  {'features': x, 'targets': y[:4]}) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:16,代码来源:test_datasets.py

示例10: test_axis_labels_on_produces_batches

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_axis_labels_on_produces_batches(self):
        dataset = IndexableDataset(numpy.eye(2))
        axis_labels = {'data': ('batch', 'features')}
        dataset.axis_labels = axis_labels
        stream = DataStream(dataset, iteration_scheme=SequentialScheme(2, 2))
        assert_equal(stream.axis_labels, axis_labels) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:8,代码来源:test_streams.py

示例11: test_flatten_batches

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_flatten_batches(self):
        wrapper = Flatten(
            DataStream(IndexableDataset(self.data),
                       iteration_scheme=SequentialScheme(4, 2)),
            which_sources=('features',))
        assert_equal(
            list(wrapper.get_epoch_iterator()),
            [(numpy.ones((2, 4)), numpy.array([[0], [1]])),
             (numpy.ones((2, 4)), numpy.array([[0], [1]]))]) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:11,代码来源:test_transformers.py

示例12: test_axis_labels_on_flatten_batches

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_axis_labels_on_flatten_batches(self):
        wrapper = Flatten(
            DataStream(IndexableDataset(self.data),
                       iteration_scheme=SequentialScheme(4, 2),
                       axis_labels={'features': ('batch', 'width', 'height'),
                                    'targets': ('batch', 'index')}),
            which_sources=('features',))
        assert_equal(wrapper.axis_labels, {'features': ('batch', 'feature'),
                                           'targets': ('batch', 'index')}) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:11,代码来源:test_transformers.py

示例13: test_axis_labels_on_flatten_batches_with_none

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_axis_labels_on_flatten_batches_with_none(self):
        wrapper = Flatten(
            DataStream(IndexableDataset(self.data),
                       iteration_scheme=SequentialScheme(4, 2),
                       axis_labels={'features': None,
                                    'targets': ('batch', 'index')}),
            which_sources=('features',))
        assert_equal(wrapper.axis_labels, {'features': None,
                                           'targets': ('batch', 'index')}) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:11,代码来源:test_transformers.py

示例14: test_axis_labels_on_flatten_examples

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_axis_labels_on_flatten_examples(self):
        wrapper = Flatten(
            DataStream(IndexableDataset(self.data),
                       iteration_scheme=SequentialExampleScheme(4),
                       axis_labels={'features': ('batch', 'width', 'height'),
                                    'targets': ('batch', 'index')}),
            which_sources=('features',))
        assert_equal(wrapper.axis_labels, {'features': ('feature',),
                                           'targets': ('index',)}) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:11,代码来源:test_transformers.py

示例15: test_filter_batches

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IndexableDataset [as 别名]
def test_filter_batches(self):
        data = [1, 2, 3, 4]
        data_filtered = [([3, 4],)]
        stream = DataStream(IndexableDataset(data),
                            iteration_scheme=SequentialScheme(4, 2))
        wrapper = Filter(stream, lambda d: d[0][0] % 3 == 0)
        assert_equal(list(wrapper.get_epoch_iterator()), data_filtered) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:9,代码来源:test_transformers.py


注:本文中的fuel.datasets.IndexableDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。