當前位置: 首頁>>代碼示例>>Python>>正文


Python chainermn.create_multi_node_evaluator方法代碼示例

本文整理匯總了Python中chainermn.create_multi_node_evaluator方法的典型用法代碼示例。如果您正苦於以下問題:Python chainermn.create_multi_node_evaluator方法的具體用法?Python chainermn.create_multi_node_evaluator怎麽用?Python chainermn.create_multi_node_evaluator使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在chainermn的用法示例。


在下文中一共展示了chainermn.create_multi_node_evaluator方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_example

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import create_multi_node_evaluator [as 別名]
def get_example(self, i):
        # It reads the i-th image/label pair and return a preprocessed image.
        # It applies following preprocesses:
        #     - Cropping (random or center rectangular)
        #     - Random flip
        #     - Scaling to [0, 1] value
        crop_size = self.crop_size

        image, label = self.base[i]
        _, h, w = image.shape

        if self.random:
            # Randomly crop a region and flip the image
            top = random.randint(0, h - crop_size - 1)
            left = random.randint(0, w - crop_size - 1)
            if random.randint(0, 1):
                image = image[:, :, ::-1]
        else:
            # Crop the center
            top = (h - crop_size) // 2
            left = (w - crop_size) // 2
        bottom = top + crop_size
        right = left + crop_size

        image = image[:, top:bottom, left:right]
        image -= self.mean[:, top:bottom, left:right]
        image *= (1.0 / 255.0)  # Scale to [0, 1]
        return image, label


# chainermn.create_multi_node_evaluator can be also used with user customized
# evaluator classes that inherit chainer.training.extensions.Evaluator. 
開發者ID:chainer,項目名稱:chainer,代碼行數:34,代碼來源:train_imagenet.py

示例2: objective

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import create_multi_node_evaluator [as 別名]
def objective(trial, comm):
    # Sample an architecture.
    model = L.Classifier(create_model(trial))

    # Setup optimizer.
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

    # Setup dataset and iterator. Only worker 0 loads the whole dataset.
    # The dataset of worker 0 is evenly split and distributed to all workers.
    if comm.rank == 0:
        train, valid = chainer.datasets.get_mnist()
        rng = np.random.RandomState(0)
        train = chainer.datasets.SubDataset(
            train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
        )
        valid = chainer.datasets.SubDataset(
            valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
        )
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))

    # Add Chainer extension for pruners.
    trainer.extend(
        optuna.integration.ChainerPruningExtension(
            trial, "validation/main/accuracy", (PRUNER_INTERVAL, "epoch")
        )
    )
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    trainer.extend(chainermn.create_multi_node_evaluator(evaluator, comm))
    log_report_extension = chainer.training.extensions.LogReport(log_name=None)
    trainer.extend(log_report_extension)

    if comm.rank == 0:
        trainer.extend(chainer.training.extensions.ProgressBar())

    # Run training.
    # Please set show_loop_exception_msg False to inhibit messages about TrialPruned exception.
    # ChainerPruningExtension raises TrialPruned exception to stop training, and
    # trainer shows some messages every time it receive TrialPruned.
    trainer.run(show_loop_exception_msg=False)

    # Evaluate.
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    report = evaluator()

    return report["main/accuracy"] 
開發者ID:optuna,項目名稱:optuna,代碼行數:61,代碼來源:chainermn_integration.py

示例3: objective

# 需要導入模塊: import chainermn [as 別名]
# 或者: from chainermn import create_multi_node_evaluator [as 別名]
def objective(trial, comm):
    # Sample an architecture.
    model = L.Classifier(create_model(trial))

    # Setup optimizer.
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)

    # Setup dataset and iterator. Only worker 0 loads the whole dataset.
    # The dataset of worker 0 is evenly split and distributed to all workers.
    if comm.rank == 0:
        train, valid = chainer.datasets.get_mnist()
        rng = np.random.RandomState(0)
        train = chainer.datasets.SubDataset(
            train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
        )
        valid = chainer.datasets.SubDataset(
            valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
        )
    else:
        train, valid = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    valid = chainermn.scatter_dataset(valid, comm)

    train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
    valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)

    # Setup trainer.
    updater = chainer.training.StandardUpdater(train_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))

    if comm.rank == 0:
        trainer.extend(chainer.training.extensions.ProgressBar())

    # Run training.
    trainer.run()

    # Evaluate.
    evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
    evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
    report = evaluator()

    return report["main/accuracy"] 
開發者ID:optuna,項目名稱:optuna,代碼行數:47,代碼來源:chainermn_simple.py


注:本文中的chainermn.create_multi_node_evaluator方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。