當前位置: 首頁>>代碼示例>>Python>>正文


Python chainermn.scatter_dataset方法代碼示例

本文整理匯總了Python中chainermn.scatter_dataset方法的典型用法代碼示例。如果您正苦於以下問題:Python chainermn.scatter_dataset方法的具體用法?Python chainermn.scatter_dataset怎麽用?Python chainermn.scatter_dataset使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在chainermn的用法示例。


在下文中一共展示了chainermn.scatter_dataset方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: make_dataset

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def make_dataset(self, stage_int):
        if self.is_master:
            size = 4 * (2 ** ((stage_int + 1) // 2))
            _dataset = BaseDataset(
                json.load(open(FLAGS.dataset_config, 'r')),
                '%dx%d' % (size, size),
                [["resize", {"probability": 1, "width": size, "height": size, "resample_filter": "ANTIALIAS"}]]
            )
            self.print_log('Add (master) dataset for size {}'.format(size))
        else:
            _dataset = None
            self.print_log('Add (slave) dataset')

        if self.use_mpi:
            _dataset = chainermn.scatter_dataset(_dataset, self.comm)

        return _dataset 
開發者ID:pfnet-research,項目名稱:chainer-stylegan,代碼行數:19,代碼來源:train.py

示例2: scatter_large_data

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def scatter_large_data(communicator):
    data = []
    if communicator.rank == 0:
        data = ['test'] * 2000000000
    data = chainermn.scatter_dataset(data, communicator)
    assert len(data) > 0 
開發者ID:chainer,項目名稱:chainer,代碼行數:8,代碼來源:test_scatter.py

示例3: _prepare_multinode_snapshot

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def _prepare_multinode_snapshot(n, result):
    n_units = 100
    batchsize = 10
    comm = create_communicator('naive')
    model = L.Classifier(MLP(n_units, 10))
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    if comm.rank == 0:
        train, _ = chainer.datasets.get_mnist()
    else:
        train, _ = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    train_iter = chainer.iterators.SerialIterator(train, batchsize)

    updater = StandardUpdater(train_iter, optimizer)
    trainer = Trainer(updater, out=result)

    snapshot = extensions.snapshot(target=updater, autoload=True)
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    mn_snapshot.initialize(trainer)
    for _ in range(n):
        updater.update()

    return updater, mn_snapshot, trainer 
開發者ID:chainer,項目名稱:chainer,代碼行數:30,代碼來源:test_multi_node_snapshot.py

示例4: setup_mnist_trainer

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def setup_mnist_trainer(self, display_log=False, use_chx=False):
        batchsize = 100
        n_units = 100

        comm = self.communicator
        model = L.Classifier(MLP(n_units, 10))

        model.to_device(get_device(None, use_chx))

        optimizer = chainermn.create_multi_node_optimizer(
            chainer.optimizers.Adam(), comm)
        optimizer.setup(model)

        if comm.rank == 0:
            train, test = chainer.datasets.get_mnist()
        else:
            train, test = None, None

        train = chainermn.scatter_dataset(train, comm, shuffle=True)
        test = chainermn.scatter_dataset(test, comm, shuffle=True)

        train_iter = chainer.iterators.SerialIterator(train, batchsize)
        test_iter = chainer.iterators.SerialIterator(test, batchsize,
                                                     repeat=False,
                                                     shuffle=False)

        updater = training.StandardUpdater(
            train_iter,
            optimizer
        )

        return updater, optimizer, train_iter, test_iter, model 
開發者ID:chainer,項目名稱:chainer,代碼行數:34,代碼來源:test_checkpoint.py

示例5: check_scatter_dataset

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def check_scatter_dataset(self, original_dataset, shuffle=False, root=0):
        if self.communicator.rank != root:
            original_dataset = None
        my_dataset = chainermn.scatter_dataset(
            original_dataset, self.communicator,
            shuffle=shuffle, root=root)
        sub_datasets = self.communicator.gather_obj(my_dataset, root=root)

        if self.communicator.rank == root:
            # Test the sizes
            sub_sizes = [len(sub_dataset) for sub_dataset in sub_datasets]
            self.assertEqual(len(set(sub_sizes)), 1)
            sub_size = sub_sizes[0]
            self.assertLessEqual(
                len(original_dataset), sub_size * self.mpi_comm.size)
            self.assertGreater(
                len(original_dataset), (sub_size - 1) * self.mpi_comm.size)

            # Test the content of scattered datasets
            joined_dataset = sum((sub_dataset[:]
                                  for sub_dataset in sub_datasets), [])

            # NOTE: The values in `original_dataset` and
            # `joined_dataset` must be casted to int to compare.
            # There are 2 backgrounds on this issue.
            #
            # (1) numpy and cupy/chainerx have different behaviours on
            # 1-element array. Numpy implicitly converts a 1-element array to
            # a scalar value.
            # type(numpy.array([1])[0])
            # =>  <class 'numpy.int64'>  # Scalar
            # type(chainerx.array([1])[0])
            # => <class 'chainerx.ndarray'>  # array of one element
            #
            # (2) Two different ChainerX arrays are never identical in the
            # context of `set()`.
            # set([chainerx.array([0]), chainerx.array([0])])
            # => {array([0], shape=(1,), dtype=int64, device='native:0'),
            #     array([0], shape=(1,), dtype=int64, device='native:0')}

            joined_dataset = [int(e) for e in joined_dataset]
            original_dataset = [int(e) for e in original_dataset]
            self.assertEqual(set(joined_dataset), set(original_dataset)) 
開發者ID:chainer,項目名稱:chainer,代碼行數:45,代碼來源:test_scatter.py

示例6: objective

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def objective(trial, comm):
    # Sample an architecture.
    model = L.Classifier(create_model(trial))

    # Setup optimizer.
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

    # Setup dataset and iterator. Only worker 0 loads the whole dataset.
    # The dataset of worker 0 is evenly split and distributed to all workers.
    if comm.rank == 0:
        train, valid = chainer.datasets.get_mnist()
        rng = np.random.RandomState(0)
        train = chainer.datasets.SubDataset(
            train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
        )
        valid = chainer.datasets.SubDataset(
            valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
        )
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))

    # Add Chainer extension for pruners.
    trainer.extend(
        optuna.integration.ChainerPruningExtension(
            trial, "validation/main/accuracy", (PRUNER_INTERVAL, "epoch")
        )
    )
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    trainer.extend(chainermn.create_multi_node_evaluator(evaluator, comm))
    log_report_extension = chainer.training.extensions.LogReport(log_name=None)
    trainer.extend(log_report_extension)

    if comm.rank == 0:
        trainer.extend(chainer.training.extensions.ProgressBar())

    # Run training.
    # Please set show_loop_exception_msg False to inhibit messages about TrialPruned exception.
    # ChainerPruningExtension raises TrialPruned exception to stop training, and
    # trainer shows some messages every time it receive TrialPruned.
    trainer.run(show_loop_exception_msg=False)

    # Evaluate.
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    report = evaluator()

    return report["main/accuracy"] 
開發者ID:optuna,項目名稱:optuna,代碼行數:61,代碼來源:chainermn_integration.py

示例7: objective

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import scatter_dataset [as 別名]
def objective(trial, comm):
    # Sample an architecture.
    model = L.Classifier(create_model(trial))

    # Setup optimizer.
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

    # Setup dataset and iterator. Only worker 0 loads the whole dataset.
    # The dataset of worker 0 is evenly split and distributed to all workers.
    if comm.rank == 0:
        train, valid = chainer.datasets.get_mnist()
        rng = np.random.RandomState(0)
        train = chainer.datasets.SubDataset(
            train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
        )
        valid = chainer.datasets.SubDataset(
            valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
        )
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))

    if comm.rank == 0:
        trainer.extend(chainer.training.extensions.ProgressBar())

    # Run training.
    trainer.run()

    # Evaluate.
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    report = evaluator()

    return report["main/accuracy"] 
開發者ID:optuna,項目名稱:optuna,代碼行數:47,代碼來源:chainermn_simple.py


注:本文中的chainermn.scatter_dataset方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。