当前位置: 首页>>代码示例>>Python>>正文


Python Events.ITERATION_COMPLETED属性代码示例

本文整理汇总了Python中ignite.engine.Events.ITERATION_COMPLETED属性的典型用法代码示例。如果您正苦于以下问题:Python Events.ITERATION_COMPLETED属性的具体用法?Python Events.ITERATION_COMPLETED怎么用?Python Events.ITERATION_COMPLETED使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在ignite.engine.Events的用法示例。


在下文中一共展示了Events.ITERATION_COMPLETED属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_engine

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def make_engine(generator_name, config, device=torch.device("cuda")):
    try:
        make_generator = import_module(IMPLEMENTED_GENERATOR[generator_name]).make_generator
    except KeyError:
        raise RuntimeError("not implemented generator <{}>".format(generator_name))
    generate = make_generator(config, device)

    def _step(engine, batch):
        batch = convert_tensor(batch, device)
        generated_images = generate(batch)
        return (batch["condition_path"], batch["target_path"]), \
               (batch["condition_img"], batch["target_img"], generated_images)

    engine = Engine(_step)
    ProgressBar(ncols=0).attach(engine)

    @engine.on(Events.ITERATION_COMPLETED)
    def save(e):
        names, images = e.state.output
        for i in range(images[0].size(0)):
            image_name = os.path.join(config["output"], "{}___{}_vis.jpg".format(names[0][i], names[1][i]))
            save_image([imgs.data[i] for imgs in images], image_name,
                       nrow=len(images), normalize=True, padding=0)

    return engine 
开发者ID:budui,项目名称:Human-Pose-Transfer,代码行数:27,代码来源:generate.py

示例2: add_tensorboard_handler

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def add_tensorboard_handler(tensorboard_folder, engine, every_iteration=False):
    """
    Every key in engine.state.epoch_history[-1] is logged to TensorBoard.
    
    Args:
        tensorboard_folder (str): Where the tensorboard logs should go.
        trainer (ignite.Engine): The engine to log.
        every_iteration (bool, optional): Whether to also log the values at every 
          iteration.
    """

    @engine.on(ValidationEvents.VALIDATION_COMPLETED)
    def log_to_tensorboard(engine):
        writer = SummaryWriter(tensorboard_folder)
        for key in engine.state.epoch_history:
            writer.add_scalar(
                key, engine.state.epoch_history[key][-1], engine.state.epoch)

    if every_iteration:
        @engine.on(Events.ITERATION_COMPLETED)
        def log_iteration_to_tensorboard(engine):
            writer = SummaryWriter(tensorboard_folder)
            for key in engine.state.iter_history:
                writer.add_scalar(
                    key, engine.state.iter_history[key][-1], engine.state.iteration) 
开发者ID:nussl,项目名称:nussl,代码行数:27,代码来源:trainer.py

示例3: test_tb_image_shape

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_tb_image_shape(self, shape):
        tempdir = tempfile.mkdtemp()
        shutil.rmtree(tempdir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return torch.zeros((1, 1, 10, 10))

        engine = Engine(_train_func)

        # set up testing handler
        stats_handler = TensorBoardImageHandler(log_dir=tempdir)
        engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler)

        data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape)))
        engine.run(data, epoch_length=10, max_epochs=1)

        self.assertTrue(os.path.exists(tempdir))
        self.assertTrue(len(glob.glob(tempdir)) > 0)
        shutil.rmtree(tempdir) 
开发者ID:Project-MONAI,项目名称:MONAI,代码行数:22,代码来源:test_handler_tb_image.py

示例4: attach

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def attach(self, engine, metric_names=None, output_transform=None):
        """
        Attaches the progress bar to an engine object.

        Args:
            engine (Engine): engine object.
            metric_names (list, optional): list of the metrics names to log as the bar progresses
            output_transform (callable, optional): a function to select what you want to print from the engine's
                output. This function may return either a dictionary with entries in the format of ``{name: value}``,
                or a single scalar, which will be displayed with the default name `output`.
        """
        if metric_names is not None and not isinstance(metric_names, list):
            raise TypeError("metric_names should be a list, got {} instead.".format(type(metric_names)))

        if output_transform is not None and not callable(output_transform):
            raise TypeError("output_transform should be a function, got {} instead."
                            .format(type(output_transform)))

        engine.add_event_handler(Events.ITERATION_COMPLETED, self._update, metric_names, output_transform)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._close) 
