本文整理汇总了Python中fuel.datasets.IterableDataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.IterableDataset方法的具体用法?Python datasets.IterableDataset怎么用?Python datasets.IterableDataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类fuel.datasets
的用法示例。
在下文中一共展示了datasets.IterableDataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_main_loop
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_main_loop():
old_config_profile_value = config.profile
config.profile = True
main_loop = MainLoop(
MockAlgorithm(), IterableDataset(range(10)).get_example_stream(),
extensions=[WriteBatchExtension(), FinishAfter(after_n_epochs=2)])
main_loop.run()
assert_raises(AttributeError, getattr, main_loop, 'model')
assert main_loop.log.status['iterations_done'] == 20
assert main_loop.log.status['_epoch_ends'] == [10, 20]
assert len(main_loop.log) == 20
for i in range(20):
assert main_loop.log[i + 1]['batch'] == {'data': i % 10}
config.profile = old_config_profile_value
示例2: test_training_interrupt
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_training_interrupt():
def process_batch(batch):
time.sleep(0.1)
algorithm = MockAlgorithm()
algorithm.process_batch = process_batch
main_loop = MockMainLoop(
algorithm=algorithm,
data_stream=IterableDataset(count()).get_example_stream(),
extensions=[Printing()]
)
p = Process(target=main_loop.run)
p.start()
time.sleep(0.1)
os.kill(p.pid, signal.SIGINT)
time.sleep(0.1)
assert p.is_alive()
os.kill(p.pid, signal.SIGINT)
time.sleep(0.2)
assert not p.is_alive()
p.join()
示例3: test_dataset_evaluators
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_dataset_evaluators():
X = theano.tensor.vector('X')
Y = theano.tensor.vector('Y')
data = [numpy.arange(1, 7, dtype=theano.config.floatX).reshape(3, 2),
numpy.arange(11, 17, dtype=theano.config.floatX).reshape(3, 2)]
data_stream = IterableDataset(dict(X=data[0],
Y=data[1])).get_example_stream()
validator = DatasetEvaluator([
CrossEntropy(requires=[X, Y],
name="monitored_cross_entropy0"),
# to test two same quantities and make sure that state will be reset
CrossEntropy(requires=[X, Y],
name="monitored_cross_entropy1"),
CategoricalCrossEntropy().apply(X, Y), ])
values = validator.evaluate(data_stream)
numpy.testing.assert_allclose(
values['monitored_cross_entropy1'],
values['categoricalcrossentropy_apply_cost'])
示例4: test_dataset_evaluators
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_dataset_evaluators():
X = theano.tensor.matrix('X')
brick = TestBrick(name='test_brick')
Y = brick.apply(X)
graph = ComputationGraph([Y])
monitor_variables = [v for v in graph.auxiliary_variables]
validator = DatasetEvaluator(monitor_variables)
data = [numpy.arange(1, 5, dtype=theano.config.floatX).reshape(2, 2),
numpy.arange(10, 16, dtype=theano.config.floatX).reshape(3, 2)]
data_stream = IterableDataset(dict(X=data)).get_example_stream()
values = validator.evaluate(data_stream)
assert values['test_brick_apply_V_squared'] == 4
numpy.testing.assert_allclose(
values['test_brick_apply_mean_row_mean'], numpy.vstack(data).mean())
per_batch_mean = numpy.mean([batch.mean() for batch in data])
numpy.testing.assert_allclose(
values['test_brick_apply_mean_batch_element'], per_batch_mean)
with assert_raises(Exception) as ar:
data_stream = IterableDataset(dict(X2=data)).get_example_stream()
validator.evaluate(data_stream)
assert "Not all data sources" in ar.exception.args[0]
示例5: test_training_resumption
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_training_resumption():
def do_test(with_serialization):
data_stream = IterableDataset(range(10)).get_example_stream()
main_loop = MainLoop(
MockAlgorithm(), data_stream,
extensions=[WriteBatchExtension(),
FinishAfter(after_n_batches=14)])
main_loop.run()
assert main_loop.log.status['iterations_done'] == 14
if with_serialization:
main_loop = cPickle.loads(cPickle.dumps(main_loop))
finish_after = unpack(
[ext for ext in main_loop.extensions
if isinstance(ext, FinishAfter)], singleton=True)
finish_after.add_condition(
["after_batch"],
predicate=lambda log: log.status['iterations_done'] == 27)
main_loop.run()
assert main_loop.log.status['iterations_done'] == 27
assert main_loop.log.status['epochs_done'] == 2
for i in range(27):
assert main_loop.log[i + 1]['batch'] == {"data": i % 10}
do_test(False)
do_test(True)
示例6: test_shared_variable_modifier_two_parameters
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_shared_variable_modifier_two_parameters():
weights = numpy.array([-1, 1], dtype=theano.config.floatX)
features = [numpy.array(f, dtype=theano.config.floatX)
for f in [[1, 2], [3, 4], [5, 6]]]
targets = [(weights * f).sum() for f in features]
n_batches = 3
dataset = IterableDataset(dict(features=features, targets=targets))
x = tensor.vector('features')
y = tensor.scalar('targets')
W = shared_floatx([0, 0], name='W')
cost = ((x * W).sum() - y) ** 2
cost.name = 'cost'
step_rule = Scale(0.001)
sgd = GradientDescent(cost=cost, parameters=[W],
step_rule=step_rule)
modifier = SharedVariableModifier(
step_rule.learning_rate,
lambda _, val: numpy.cast[theano.config.floatX](val * 0.2))
main_loop = MainLoop(
model=None, data_stream=dataset.get_example_stream(),
algorithm=sgd,
extensions=[FinishAfter(after_n_epochs=1), modifier])
main_loop.run()
new_value = step_rule.learning_rate.get_value()
assert_allclose(new_value,
0.001 * 0.2 ** n_batches,
atol=1e-5)
示例7: setup_mainloop
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def setup_mainloop(extension):
"""Set up a simple main loop for progress bar tests.
Create a MainLoop, register the given extension, supply it with a
DataStream and a minimal model/cost to optimize.
"""
# Since progressbar2 3.6.0, the `maxval` kwarg has been replaced by
# `max_value`, which has a default value of 100. If we're still using
# `maxval` by accident, this test should fail complaining that
# the progress bar has received a value out of range.
features = [numpy.array(f, dtype=theano.config.floatX)
for f in [[1, 2]] * 101]
dataset = IterableDataset(dict(features=features))
W = shared_floatx([0, 0], name='W')
x = tensor.vector('features')
cost = tensor.sum((x-W)**2)
cost.name = "cost"
algorithm = GradientDescent(cost=cost, parameters=[W],
step_rule=Scale(1e-3))
main_loop = MainLoop(
model=None, data_stream=dataset.get_example_stream(),
algorithm=algorithm,
extensions=[
FinishAfter(after_n_epochs=1),
extension])
return main_loop
示例8: setUp
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def setUp(self):
self.data = [1, 2, 3]
self.stream = DataStream(IterableDataset(self.data))
示例9: test_default_transformer
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_default_transformer(self):
class DoublingDataset(IterableDataset):
def apply_default_transformer(self, stream):
return Mapping(
stream, lambda sources: tuple(2 * s for s in sources))
dataset = DoublingDataset(self.data)
stream = dataset.apply_default_transformer(DataStream(dataset))
assert_equal(list(stream.get_epoch_iterator()), [(2,), (4,), (6,)])
示例10: test_no_axis_labels
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_no_axis_labels(self):
assert IterableDataset(self.data).axis_labels is None
示例11: test_axis_labels
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_axis_labels(self):
axis_labels = {'data': ('batch',)}
dataset = IterableDataset(self.data, axis_labels=axis_labels)
assert dataset.axis_labels == axis_labels
示例12: test_filter_sources
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_filter_sources(self):
dataset = IterableDataset(
OrderedDict([('1', [1, 2]), ('2', [3, 4])]), sources=('1',))
assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2],))
示例13: test_value_error_on_non_iterable_dict
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_value_error_on_non_iterable_dict(self):
assert_raises(ValueError, IterableDataset, {'x': None, 'y': None})
示例14: test_value_error_get_data_none_state
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_value_error_get_data_none_state(self):
assert_raises(
ValueError, IterableDataset([1, 2, 3]).get_data, None, None)
示例15: test_value_error_get_data_request
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import IterableDataset [as 别名]
def test_value_error_get_data_request(self):
assert_raises(
ValueError, IterableDataset([1, 2, 3]).get_data, [1, 2, 3], True)