当前位置: 首页>>代码示例>>Python>>正文


Python extensions.snapshot方法代码示例

本文整理汇总了Python中chainer.training.extensions.snapshot方法的典型用法代码示例。如果您正苦于以下问题:Python extensions.snapshot方法的具体用法?Python extensions.snapshot怎么用?Python extensions.snapshot使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.training.extensions的用法示例。


在下文中一共展示了extensions.snapshot方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_on_error

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_on_error(self):

        class TheOnlyError(Exception):
            pass

        @training.make_extension(trigger=(1, 'iteration'), priority=100)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        snapshot = extensions.snapshot_object(self.trainer, self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename)) 
开发者ID:chainer,项目名称:chainer,代码行数:22,代码来源:test_snapshot.py

示例2: test_smoke_wrapper

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_smoke_wrapper():
    rs = [[0, 1], ]
    comm = create_communicator('naive')
    if comm.size < 2:
        pytest.skip()

    snapshot = extensions.snapshot()
    filename = '{}.{}'.format(snapshot.filename, comm.rank)

    replica_sets = rs
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    if comm.rank == 0:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename
    elif comm.rank == 1:
        assert not mn_snapshot.is_master
    elif comm.rank == 2:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename
    else:
        assert not mn_snapshot.is_master

    comm.finalize() 
开发者ID:chainer,项目名称:chainer,代码行数:25,代码来源:test_multi_node_snapshot.py

示例3: test_snapshot_hdfs

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_snapshot_hdfs():
    trainer = chainer.testing.get_trainer_with_mock_updater()
    trainer.out = '.'
    trainer._done = True

    with pfio.create_handler('hdfs') as fs:
        tmpdir = "some-pfio-tmp-dir"
        fs.makedirs(tmpdir, exist_ok=True)
        file_list = list(fs.list(tmpdir))
        assert len(file_list) == 0

        writer = SimpleWriter(tmpdir, fs=fs)
        snapshot = extensions.snapshot(writer=writer)
        snapshot(trainer)

        assert 'snapshot_iter_0' in fs.list(tmpdir)

        trainer2 = chainer.testing.get_trainer_with_mock_updater()
        load_snapshot(trainer2, tmpdir, fs=fs, fail_on_no_file=True)

        # Cleanup
        fs.remove(tmpdir, recursive=True) 
开发者ID:pfnet,项目名称:pfio,代码行数:24,代码来源:test_snapshot.py

示例4: add_default_arguments

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def add_default_arguments(parser):
    parser.add_argument("log_dir", help='directory where generated models and logs shall be stored')
    parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True,
                        help="Number of images per training batch")
    parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]")
    parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]")
    parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume")
    parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000,
                        help="number of iterations after which a snapshot shall be taken [default: 20000]")
    parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder")
    parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01,
                        help="initial learning rate [default: 0.01]")
    parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100,
                        help="number of iterations after which an update shall be logged [default: 100]")
    parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1,
                        help="Step size for decreasing learning rate [default: 0.1]")
    parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000,
                        help="number of iterations after which testing should be performed [default: 1000]")
    parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200,
                        help="number of test iterations [default: 200]")
    parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float,
                        help="ratio for dropout layers")

    return parser 
开发者ID:Bartzi,项目名称:see,代码行数:26,代码来源:train_utils.py

示例5: parse_args

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-b', type=int, default=32,
                        help='Number of examples in each mini-batch')
    parser.add_argument('--bproplen', '-l', type=int, default=35,
                        help='Number of words in each mini-batch '
                             '(= length of truncated BPTT)')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--gradclip', '-c', type=float, default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--test', action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.set_defaults(test=False)
    parser.add_argument('--hidden_size', type=int, default=300,
                        help='Number of LSTM units in each layer')
    parser.add_argument('--embed_size', type=int, default=300,
                        help='Size of embeddings')
    parser.add_argument('--model', '-m', default='model.npz',
                        help='Model file name to serialize')
    parser.add_argument('--glove', default='data/glove.6B.300d.txt',
                        help='Path to glove embedding file.')
    args = parser.parse_args()
    return args 
