當前位置: 首頁>>代碼示例>>Python>>正文


Python saveload.Checkpoint方法代碼示例

本文整理匯總了Python中blocks.extensions.saveload.Checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python saveload.Checkpoint方法的具體用法?Python saveload.Checkpoint怎麽用?Python saveload.Checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在blocks.extensions.saveload的用法示例。


在下文中一共展示了saveload.Checkpoint方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_add_list_condition

# 需要導入模塊: from blocks.extensions import saveload [as 別名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 別名]
def test_add_list_condition():
    extension_list = Checkpoint('extension_list').add_condition(
        ['before_first_epoch', 'after_epoch'],
        OnLogRecord('notification_name'),
        ('dest_path.kl',))
    extension_iter = Checkpoint('extension_iter')
    extension_iter.add_condition(
        ['before_first_epoch'],
        OnLogRecord('notification_name'),
        ('dest_path.kl',))
    extension_iter.add_condition(
        ['after_epoch'],
        OnLogRecord('notification_name'),
        ('dest_path.kl',))
    assert len(extension_list._conditions) == len(extension_iter._conditions)
    assert_raises(ValueError, extension_iter.add_condition,
                  callbacks_names='after_epoch',
                  predicate=OnLogRecord('notification_name'),
                  arguments=('dest_path.kl',)) 
開發者ID:rizar,項目名稱:attention-lvcsr,代碼行數:21,代碼來源:test_extensions.py

示例2: track_best

# 需要導入模塊: from blocks.extensions import saveload [as 別名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 別名]
def track_best(channel, save_path):
    tracker = TrackTheBest(channel, choose_best=min)
    checkpoint = saveload.Checkpoint(
        save_path, after_training=False, use_cpickle=True)
    checkpoint.add_condition(["after_epoch"],
                             predicate=predicates.OnLogRecord('{0}_best_so_far'.format(channel)))
    return [tracker, checkpoint] 
開發者ID:johnarevalo,項目名稱:blocks-char-rnn,代碼行數:9,代碼來源:utils.py

示例3: run

# 需要導入模塊: from blocks.extensions import saveload [as 別名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 別名]
def run():
    streams = create_celeba_streams(training_batch_size=100,
                                    monitoring_batch_size=500,
                                    include_targets=True)
    main_loop_stream = streams[0]
    train_monitor_stream = streams[1]
    valid_monitor_stream = streams[2]

    cg, bn_dropout_cg = create_training_computation_graphs()

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    pop_updates = get_batch_normalization_updates(bn_dropout_cg)
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_dropout_cg.outputs[0],
                                parameters=bn_dropout_cg.parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    cost = bn_dropout_cg.outputs[0]
    cost.name = 'cost'
    train_monitoring = DataStreamMonitoring(
        [cost], train_monitor_stream, prefix="train",
        before_first_epoch=False, after_epoch=False, after_training=True,
        updates=extra_updates)

    cost, accuracy = cg.outputs
    cost.name = 'cost'
    accuracy.name = 'accuracy'
    monitored_quantities = [cost, accuracy]
    valid_monitoring = DataStreamMonitoring(
        monitored_quantities, valid_monitor_stream, prefix="valid",
        before_first_epoch=False, after_epoch=False, every_n_epochs=5)

    # Prepare checkpoint
    checkpoint = Checkpoint(
        'celeba_classifier.zip', every_n_epochs=5, use_cpickle=True)

    extensions = [Timing(), FinishAfter(after_n_epochs=50), train_monitoring,
                  valid_monitoring, checkpoint, Printing(), ProgressBar()]
    main_loop = MainLoop(data_stream=main_loop_stream, algorithm=algorithm,
                         extensions=extensions)
    main_loop.run() 
開發者ID:vdumoulin,項目名稱:discgen,代碼行數:51,代碼來源:train_celeba_classifier.py

示例4: run

# 需要導入模塊: from blocks.extensions import saveload [as 別名]
# 或者: from blocks.extensions.saveload import Checkpoint [as 別名]
def run(discriminative_regularization=True):
    streams = create_celeba_streams(training_batch_size=100,
                                    monitoring_batch_size=500,
                                    include_targets=False)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    rval = create_training_computation_graphs(discriminative_regularization)
    cg, bn_cg, variance_parameters = rval
    pop_updates = list(
        set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    model = Model(bn_cg.outputs[0])
    selector = Selector(
        find_bricks(
            model.top_bricks,
            lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp',
                                         'decoder_convnet', 'decoder_mlp')))
    parameters = list(selector.get_parameters().values()) + variance_parameters

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_cg.outputs[0],
                                parameters=parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    monitored_quantities_list = []
    for graph in [bn_cg, cg]:
        cost, kl_term, reconstruction_term = graph.outputs
        cost.name = 'nll_upper_bound'
        avg_kl_term = kl_term.mean(axis=0)
        avg_kl_term.name = 'avg_kl_term'
        avg_reconstruction_term = -reconstruction_term.mean(axis=0)
        avg_reconstruction_term.name = 'avg_reconstruction_term'
        monitored_quantities_list.append(
            [cost, avg_kl_term, avg_reconstruction_term])
    train_monitoring = DataStreamMonitoring(
        monitored_quantities_list[0], train_monitor_stream, prefix="train",
        updates=extra_updates, after_epoch=False, before_first_epoch=False,
        every_n_epochs=5)
    valid_monitoring = DataStreamMonitoring(
        monitored_quantities_list[1], valid_monitor_stream, prefix="valid",
        after_epoch=False, before_first_epoch=False, every_n_epochs=5)

    # Prepare checkpoint
    save_path = 'celeba_vae_{}regularization.zip'.format(
        '' if discriminative_regularization else 'no_')
    checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True)

    extensions = [Timing(), FinishAfter(after_n_epochs=75), train_monitoring,
                  valid_monitoring, checkpoint, Printing(), ProgressBar()]
    main_loop = MainLoop(data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    main_loop.run() 
開發者ID:vdumoulin,項目名稱:discgen,代碼行數:62,代碼來源:train_celeba_vae.py


注:本文中的blocks.extensions.saveload.Checkpoint方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。