开发者ID:leokarlin,项目名称:LaSO,代码行数:22,代码来源:tqdm_logger.py

示例5: test_concepts_snippet_warning

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_concepts_snippet_warning():
    def random_train_data_generator():
        while True:
            yield torch.randint(0, 100, size=(1,))

    def print_train_data(engine, batch):
        i = engine.state.iteration
        e = engine.state.epoch
        print("train", e, i, batch.tolist())

    trainer = DeterministicEngine(print_train_data)

    @trainer.on(Events.ITERATION_COMPLETED(every=3))
    def user_handler(_):
        # handler synchronizes the random state
        torch.manual_seed(12)
        a = torch.rand(1)

    trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_deterministic.py

示例6: test_reset_should_terminate

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_reset_should_terminate():
    def update_fn(engine, batch):
        pass

    engine = Engine(update_fn)

    @engine.on(Events.ITERATION_COMPLETED)
    def terminate_on_iteration_10(engine):
        if engine.state.iteration == 10:
            engine.terminate()

    engine.run([0] * 20)
    assert engine.state.iteration == 10

    engine.run([0] * 20)
    assert engine.state.iteration == 10 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:test_engine.py

示例7: test_run_finite_iterator_no_epoch_length_2

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_run_finite_iterator_no_epoch_length_2():
    # FR: https://github.com/pytorch/ignite/issues/871
    known_size = 11

    def finite_size_data_iter(size):
        for i in range(size):
            yield i

    bc = BatchChecker(data=list(range(known_size)))

    engine = Engine(lambda e, b: bc.check(b))

    @engine.on(Events.ITERATION_COMPLETED(every=known_size))
    def restart_iter():
        engine.state.dataloader = finite_size_data_iter(known_size)

    data_iter = finite_size_data_iter(known_size)
    engine.run(data_iter, max_epochs=5)

    assert engine.state.epoch == 5
    assert engine.state.iteration == known_size * 5 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_engine.py

示例8: test_pbar_wrong_events_order

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)

    with pytest.raises(ValueError, match="should not be a filtered event"):
        pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10)) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_tqdm_logger.py

示例9: test_get_intermediate_results_during_run

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_get_intermediate_results_during_run(capsys):
    true_event_handler_time = 0.0645
    true_max_epochs = 2
    true_num_iters = 5

    profiler = BasicTimeProfiler()
    dummy_trainer = get_prepared_engine(true_event_handler_time)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.ITERATION_COMPLETED(every=3))
    def log_results(_):
        results = profiler.get_results()
        profiler.print_results(results)
        captured = capsys.readouterr()
        out = captured.out
        assert "BasicTimeProfiler._" not in out
        assert "nan" not in out
        assert " min/index: (0.0, " not in out, out

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) 
开发者ID:pytorch,项目名称:ignite,代码行数:22,代码来源:test_time_profilers.py

示例10: test_save_param_history

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_save_param_history():
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, save_history=True)
    lrs = []

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]["lr"])

    trainer = Engine(lambda engine, batch: None)
    assert not hasattr(trainer.state, "param_history")

    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
    trainer.run([0] * 10, max_epochs=2)

    state_lrs = trainer.state.param_history["lr"]
    assert len(state_lrs) == len(lrs)
    # Unpack singleton lists
    assert [group[0] for group in state_lrs] == lrs 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_param_scheduler.py

示例11: _test_distrib_one_rank_only_with_engine

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def _test_distrib_one_rank_only_with_engine(device):
    def _test(barrier):
        engine = Engine(lambda e, b: b)

        batch_sum = torch.tensor(0).to(device)

        @engine.on(Events.ITERATION_COMPLETED)
        @idist.one_rank_only(with_barrier=barrier)  # ie rank == 0
        def _(_):
            batch_sum.data += torch.tensor(engine.state.batch).to(device)

        engine.run([1, 2, 3], max_epochs=2)

        value_list = idist.all_gather(tensor=batch_sum)

        for r in range(idist.get_world_size()):
            if r == 0:
                assert value_list[r].item() == 12
            else:
                assert value_list[r].item() == 0

    _test(barrier=True)
    _test(barrier=False) 