开发者ID:Pinafore,项目名称:qb,代码行数:32,代码来源:main.py

示例6: parse_args

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def parse_args(generators, discriminators, updaters):
    parser = argparse.ArgumentParser(description='Semantic Segmentation using Adversarial Networks')
    parser.add_argument('--generator', choices=generators.keys(), default='fcn32s',
                        help='Generator(segmentor) architecture')
    parser.add_argument('--discriminator', choices=discriminators.keys(), default='largefov',
                        help='Discriminator architecture')
    parser.add_argument('--updater', choices=updaters.keys(), default='gan',
                        help='Updater')
    parser.add_argument('--initgen_path', default='pretrained_model/vgg16.npz',
                        help='Pretrained model of generator')
    parser.add_argument('--initdis_path', default=None,
                        help='Pretrained model of discriminator')
    parser.add_argument('--batchsize', '-b', type=int, default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--iteration', '-i', type=int, default=100000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='snapshot',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--evaluate_interval', type=int, default=1000,
                        help='Interval of evaluation')
    parser.add_argument('--snapshot_interval', type=int, default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval', type=int, default=10,
                        help='Interval of displaying log to console')
    return parser.parse_args() 
开发者ID:oyam,项目名称:Semantic-Segmentation-using-Adversarial-Networks,代码行数:31,代码来源:02-train.py

示例7: main

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=7,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    parser.add_argument('--onnx', default='',
                        help='Export ONNX model')
    parser.add_argument('--model', '-m', default='model.npz',
                        help='Model file name to serialize')
    parser.add_argument('--timeout', type=int, default=0,
                        help='Enable timeout')
    parser.add_argument('--trace', default='',
                        help='Enable tracing')
    parser.add_argument('--run_training', action='store_true',
                        help='Run training')
    args = parser.parse_args()

    main_impl(args) 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:33,代码来源:gen_mnist_mlp.py

示例8: test_call

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_call(self):
        t = mock.MagicMock()
        c = mock.MagicMock(side_effect=[True, False])
        w = mock.MagicMock()
        snapshot = extensions.snapshot(target=t, condition=c, writer=w)
        trainer = mock.MagicMock()
        snapshot(trainer)
        snapshot(trainer)

        assert c.call_count == 2
        assert w.call_count == 1 
开发者ID:chainer,项目名称:chainer,代码行数:13,代码来源:test_snapshot.py

示例9: test_savefun_and_writer_exclusive

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_savefun_and_writer_exclusive(self):
        # savefun and writer arguments cannot be specified together.
        def savefun(*args, **kwargs):
            assert False
        writer = extensions.snapshot_writers.SimpleWriter()
        with pytest.raises(TypeError):
            extensions.snapshot(savefun=savefun, writer=writer)

        trainer = mock.MagicMock()
        with pytest.raises(TypeError):
            extensions.snapshot_object(trainer, savefun=savefun, writer=writer) 
开发者ID:chainer,项目名称:chainer,代码行数:13,代码来源:test_snapshot.py

示例10: test_save_file

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_save_file(self):
        w = extensions.snapshot_writers.SimpleWriter()
        snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat',
                                              writer=w)
        snapshot(self.trainer)

        self.assertTrue(os.path.exists('myfile.dat')) 
开发者ID:chainer,项目名称:chainer,代码行数:9,代码来源:test_snapshot.py

示例11: test_clean_up_tempdir

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_clean_up_tempdir(self):
        snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
        snapshot(self.trainer)

        left_tmps = [fn for fn in os.listdir('.')
                     if fn.startswith('tmpmyfile.dat')]
        self.assertEqual(len(left_tmps), 0) 
开发者ID:chainer,项目名称:chainer,代码行数:9,代码来源:test_snapshot.py

