当前位置: 首页>>代码示例>>Python>>正文


Python datasets.IterableDataset方法代码示例

本文整理汇总了Python中fuel.datasets.IterableDataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.IterableDataset方法的具体用法?Python datasets.IterableDataset怎么用?Python datasets.IterableDataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在fuel.datasets的用法示例。


在下文中一共展示了datasets.IterableDataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_main_loop

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_main_loop():
    old_config_profile_value = config.profile
    config.profile = True

    main_loop = MainLoop(
        MockAlgorithm(), IterableDataset(range(10)).get_example_stream(),
        extensions=[WriteBatchExtension(), FinishAfter(after_n_epochs=2)])
    main_loop.run()
    assert_raises(AttributeError, getattr, main_loop, 'model')

    assert main_loop.log.status['iterations_done'] == 20
    assert main_loop.log.status['_epoch_ends'] == [10, 20]
    assert len(main_loop.log) == 20
    for i in range(20):
        assert main_loop.log[i + 1]['batch'] == {'data': i % 10}

    config.profile = old_config_profile_value 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:19,代码来源:test_main_loop.py

示例2: test_training_interrupt

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_training_interrupt():
    def process_batch(batch):
        time.sleep(0.1)

    algorithm = MockAlgorithm()
    algorithm.process_batch = process_batch

    main_loop = MockMainLoop(
        algorithm=algorithm,
        data_stream=IterableDataset(count()).get_example_stream(),
        extensions=[Printing()]
    )

    p = Process(target=main_loop.run)
    p.start()
    time.sleep(0.1)
    os.kill(p.pid, signal.SIGINT)
    time.sleep(0.1)
    assert p.is_alive()
    os.kill(p.pid, signal.SIGINT)
    time.sleep(0.2)
    assert not p.is_alive()
    p.join() 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:25,代码来源:test_main_loop.py

示例3: test_dataset_evaluators

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_dataset_evaluators():
    X = theano.tensor.vector('X')
    Y = theano.tensor.vector('Y')

    data = [numpy.arange(1, 7, dtype=theano.config.floatX).reshape(3, 2),
            numpy.arange(11, 17, dtype=theano.config.floatX).reshape(3, 2)]
    data_stream = IterableDataset(dict(X=data[0],
                                       Y=data[1])).get_example_stream()

    validator = DatasetEvaluator([
        CrossEntropy(requires=[X, Y],
                     name="monitored_cross_entropy0"),
        # to test two same quantities and make sure that state will be reset
        CrossEntropy(requires=[X, Y],
                     name="monitored_cross_entropy1"),
        CategoricalCrossEntropy().apply(X, Y), ])
    values = validator.evaluate(data_stream)
    numpy.testing.assert_allclose(
        values['monitored_cross_entropy1'],
        values['categoricalcrossentropy_apply_cost']) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:22,代码来源:test_monitored_quantity.py

示例4: test_dataset_evaluators

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_dataset_evaluators():
    X = theano.tensor.matrix('X')
    brick = TestBrick(name='test_brick')
    Y = brick.apply(X)
    graph = ComputationGraph([Y])
    monitor_variables = [v for v in graph.auxiliary_variables]
    validator = DatasetEvaluator(monitor_variables)

    data = [numpy.arange(1, 5, dtype=theano.config.floatX).reshape(2, 2),
            numpy.arange(10, 16, dtype=theano.config.floatX).reshape(3, 2)]
    data_stream = IterableDataset(dict(X=data)).get_example_stream()

    values = validator.evaluate(data_stream)
    assert values['test_brick_apply_V_squared'] == 4
    numpy.testing.assert_allclose(
        values['test_brick_apply_mean_row_mean'], numpy.vstack(data).mean())
    per_batch_mean = numpy.mean([batch.mean() for batch in data])
    numpy.testing.assert_allclose(
        values['test_brick_apply_mean_batch_element'], per_batch_mean)

    with assert_raises(Exception) as ar:
        data_stream = IterableDataset(dict(X2=data)).get_example_stream()
        validator.evaluate(data_stream)
    assert "Not all data sources" in ar.exception.args[0] 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:26,代码来源:test_evaluators.py