开发者ID:pytorch,项目名称:ignite,代码行数:25,代码来源:test_utils.py

示例12: attach

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def attach(self, engine, name):
        engine.add_event_handler(Events.EPOCH_STARTED, self.started)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) 
开发者ID:hrhodin,项目名称:UnsupervisedGeometryAwareRepresentationLearning,代码行数:6,代码来源:metric.py

示例13: gradual_unfreezing

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def gradual_unfreezing(self):

        for name, param in self.model.named_parameters():
            if 'embeddings' not in name and 'classification' not in name:
                param.detach_()
                param.requires_grad = False

            else:
                param.requires_grad = True

        full_parameters = sum(p.numel() for p in self.model.parameters())
        trained_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

        logger.info(f"We will start by training {trained_parameters:3e} parameters out of {full_parameters:3e},"
                    f" i.e. {100 * trained_parameters / full_parameters:.2f}%")

        # We will unfreeze blocks regularly along the training: one block every `unfreezing_interval` step
        unfreezing_interval = int(len(self.dataset_splits.train_data_loader()) * self.num_epochs / (self.model.num_layers + 1))

        @self.trainer.on(Events.ITERATION_COMPLETED)
        def unfreeze_layer_if_needed(engine):
            if engine.state.iteration % unfreezing_interval == 0:
                # Which layer should we unfreeze now
                unfreezing_index = self.model.num_layers - (engine.state.iteration // unfreezing_interval)

                # Let's unfreeze it
                unfreezed = []
                for name, param in self.model.named_parameters():
                    if re.match(r"transformer\.[^\.]*\." + str(unfreezing_index) + r"\.", name):
                        unfreezed.append(name)
                        param.require_grad = True
                logger.info(f"Unfreezing block {unfreezing_index} with {unfreezed}") 
开发者ID:feedly,项目名称:transfer-nlp,代码行数:34,代码来源:trainers.py

示例14: run_with_pbar

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def run_with_pbar(engine, loader, desc=None):
    pbar = tqdm.trange(len(loader), desc=desc)
    engine.on(Events.ITERATION_COMPLETED)(lambda _: pbar.update(1))
    engine.run(loader)
    pbar.close() 
开发者ID:lopuhin,项目名称:kaggle-kuzushiji-2019,代码行数:7,代码来源:utils.py

示例15: __call__

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def __call__(self, engine, logger, event_name):
        if not isinstance(logger, ChainerUILogger):
            raise RuntimeError(
                '`chainerui.contrib.ignite.handler.OutputHandler` works only '
                'with ChainerUILogger, but set {}'.format(type(logger)))

        metrics = self._setup_output_metrics(engine)
        if not metrics:
            return
        iteration = self.global_step_transform(
            engine, Events.ITERATION_COMPLETED)
        epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED)

        # convert metrics name
        rendered_metrics = {}
        for k, v in metrics.items():
            rendered_metrics['{}/{}'.format(self.tag, k)] = v
        rendered_metrics['iteration'] = iteration
        rendered_metrics['epoch'] = epoch
        if 'elapsed_time' not in rendered_metrics:
            rendered_metrics['elapsed_time'] = _get_time() - logger.start_at

        if self.interval <= 1:
            logger.post_log([rendered_metrics])
            return

        # enable interval, cache metrics
        logger.cache.setdefault(self.tag, []).append(rendered_metrics)
        # select appropriate even set by handler init
        global_count = self.global_step_transform(engine, event_name)
        if global_count % self.interval == 0:
            logger.post_log(logger.cache[self.tag])
            logger.cache[self.tag].clear() 
开发者ID:chainer,项目名称:chainerui,代码行数:35,代码来源:handler.py


注:本文中的ignite.engine.Events.ITERATION_COMPLETED属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。