本文整理汇总了Python中chainer.training.extensions.snapshot方法的典型用法代码示例。如果您正苦于以下问题:Python extensions.snapshot方法的具体用法?Python extensions.snapshot怎么用?Python extensions.snapshot使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer.training.extensions
的用法示例。
在下文中一共展示了extensions.snapshot方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_on_error
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_on_error(self):
class TheOnlyError(Exception):
pass
@training.make_extension(trigger=(1, 'iteration'), priority=100)
def exception_raiser(trainer):
raise TheOnlyError()
self.trainer.extend(exception_raiser)
snapshot = extensions.snapshot_object(self.trainer, self.filename,
snapshot_on_error=True)
self.trainer.extend(snapshot)
self.assertFalse(os.path.exists(self.filename))
with self.assertRaises(TheOnlyError):
self.trainer.run()
self.assertTrue(os.path.exists(self.filename))
示例2: test_smoke_wrapper
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_smoke_wrapper():
rs = [[0, 1], ]
comm = create_communicator('naive')
if comm.size < 2:
pytest.skip()
snapshot = extensions.snapshot()
filename = '{}.{}'.format(snapshot.filename, comm.rank)
replica_sets = rs
mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
if comm.rank == 0:
assert mn_snapshot.is_master
assert filename == mn_snapshot.snapshot.filename
elif comm.rank == 1:
assert not mn_snapshot.is_master
elif comm.rank == 2:
assert mn_snapshot.is_master
assert filename == mn_snapshot.snapshot.filename
else:
assert not mn_snapshot.is_master
comm.finalize()
示例3: test_snapshot_hdfs
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_snapshot_hdfs():
trainer = chainer.testing.get_trainer_with_mock_updater()
trainer.out = '.'
trainer._done = True
with pfio.create_handler('hdfs') as fs:
tmpdir = "some-pfio-tmp-dir"
fs.makedirs(tmpdir, exist_ok=True)
file_list = list(fs.list(tmpdir))
assert len(file_list) == 0
writer = SimpleWriter(tmpdir, fs=fs)
snapshot = extensions.snapshot(writer=writer)
snapshot(trainer)
assert 'snapshot_iter_0' in fs.list(tmpdir)
trainer2 = chainer.testing.get_trainer_with_mock_updater()
load_snapshot(trainer2, tmpdir, fs=fs, fail_on_no_file=True)
# Cleanup
fs.remove(tmpdir, recursive=True)
示例4: add_default_arguments
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def add_default_arguments(parser):
parser.add_argument("log_dir", help='directory where generated models and logs shall be stored')
parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True,
help="Number of images per training batch")
parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]")
parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]")
parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume")
parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000,
help="number of iterations after which a snapshot shall be taken [default: 20000]")
parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder")
parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01,
help="initial learning rate [default: 0.01]")
parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100,
help="number of iterations after which an update shall be logged [default: 100]")
parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1,
help="Step size for decreasing learning rate [default: 0.1]")
parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000,
help="number of iterations after which testing should be performed [default: 1000]")
parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200,
help="number of test iterations [default: 200]")
parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float,
help="ratio for dropout layers")
return parser
示例5: parse_args
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', '-b', type=int, default=32,
help='Number of examples in each mini-batch')
parser.add_argument('--bproplen', '-l', type=int, default=35,
help='Number of words in each mini-batch '
'(= length of truncated BPTT)')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=0,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--gradclip', '-c', type=float, default=5,
help='Gradient norm threshold to clip')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--test', action='store_true',
help='Use tiny datasets for quick tests')
parser.set_defaults(test=False)
parser.add_argument('--hidden_size', type=int, default=300,
help='Number of LSTM units in each layer')
parser.add_argument('--embed_size', type=int, default=300,
help='Size of embeddings')
parser.add_argument('--model', '-m', default='model.npz',
help='Model file name to serialize')
parser.add_argument('--glove', default='data/glove.6B.300d.txt',
help='Path to glove embedding file.')
args = parser.parse_args()
return args
示例6: parse_args
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def parse_args(generators, discriminators, updaters):
parser = argparse.ArgumentParser(description='Semantic Segmentation using Adversarial Networks')
parser.add_argument('--generator', choices=generators.keys(), default='fcn32s',
help='Generator(segmentor) architecture')
parser.add_argument('--discriminator', choices=discriminators.keys(), default='largefov',
help='Discriminator architecture')
parser.add_argument('--updater', choices=updaters.keys(), default='gan',
help='Updater')
parser.add_argument('--initgen_path', default='pretrained_model/vgg16.npz',
help='Pretrained model of generator')
parser.add_argument('--initdis_path', default=None,
help='Pretrained model of discriminator')
parser.add_argument('--batchsize', '-b', type=int, default=1,
help='Number of images in each mini-batch')
parser.add_argument('--iteration', '-i', type=int, default=100000,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='snapshot',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--evaluate_interval', type=int, default=1000,
help='Interval of evaluation')
parser.add_argument('--snapshot_interval', type=int, default=10000,
help='Interval of snapshot')
parser.add_argument('--display_interval', type=int, default=10,
help='Interval of displaying log to console')
return parser.parse_args()
示例7: main
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=7,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
parser.add_argument('--noplot', dest='plot', action='store_false',
help='Disable PlotReport extension')
parser.add_argument('--onnx', default='',
help='Export ONNX model')
parser.add_argument('--model', '-m', default='model.npz',
help='Model file name to serialize')
parser.add_argument('--timeout', type=int, default=0,
help='Enable timeout')
parser.add_argument('--trace', default='',
help='Enable tracing')
parser.add_argument('--run_training', action='store_true',
help='Run training')
args = parser.parse_args()
main_impl(args)
示例8: test_call
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_call(self):
t = mock.MagicMock()
c = mock.MagicMock(side_effect=[True, False])
w = mock.MagicMock()
snapshot = extensions.snapshot(target=t, condition=c, writer=w)
trainer = mock.MagicMock()
snapshot(trainer)
snapshot(trainer)
assert c.call_count == 2
assert w.call_count == 1
示例9: test_savefun_and_writer_exclusive
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_savefun_and_writer_exclusive(self):
# savefun and writer arguments cannot be specified together.
def savefun(*args, **kwargs):
assert False
writer = extensions.snapshot_writers.SimpleWriter()
with pytest.raises(TypeError):
extensions.snapshot(savefun=savefun, writer=writer)
trainer = mock.MagicMock()
with pytest.raises(TypeError):
extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
示例10: test_save_file
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_save_file(self):
w = extensions.snapshot_writers.SimpleWriter()
snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat',
writer=w)
snapshot(self.trainer)
self.assertTrue(os.path.exists('myfile.dat'))
示例11: test_clean_up_tempdir
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_clean_up_tempdir(self):
snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
snapshot(self.trainer)
left_tmps = [fn for fn in os.listdir('.')
if fn.startswith('tmpmyfile.dat')]
self.assertEqual(len(left_tmps), 0)
示例12: test_remove_stale_snapshots
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_remove_stale_snapshots(self):
fmt = 'snapshot_iter_{.updater.iteration}'
retain = 3
snapshot = extensions.snapshot(filename=fmt, n_retains=retain,
autoload=False)
trainer = testing.get_trainer_with_mock_updater()
trainer.out = self.path
trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)
class TimeStampUpdater():
t = time.time() - 100
name = 'ts_updater'
priority = 1 # This must be called after snapshot taken
def __call__(self, _trainer):
filename = os.path.join(_trainer.out, fmt.format(_trainer))
self.t += 1
# For filesystems that does low timestamp precision
os.utime(filename, (self.t, self.t))
trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
trainer.run()
assert 10 == trainer.updater.iteration
assert trainer._done
pattern = os.path.join(trainer.out, "snapshot_iter_*")
found = [os.path.basename(path) for path in glob.glob(pattern)]
assert retain == len(found)
found.sort()
# snapshot_iter_(8, 9, 10) expected
expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
expected.sort()
assert expected == found
trainer2 = testing.get_trainer_with_mock_updater()
trainer2.out = self.path
assert not trainer2._done
snapshot2 = extensions.snapshot(filename=fmt, autoload=True)
# Just making sure no error occurs
snapshot2.initialize(trainer2)
示例13: test_callable_filename
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_callable_filename():
rs = [[0, 1], ]
comm = create_communicator('naive')
if comm.size < 2:
pytest.skip()
def filename_fun(t):
return 'deadbeef-{.updater.iteration}'.format(t)
snapshot = extensions.snapshot(filename=filename_fun)
trainer = mock.MagicMock()
filename = '{}.{}'.format(filename_fun(trainer), comm.rank)
replica_sets = rs
mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
if comm.rank == 0:
assert mn_snapshot.is_master
assert filename == mn_snapshot.snapshot.filename(trainer)
elif comm.rank == 1:
assert not mn_snapshot.is_master
elif comm.rank == 2:
assert mn_snapshot.is_master
assert filename == mn_snapshot.snapshot.filename(trainer)
else:
assert not mn_snapshot.is_master
comm.finalize()
示例14: test_smoke_multinode_snapshot
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_smoke_multinode_snapshot():
t = mock.MagicMock()
c = mock.MagicMock(side_effect=[True, False])
w = mock.MagicMock()
snapshot = extensions.snapshot(target=t, condition=c, writer=w)
trainer = mock.MagicMock()
comm = create_communicator('naive')
replica_sets = []
mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
mn_snapshot.initialize(trainer)
mn_snapshot(trainer)
mn_snapshot(trainer)
mn_snapshot.finalize()
if comm.rank == 0:
assert mn_snapshot.is_master
assert c.call_count == 2
assert w.call_count == 1
else:
assert not mn_snapshot.is_master
assert c.call_count == 0
assert w.call_count == 0
comm.finalize()
示例15: _prepare_multinode_snapshot
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def _prepare_multinode_snapshot(n, result):
n_units = 100
batchsize = 10
comm = create_communicator('naive')
model = L.Classifier(MLP(n_units, 10))
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.Adam(), comm)
optimizer.setup(model)
if comm.rank == 0:
train, _ = chainer.datasets.get_mnist()
else:
train, _ = None, None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
train_iter = chainer.iterators.SerialIterator(train, batchsize)
updater = StandardUpdater(train_iter, optimizer)
trainer = Trainer(updater, out=result)
snapshot = extensions.snapshot(target=updater, autoload=True)
replica_sets = []
mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
mn_snapshot.initialize(trainer)
for _ in range(n):
updater.update()
return updater, mn_snapshot, trainer