示例12: test_remove_stale_snapshots

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_remove_stale_snapshots(self):
        fmt = 'snapshot_iter_{.updater.iteration}'
        retain = 3
        snapshot = extensions.snapshot(filename=fmt, n_retains=retain,
                                       autoload=False)

        trainer = testing.get_trainer_with_mock_updater()
        trainer.out = self.path
        trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)

        class TimeStampUpdater():
            t = time.time() - 100
            name = 'ts_updater'
            priority = 1  # This must be called after snapshot taken

            def __call__(self, _trainer):
                filename = os.path.join(_trainer.out, fmt.format(_trainer))
                self.t += 1
                # For filesystems that does low timestamp precision
                os.utime(filename, (self.t, self.t))

        trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
        trainer.run()
        assert 10 == trainer.updater.iteration
        assert trainer._done

        pattern = os.path.join(trainer.out, "snapshot_iter_*")
        found = [os.path.basename(path) for path in glob.glob(pattern)]
        assert retain == len(found)
        found.sort()
        # snapshot_iter_(8, 9, 10) expected
        expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
        expected.sort()
        assert expected == found

        trainer2 = testing.get_trainer_with_mock_updater()
        trainer2.out = self.path
        assert not trainer2._done
        snapshot2 = extensions.snapshot(filename=fmt, autoload=True)
        # Just making sure no error occurs
        snapshot2.initialize(trainer2) 
开发者ID:chainer,项目名称:chainer,代码行数:43,代码来源:test_snapshot.py

示例13: test_callable_filename

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_callable_filename():
    rs = [[0, 1], ]
    comm = create_communicator('naive')
    if comm.size < 2:
        pytest.skip()

    def filename_fun(t):
        return 'deadbeef-{.updater.iteration}'.format(t)

    snapshot = extensions.snapshot(filename=filename_fun)

    trainer = mock.MagicMock()
    filename = '{}.{}'.format(filename_fun(trainer), comm.rank)

    replica_sets = rs
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    if comm.rank == 0:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename(trainer)
    elif comm.rank == 1:
        assert not mn_snapshot.is_master
    elif comm.rank == 2:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename(trainer)
    else:
        assert not mn_snapshot.is_master

    comm.finalize() 
开发者ID:chainer,项目名称:chainer,代码行数:30,代码来源:test_multi_node_snapshot.py

示例14: test_smoke_multinode_snapshot

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def test_smoke_multinode_snapshot():
    t = mock.MagicMock()
    c = mock.MagicMock(side_effect=[True, False])
    w = mock.MagicMock()
    snapshot = extensions.snapshot(target=t, condition=c, writer=w)
    trainer = mock.MagicMock()

    comm = create_communicator('naive')
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)

    mn_snapshot.initialize(trainer)
    mn_snapshot(trainer)
    mn_snapshot(trainer)
    mn_snapshot.finalize()

    if comm.rank == 0:
        assert mn_snapshot.is_master
        assert c.call_count == 2
        assert w.call_count == 1
    else:
        assert not mn_snapshot.is_master
        assert c.call_count == 0
        assert w.call_count == 0

    comm.finalize() 
开发者ID:chainer,项目名称:chainer,代码行数:28,代码来源:test_multi_node_snapshot.py

示例15: _prepare_multinode_snapshot

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import snapshot [as 别名]
def _prepare_multinode_snapshot(n, result):
    n_units = 100
    batchsize = 10
    comm = create_communicator('naive')
    model = L.Classifier(MLP(n_units, 10))
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    if comm.rank == 0:
        train, _ = chainer.datasets.get_mnist()
    else:
        train, _ = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    train_iter = chainer.iterators.SerialIterator(train, batchsize)

    updater = StandardUpdater(train_iter, optimizer)
    trainer = Trainer(updater, out=result)

    snapshot = extensions.snapshot(target=updater, autoload=True)
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    mn_snapshot.initialize(trainer)
    for _ in range(n):
        updater.update()

    return updater, mn_snapshot, trainer 
开发者ID:chainer,项目名称:chainer,代码行数:30,代码来源:test_multi_node_snapshot.py


注:本文中的chainer.training.extensions.snapshot方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。