示例5: test_training_resumption

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_training_resumption():
    def do_test(with_serialization):
        data_stream = IterableDataset(range(10)).get_example_stream()
        main_loop = MainLoop(
            MockAlgorithm(), data_stream,
            extensions=[WriteBatchExtension(),
                        FinishAfter(after_n_batches=14)])
        main_loop.run()
        assert main_loop.log.status['iterations_done'] == 14

        if with_serialization:
            main_loop = cPickle.loads(cPickle.dumps(main_loop))

        finish_after = unpack(
            [ext for ext in main_loop.extensions
             if isinstance(ext, FinishAfter)], singleton=True)
        finish_after.add_condition(
            ["after_batch"],
            predicate=lambda log: log.status['iterations_done'] == 27)
        main_loop.run()
        assert main_loop.log.status['iterations_done'] == 27
        assert main_loop.log.status['epochs_done'] == 2
        for i in range(27):
            assert main_loop.log[i + 1]['batch'] == {"data": i % 10}

    do_test(False)
    do_test(True) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:29,代码来源:test_main_loop.py

示例6: test_shared_variable_modifier_two_parameters

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_shared_variable_modifier_two_parameters():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    W = shared_floatx([0, 0], name='W')
    cost = ((x * W).sum() - y) ** 2
    cost.name = 'cost'

    step_rule = Scale(0.001)
    sgd = GradientDescent(cost=cost, parameters=[W],
                          step_rule=step_rule)
    modifier = SharedVariableModifier(
        step_rule.learning_rate,
        lambda _, val: numpy.cast[theano.config.floatX](val * 0.2))
    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=sgd,
        extensions=[FinishAfter(after_n_epochs=1), modifier])

    main_loop.run()

    new_value = step_rule.learning_rate.get_value()
    assert_allclose(new_value,
                    0.001 * 0.2 ** n_batches,
                    atol=1e-5) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:33,代码来源:test_training.py

示例7: setup_mainloop

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def setup_mainloop(extension):
    """Set up a simple main loop for progress bar tests.

    Create a MainLoop, register the given extension, supply it with a
    DataStream and a minimal model/cost to optimize.

    """
    # Since progressbar2 3.6.0, the `maxval` kwarg has been replaced by
    # `max_value`, which has a default value of 100. If we're still using
    # `maxval` by accident, this test should fail complaining that
    # the progress bar has received a value out of range.
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2]] * 101]
    dataset = IterableDataset(dict(features=features))

    W = shared_floatx([0, 0], name='W')
    x = tensor.vector('features')
    cost = tensor.sum((x-W)**2)
    cost.name = "cost"

    algorithm = GradientDescent(cost=cost, parameters=[W],
                                step_rule=Scale(1e-3))

    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=algorithm,
        extensions=[
            FinishAfter(after_n_epochs=1),
            extension])

    return main_loop 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:33,代码来源:test_progressbar.py

示例8: setUp

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def setUp(self):
        self.data = [1, 2, 3]
        self.stream = DataStream(IterableDataset(self.data)) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:5,代码来源:test_datasets.py

示例9: test_default_transformer

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_default_transformer(self):
        class DoublingDataset(IterableDataset):
            def apply_default_transformer(self, stream):
                return Mapping(
                    stream, lambda sources: tuple(2 * s for s in sources))
        dataset = DoublingDataset(self.data)
        stream = dataset.apply_default_transformer(DataStream(dataset))
        assert_equal(list(stream.get_epoch_iterator()), [(2,), (4,), (6,)]) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:10,代码来源:test_datasets.py

示例10: test_no_axis_labels

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_no_axis_labels(self):
        assert IterableDataset(self.data).axis_labels is None 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:4,代码来源:test_datasets.py

示例11: test_axis_labels

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_axis_labels(self):
        axis_labels = {'data': ('batch',)}
        dataset = IterableDataset(self.data, axis_labels=axis_labels)
        assert dataset.axis_labels == axis_labels 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:6,代码来源:test_datasets.py

示例12: test_filter_sources

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_filter_sources(self):
        dataset = IterableDataset(
            OrderedDict([('1', [1, 2]), ('2', [3, 4])]), sources=('1',))
        assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2],)) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:6,代码来源:test_datasets.py

示例13: test_value_error_on_non_iterable_dict

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_value_error_on_non_iterable_dict(self):
        assert_raises(ValueError, IterableDataset, {'x': None, 'y': None}) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:4,代码来源:test_datasets.py

示例14: test_value_error_get_data_none_state

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_value_error_get_data_none_state(self):
        assert_raises(
            ValueError, IterableDataset([1, 2, 3]).get_data, None, None) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:5,代码来源:test_datasets.py

示例15: test_value_error_get_data_request

# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_value_error_get_data_request(self):
        assert_raises(
            ValueError, IterableDataset([1, 2, 3]).get_data, [1, 2, 3], True) 
开发者ID:rizar,项目名称:attention-lvcsr,代码行数:5,代码来源:test_datasets.py


注:本文中的fuel.datasets.IterableDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。