本文整理汇总了Python中blocks.extensions.saveload.Checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python saveload.Checkpoint方法的具体用法?Python saveload.Checkpoint怎么用?Python saveload.Checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类blocks.extensions.saveload
的用法示例。
在下文中一共展示了saveload.Checkpoint方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_add_list_condition
# 需要导入模块: from blocks.extensions import saveload [as 别名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 别名]
def test_add_list_condition():
extension_list = Checkpoint('extension_list').add_condition(
['before_first_epoch', 'after_epoch'],
OnLogRecord('notification_name'),
('dest_path.kl',))
extension_iter = Checkpoint('extension_iter')
extension_iter.add_condition(
['before_first_epoch'],
OnLogRecord('notification_name'),
('dest_path.kl',))
extension_iter.add_condition(
['after_epoch'],
OnLogRecord('notification_name'),
('dest_path.kl',))
assert len(extension_list._conditions) == len(extension_iter._conditions)
assert_raises(ValueError, extension_iter.add_condition,
callbacks_names='after_epoch',
predicate=OnLogRecord('notification_name'),
arguments=('dest_path.kl',))
示例2: track_best
# 需要导入模块: from blocks.extensions import saveload [as 别名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 别名]
def track_best(channel, save_path):
tracker = TrackTheBest(channel, choose_best=min)
checkpoint = saveload.Checkpoint(
save_path, after_training=False, use_cpickle=True)
checkpoint.add_condition(["after_epoch"],
predicate=predicates.OnLogRecord('{0}_best_so_far'.format(channel)))
return [tracker, checkpoint]
示例3: run
# 需要导入模块: from blocks.extensions import saveload [as 别名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 别名]
def run():
streams = create_celeba_streams(training_batch_size=100,
monitoring_batch_size=500,
include_targets=True)
main_loop_stream = streams[0]
train_monitor_stream = streams[1]
valid_monitor_stream = streams[2]
cg, bn_dropout_cg = create_training_computation_graphs()
# Compute parameter updates for the batch normalization population
# statistics. They are updated following an exponential moving average.
pop_updates = get_batch_normalization_updates(bn_dropout_cg)
decay_rate = 0.05
extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
for p, m in pop_updates]
# Prepare algorithm
step_rule = Adam()
algorithm = GradientDescent(cost=bn_dropout_cg.outputs[0],
parameters=bn_dropout_cg.parameters,
step_rule=step_rule)
algorithm.add_updates(extra_updates)
# Prepare monitoring
cost = bn_dropout_cg.outputs[0]
cost.name = 'cost'
train_monitoring = DataStreamMonitoring(
[cost], train_monitor_stream, prefix="train",
before_first_epoch=False, after_epoch=False, after_training=True,
updates=extra_updates)
cost, accuracy = cg.outputs
cost.name = 'cost'
accuracy.name = 'accuracy'
monitored_quantities = [cost, accuracy]
valid_monitoring = DataStreamMonitoring(
monitored_quantities, valid_monitor_stream, prefix="valid",
before_first_epoch=False, after_epoch=False, every_n_epochs=5)
# Prepare checkpoint
checkpoint = Checkpoint(
'celeba_classifier.zip', every_n_epochs=5, use_cpickle=True)
extensions = [Timing(), FinishAfter(after_n_epochs=50), train_monitoring,
valid_monitoring, checkpoint, Printing(), ProgressBar()]
main_loop = MainLoop(data_stream=main_loop_stream, algorithm=algorithm,
extensions=extensions)
main_loop.run()
示例4: run
# 需要导入模块: from blocks.extensions import saveload [as 别名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 别名]
def run(discriminative_regularization=True):
streams = create_celeba_streams(training_batch_size=100,
monitoring_batch_size=500,
include_targets=False)
main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]
# Compute parameter updates for the batch normalization population
# statistics. They are updated following an exponential moving average.
rval = create_training_computation_graphs(discriminative_regularization)
cg, bn_cg, variance_parameters = rval
pop_updates = list(
set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
decay_rate = 0.05
extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
for p, m in pop_updates]
model = Model(bn_cg.outputs[0])
selector = Selector(
find_bricks(
model.top_bricks,
lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp',
'decoder_convnet', 'decoder_mlp')))
parameters = list(selector.get_parameters().values()) + variance_parameters
# Prepare algorithm
step_rule = Adam()
algorithm = GradientDescent(cost=bn_cg.outputs[0],
parameters=parameters,
step_rule=step_rule)
algorithm.add_updates(extra_updates)
# Prepare monitoring
monitored_quantities_list = []
for graph in [bn_cg, cg]:
cost, kl_term, reconstruction_term = graph.outputs
cost.name = 'nll_upper_bound'
avg_kl_term = kl_term.mean(axis=0)
avg_kl_term.name = 'avg_kl_term'
avg_reconstruction_term = -reconstruction_term.mean(axis=0)
avg_reconstruction_term.name = 'avg_reconstruction_term'
monitored_quantities_list.append(
[cost, avg_kl_term, avg_reconstruction_term])
train_monitoring = DataStreamMonitoring(
monitored_quantities_list[0], train_monitor_stream, prefix="train",
updates=extra_updates, after_epoch=False, before_first_epoch=False,
every_n_epochs=5)
valid_monitoring = DataStreamMonitoring(
monitored_quantities_list[1], valid_monitor_stream, prefix="valid",
after_epoch=False, before_first_epoch=False, every_n_epochs=5)
# Prepare checkpoint
save_path = 'celeba_vae_{}regularization.zip'.format(
'' if discriminative_regularization else 'no_')
checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True)
extensions = [Timing(), FinishAfter(after_n_epochs=75), train_monitoring,
valid_monitoring, checkpoint, Printing(), ProgressBar()]
main_loop = MainLoop(data_stream=main_loop_stream,
algorithm=algorithm, extensions=extensions)
main_loop.run()