本文整理汇总了Python中cntk.io.MinibatchSource.get_checkpoint_state方法的典型用法代码示例。如果您正苦于以下问题:Python MinibatchSource.get_checkpoint_state方法的具体用法?Python MinibatchSource.get_checkpoint_state怎么用?Python MinibatchSource.get_checkpoint_state使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类cntk.io.MinibatchSource
的用法示例。
在下文中一共展示了MinibatchSource.get_checkpoint_state方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_user_deserializer_sequence_mode
# 需要导入模块: from cntk.io import MinibatchSource [as 别名]
# 或者: from cntk.io.MinibatchSource import get_checkpoint_state [as 别名]
def test_user_deserializer_sequence_mode():
import scipy.sparse as sp
streams = [StreamInformation('x', 0, 'dense', np.float32, (2, 3)),
StreamInformation('y', 1, 'sparse', np.float32, (3,))]
def run_minibatch_source(minibatch_source, num_chunks, num_sequences_per_value):
sequence_x_values = np.zeros(num_chunks, dtype=np.int32)
sequence_y_values = np.zeros(num_chunks, dtype=np.int32)
mb_count = 0
while True:
if mb_count % 10 == 1: # perform checkpointing
checkpoint_state = minibatch_source.get_checkpoint_state()
for i in range(3):
minibatch_source.next_minibatch(20)
minibatch_source.restore_from_checkpoint(checkpoint_state)
mb_count +=1
continue
mb = minibatch_source.next_minibatch(20)
mb_count += 1
if not mb:
break
for sequence in mb[minibatch_source.streams.x].asarray():
sequence_x_values[int(sequence[0][0][0])] +=1
for sequence in mb[minibatch_source.streams.y].as_sequences(C.sequence.input_variable((3,), True)):
sequence_y_values[int(sequence.toarray()[0][0])] += 1
mb = None
expected_values = np.full(num_chunks, fill_value=num_sequences_per_value, dtype=np.int32)
assert (sequence_x_values == expected_values).all()
assert (sequence_y_values == expected_values).all()
# Big chunks
d = GenDeserializer(stream_infos=streams, num_chunks=15,
num_sequences=100, max_sequence_len=10)
mbs = MinibatchSource([d], randomize=False, max_sweeps=2)
state = mbs.get_checkpoint_state()
mbs.restore_from_checkpoint(state)
run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=200)
# Randomized
mbs = MinibatchSource([d], randomize=True, max_sweeps=2, randomization_window_in_chunks=5)
state = mbs.get_checkpoint_state()
mbs.restore_from_checkpoint(state)
run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=200)
# Small chunks of 1
d = GenDeserializer(stream_infos=streams, num_chunks=15,
num_sequences=1, max_sequence_len=10)
mbs = MinibatchSource([d], randomize=False, max_sweeps=3)
run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=3)
# Randomized
mbs = MinibatchSource([d], randomize=True, max_sweeps=3, randomization_window_in_chunks=5)
run_minibatch_source(mbs, num_chunks=15, num_sequences_per_value=3)