当前位置: 首页>>代码示例>>Python>>正文


Python MinibatchSource.get_checkpoint_state方法代码示例

本文整理汇总了Python中cntk.io.MinibatchSource.get_checkpoint_state方法的典型用法代码示例。如果您正苦于以下问题:Python MinibatchSource.get_checkpoint_state方法的具体用法?Python MinibatchSource.get_checkpoint_state怎么用?Python MinibatchSource.get_checkpoint_state使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在cntk.io.MinibatchSource的用法示例。


在下文中一共展示了MinibatchSource.get_checkpoint_state方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_user_deserializer_sequence_mode

# 需要导入模块: from cntk.io import MinibatchSource [as 别名]
# 或者: from cntk.io.MinibatchSource import get_checkpoint_state [as 别名]
def test_user_deserializer_sequence_mode():
    import scipy.sparse as sp
    streams = [StreamInformation('x', 0, 'dense', np.float32, (2, 3)), 
               StreamInformation('y', 1, 'sparse', np.float32, (3,))]

    def run_minibatch_source(minibatch_source, num_chunks, num_sequences_per_value):
        sequence_x_values = np.zeros(num_chunks, dtype=np.int32)
        sequence_y_values = np.zeros(num_chunks, dtype=np.int32)
        mb_count = 0
        while True:
            if mb_count % 10 == 1: # perform checkpointing
                checkpoint_state = minibatch_source.get_checkpoint_state()
                for i in range(3): 
                    minibatch_source.next_minibatch(20)
                minibatch_source.restore_from_checkpoint(checkpoint_state)
                mb_count +=1
                continue            

            mb = minibatch_source.next_minibatch(20)
            mb_count += 1
            if not mb:
                break

            for sequence in mb[minibatch_source.streams.x].asarray():
                sequence_x_values[int(sequence[0][0][0])] +=1

            for sequence in mb[minibatch_source.streams.y].as_sequences(C.sequence.input_variable((3,), True)):             
                sequence_y_values[int(sequence.toarray()[0][0])] += 1
            mb = None

        expected_values = np.full(num_chunks, fill_value=num_sequences_per_value, dtype=np.int32)
        assert (sequence_x_values == expected_values).all()
        assert (sequence_y_values == expected_values).all()

    # Big chunks
    d = GenDeserializer(stream_infos=streams, num_chunks=15, 
                        num_sequences=100, max_sequence_len=10)
    mbs = MinibatchSource([d], randomize=False, max_sweeps=2)
    state = mbs.get_checkpoint_state()
    mbs.restore_from_checkpoint(state)
    run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=200)
    # Randomized
    mbs = MinibatchSource([d], randomize=True, max_sweeps=2, randomization_window_in_chunks=5)
    state = mbs.get_checkpoint_state()
    mbs.restore_from_checkpoint(state)
    run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=200)

    # Small chunks of 1
    d = GenDeserializer(stream_infos=streams, num_chunks=15,
                        num_sequences=1, max_sequence_len=10)
    mbs = MinibatchSource([d], randomize=False, max_sweeps=3)
    run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=3)
    # Randomized
    mbs = MinibatchSource([d], randomize=True, max_sweeps=3, randomization_window_in_chunks=5)
    run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=3)
开发者ID:AllanYiin,项目名称:CNTK,代码行数:57,代码来源:io_tests.py


注:本文中的cntk.io.MinibatchSource.get_checkpoint_state方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。