当前位置: 首页>>代码示例>>Python>>正文


Python chainermn.scatter_dataset方法代码示例

本文整理汇总了Python中chainermn.scatter_dataset方法的典型用法代码示例。如果您正苦于以下问题:Python chainermn.scatter_dataset方法的具体用法?Python chainermn.scatter_dataset怎么用?Python chainermn.scatter_dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainermn的用法示例。


在下文中一共展示了chainermn.scatter_dataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_dataset

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def make_dataset(self, stage_int):
        if self.is_master:
            size = 4 * (2 ** ((stage_int + 1) // 2))
            _dataset = BaseDataset(
                json.load(open(FLAGS.dataset_config, 'r')),
                '%dx%d' % (size, size),
                [["resize", {"probability": 1, "width": size, "height": size, "resample_filter": "ANTIALIAS"}]]
            )
            self.print_log('Add (master) dataset for size {}'.format(size))
        else:
            _dataset = None
            self.print_log('Add (slave) dataset')

        if self.use_mpi:
            _dataset = chainermn.scatter_dataset(_dataset, self.comm)

        return _dataset 
开发者ID:pfnet-research,项目名称:chainer-stylegan,代码行数:19,代码来源:train.py

示例2: scatter_large_data

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def scatter_large_data(communicator):
    data = []
    if communicator.rank == 0:
        data = ['test'] * 2000000000
    data = chainermn.scatter_dataset(data, communicator)
    assert len(data) > 0 
开发者ID:chainer,项目名称:chainer,代码行数:8,代码来源:test_scatter.py

示例3: _prepare_multinode_snapshot

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def _prepare_multinode_snapshot(n, result):
    n_units = 100
    batchsize = 10
    comm = create_communicator('naive')
    model = L.Classifier(MLP(n_units, 10))
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    if comm.rank == 0:
        train, _ = chainer.datasets.get_mnist()
    else:
        train, _ = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    train_iter = chainer.iterators.SerialIterator(train, batchsize)

    updater = StandardUpdater(train_iter, optimizer)
    trainer = Trainer(updater, out=result)

    snapshot = extensions.snapshot(target=updater, autoload=True)
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    mn_snapshot.initialize(trainer)
    for _ in range(n):
        updater.update()

    return updater, mn_snapshot, trainer 
开发者ID:chainer,项目名称:chainer,代码行数:30,代码来源:test_multi_node_snapshot.py

示例4: setup_mnist_trainer

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def setup_mnist_trainer(self, display_log=False, use_chx=False):
        batchsize = 100
        n_units = 100

        comm = self.communicator
        model = L.Classifier(MLP(n_units, 10))

        model.to_device(get_device(None, use_chx))

        optimizer = chainermn.create_multi_node_optimizer(
            chainer.optimizers.Adam(), comm)
        optimizer.setup(model)

        if comm.rank == 0:
            train, test = chainer.datasets.get_mnist()
        else:
            train, test = None, None

        train = chainermn.scatter_dataset(train, comm, shuffle=True)
        test = chainermn.scatter_dataset(test, comm, shuffle=True)

        train_iter = chainer.iterators.SerialIterator(train, batchsize)
        test_iter = chainer.iterators.SerialIterator(test, batchsize,
                                                     repeat=False,
                                                     shuffle=False)

        updater = training.StandardUpdater(
            train_iter,
            optimizer
        )

        return updater, optimizer, train_iter, test_iter, model 
开发者ID:chainer,项目名称:chainer,代码行数:34,代码来源:test_checkpoint.py

示例5: check_scatter_dataset

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def check_scatter_dataset(self, original_dataset, shuffle=False, root=0):
        if self.communicator.rank != root:
            original_dataset = None
        my_dataset = chainermn.scatter_dataset(
            original_dataset, self.communicator,
            shuffle=shuffle, root=root)
        sub_datasets = self.communicator.gather_obj(my_dataset, root=root)

        if self.communicator.rank == root:
            # Test the sizes
            sub_sizes = [len(sub_dataset) for sub_dataset in sub_datasets]
            self.assertEqual(len(set(sub_sizes)), 1)
            sub_size = sub_sizes[0]
            self.assertLessEqual(
                len(original_dataset), sub_size * self.mpi_comm.size)
            self.assertGreater(
                len(original_dataset), (sub_size - 1) * self.mpi_comm.size)

            # Test the content of scattered datasets
            joined_dataset = sum((sub_dataset[:]
                                  for sub_dataset in sub_datasets), [])

            # NOTE: The values in `original_dataset` and
            # `joined_dataset` must be casted to int to compare.
            # There are 2 backgrounds on this issue.
            #
            # (1) numpy and cupy/chainerx have different behaviours on
            # 1-element array. Numpy implicitly converts a 1-element array to
            # a scalar value.
            # type(numpy.array([1])[0])
            # =>  <class 'numpy.int64'>  # Scalar
            # type(chainerx.array([1])[0])
            # => <class 'chainerx.ndarray'>  # array of one element
            #
            # (2) Two different ChainerX arrays are never identical in the
            # context of `set()`.
            # set([chainerx.array([0]), chainerx.array([0])])
            # => {array([0], shape=(1,), dtype=int64, device='native:0'),
            #     array([0], shape=(1,), dtype=int64, device='native:0')}

            joined_dataset = [int(e) for e in joined_dataset]
            original_dataset = [int(e) for e in original_dataset]
            self.assertEqual(set(joined_dataset), set(original_dataset)) 
开发者ID:chainer,项目名称:chainer,代码行数:45,代码来源:test_scatter.py

示例6: objective

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def objective(trial, comm):
    # Sample an architecture.
    model = L.Classifier(create_model(trial))

    # Setup optimizer.
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

    # Setup dataset and iterator. Only worker 0 loads the whole dataset.
    # The dataset of worker 0 is evenly split and distributed to all workers.
    if comm.rank == 0:
        train, valid = chainer.datasets.get_mnist()
        rng = np.random.RandomState(0)
        train = chainer.datasets.SubDataset(
            train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
        )
        valid = chainer.datasets.SubDataset(
            valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
        )
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))

    # Add Chainer extension for pruners.
    trainer.extend(
        optuna.integration.ChainerPruningExtension(
            trial, "validation/main/accuracy", (PRUNER_INTERVAL, "epoch")
        )
    )
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    trainer.extend(chainermn.create_multi_node_evaluator(evaluator, comm))
    log_report_extension = chainer.training.extensions.LogReport(log_name=None)
    trainer.extend(log_report_extension)

    if comm.rank == 0:
        trainer.extend(chainer.training.extensions.ProgressBar())

    # Run training.
    # Please set show_loop_exception_msg False to inhibit messages about TrialPruned exception.
    # ChainerPruningExtension raises TrialPruned exception to stop training, and
    # trainer shows some messages every time it receive TrialPruned.
    trainer.run(show_loop_exception_msg=False)

    # Evaluate.
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    report = evaluator()

    return report["main/accuracy"] 
开发者ID:optuna,项目名称:optuna,代码行数:61,代码来源:chainermn_integration.py

示例7: objective

# 需要导入模块: import chainermn [as 别名]
# 或者: from chainermn import scatter_dataset [as 别名]
def objective(trial, comm):
    # Sample an architecture.
    model = L.Classifier(create_model(trial))

    # Setup optimizer.
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

    # Setup dataset and iterator. Only worker 0 loads the whole dataset.
    # The dataset of worker 0 is evenly split and distributed to all workers.
    if comm.rank == 0:
        train, valid = chainer.datasets.get_mnist()
        rng = np.random.RandomState(0)
        train = chainer.datasets.SubDataset(
            train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
        )
        valid = chainer.datasets.SubDataset(
            valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
        )
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))

    if comm.rank == 0:
        trainer.extend(chainer.training.extensions.ProgressBar())

    # Run training.
    trainer.run()

    # Evaluate.
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    report = evaluator()

    return report["main/accuracy"] 
开发者ID:optuna,项目名称:optuna,代码行数:47,代码来源:chainermn_simple.py


注:本文中的chainermn.scatter_dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。