本文整理汇总了Python中ignite.engine.Events.ITERATION_COMPLETED属性的典型用法代码示例。如果您正苦于以下问题:Python Events.ITERATION_COMPLETED属性的具体用法?Python Events.ITERATION_COMPLETED怎么用?Python Events.ITERATION_COMPLETED使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类ignite.engine.Events
的用法示例。
在下文中一共展示了Events.ITERATION_COMPLETED属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: make_engine
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def make_engine(generator_name, config, device=torch.device("cuda")):
try:
make_generator = import_module(IMPLEMENTED_GENERATOR[generator_name]).make_generator
except KeyError:
raise RuntimeError("not implemented generator <{}>".format(generator_name))
generate = make_generator(config, device)
def _step(engine, batch):
batch = convert_tensor(batch, device)
generated_images = generate(batch)
return (batch["condition_path"], batch["target_path"]), \
(batch["condition_img"], batch["target_img"], generated_images)
engine = Engine(_step)
ProgressBar(ncols=0).attach(engine)
@engine.on(Events.ITERATION_COMPLETED)
def save(e):
names, images = e.state.output
for i in range(images[0].size(0)):
image_name = os.path.join(config["output"], "{}___{}_vis.jpg".format(names[0][i], names[1][i]))
save_image([imgs.data[i] for imgs in images], image_name,
nrow=len(images), normalize=True, padding=0)
return engine
示例2: add_tensorboard_handler
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def add_tensorboard_handler(tensorboard_folder, engine, every_iteration=False):
"""
Every key in engine.state.epoch_history[-1] is logged to TensorBoard.
Args:
tensorboard_folder (str): Where the tensorboard logs should go.
trainer (ignite.Engine): The engine to log.
every_iteration (bool, optional): Whether to also log the values at every
iteration.
"""
@engine.on(ValidationEvents.VALIDATION_COMPLETED)
def log_to_tensorboard(engine):
writer = SummaryWriter(tensorboard_folder)
for key in engine.state.epoch_history:
writer.add_scalar(
key, engine.state.epoch_history[key][-1], engine.state.epoch)
if every_iteration:
@engine.on(Events.ITERATION_COMPLETED)
def log_iteration_to_tensorboard(engine):
writer = SummaryWriter(tensorboard_folder)
for key in engine.state.iter_history:
writer.add_scalar(
key, engine.state.iter_history[key][-1], engine.state.iteration)
示例3: test_tb_image_shape
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_tb_image_shape(self, shape):
tempdir = tempfile.mkdtemp()
shutil.rmtree(tempdir, ignore_errors=True)
# set up engine
def _train_func(engine, batch):
return torch.zeros((1, 1, 10, 10))
engine = Engine(_train_func)
# set up testing handler
stats_handler = TensorBoardImageHandler(log_dir=tempdir)
engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler)
data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape)))
engine.run(data, epoch_length=10, max_epochs=1)
self.assertTrue(os.path.exists(tempdir))
self.assertTrue(len(glob.glob(tempdir)) > 0)
shutil.rmtree(tempdir)
示例4: attach
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def attach(self, engine, metric_names=None, output_transform=None):
"""
Attaches the progress bar to an engine object.
Args:
engine (Engine): engine object.
metric_names (list, optional): list of the metrics names to log as the bar progresses
output_transform (callable, optional): a function to select what you want to print from the engine's
output. This function may return either a dictionary with entries in the format of ``{name: value}``,
or a single scalar, which will be displayed with the default name `output`.
"""
if metric_names is not None and not isinstance(metric_names, list):
raise TypeError("metric_names should be a list, got {} instead.".format(type(metric_names)))
if output_transform is not None and not callable(output_transform):
raise TypeError("output_transform should be a function, got {} instead."
.format(type(output_transform)))
engine.add_event_handler(Events.ITERATION_COMPLETED, self._update, metric_names, output_transform)
engine.add_event_handler(Events.EPOCH_COMPLETED, self._close)
示例5: test_concepts_snippet_warning
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_concepts_snippet_warning():
def random_train_data_generator():
while True:
yield torch.randint(0, 100, size=(1,))
def print_train_data(engine, batch):
i = engine.state.iteration
e = engine.state.epoch
print("train", e, i, batch.tolist())
trainer = DeterministicEngine(print_train_data)
@trainer.on(Events.ITERATION_COMPLETED(every=3))
def user_handler(_):
# handler synchronizes the random state
torch.manual_seed(12)
a = torch.rand(1)
trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5)
示例6: test_reset_should_terminate
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_reset_should_terminate():
def update_fn(engine, batch):
pass
engine = Engine(update_fn)
@engine.on(Events.ITERATION_COMPLETED)
def terminate_on_iteration_10(engine):
if engine.state.iteration == 10:
engine.terminate()
engine.run([0] * 20)
assert engine.state.iteration == 10
engine.run([0] * 20)
assert engine.state.iteration == 10
示例7: test_run_finite_iterator_no_epoch_length_2
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_run_finite_iterator_no_epoch_length_2():
# FR: https://github.com/pytorch/ignite/issues/871
known_size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
bc = BatchChecker(data=list(range(known_size)))
engine = Engine(lambda e, b: bc.check(b))
@engine.on(Events.ITERATION_COMPLETED(every=known_size))
def restart_iter():
engine.state.dataloader = finite_size_data_iter(known_size)
data_iter = finite_size_data_iter(known_size)
engine.run(data_iter, max_epochs=5)
assert engine.state.epoch == 5
assert engine.state.iteration == known_size * 5
示例8: test_pbar_wrong_events_order
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_pbar_wrong_events_order():
engine = Engine(update_fn)
pbar = ProgressBar()
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)
with pytest.raises(ValueError, match="should not be a filtered event"):
pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))
示例9: test_get_intermediate_results_during_run
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_get_intermediate_results_during_run(capsys):
true_event_handler_time = 0.0645
true_max_epochs = 2
true_num_iters = 5
profiler = BasicTimeProfiler()
dummy_trainer = get_prepared_engine(true_event_handler_time)
profiler.attach(dummy_trainer)
@dummy_trainer.on(Events.ITERATION_COMPLETED(every=3))
def log_results(_):
results = profiler.get_results()
profiler.print_results(results)
captured = capsys.readouterr()
out = captured.out
assert "BasicTimeProfiler._" not in out
assert "nan" not in out
assert " min/index: (0.0, " not in out, out
dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
示例10: test_save_param_history
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def test_save_param_history():
tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0)
scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, save_history=True)
lrs = []
def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])
trainer = Engine(lambda engine, batch: None)
assert not hasattr(trainer.state, "param_history")
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
trainer.run([0] * 10, max_epochs=2)
state_lrs = trainer.state.param_history["lr"]
assert len(state_lrs) == len(lrs)
# Unpack singleton lists
assert [group[0] for group in state_lrs] == lrs
示例11: _test_distrib_one_rank_only_with_engine
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def _test_distrib_one_rank_only_with_engine(device):
def _test(barrier):
engine = Engine(lambda e, b: b)
batch_sum = torch.tensor(0).to(device)
@engine.on(Events.ITERATION_COMPLETED)
@idist.one_rank_only(with_barrier=barrier) # ie rank == 0
def _(_):
batch_sum.data += torch.tensor(engine.state.batch).to(device)
engine.run([1, 2, 3], max_epochs=2)
value_list = idist.all_gather(tensor=batch_sum)
for r in range(idist.get_world_size()):
if r == 0:
assert value_list[r].item() == 12
else:
assert value_list[r].item() == 0
_test(barrier=True)
_test(barrier=False)
示例12: attach
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def attach(self, engine, name):
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name)
示例13: gradual_unfreezing
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def gradual_unfreezing(self):
for name, param in self.model.named_parameters():
if 'embeddings' not in name and 'classification' not in name:
param.detach_()
param.requires_grad = False
else:
param.requires_grad = True
full_parameters = sum(p.numel() for p in self.model.parameters())
trained_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info(f"We will start by training {trained_parameters:3e} parameters out of {full_parameters:3e},"
f" i.e. {100 * trained_parameters / full_parameters:.2f}%")
# We will unfreeze blocks regularly along the training: one block every `unfreezing_interval` step
unfreezing_interval = int(len(self.dataset_splits.train_data_loader()) * self.num_epochs / (self.model.num_layers + 1))
@self.trainer.on(Events.ITERATION_COMPLETED)
def unfreeze_layer_if_needed(engine):
if engine.state.iteration % unfreezing_interval == 0:
# Which layer should we unfreeze now
unfreezing_index = self.model.num_layers - (engine.state.iteration // unfreezing_interval)
# Let's unfreeze it
unfreezed = []
for name, param in self.model.named_parameters():
if re.match(r"transformer\.[^\.]*\." + str(unfreezing_index) + r"\.", name):
unfreezed.append(name)
param.require_grad = True
logger.info(f"Unfreezing block {unfreezing_index} with {unfreezed}")
示例14: run_with_pbar
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def run_with_pbar(engine, loader, desc=None):
pbar = tqdm.trange(len(loader), desc=desc)
engine.on(Events.ITERATION_COMPLETED)(lambda _: pbar.update(1))
engine.run(loader)
pbar.close()
示例15: __call__
# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_COMPLETED [as 别名]
def __call__(self, engine, logger, event_name):
if not isinstance(logger, ChainerUILogger):
raise RuntimeError(
'`chainerui.contrib.ignite.handler.OutputHandler` works only '
'with ChainerUILogger, but set {}'.format(type(logger)))
metrics = self._setup_output_metrics(engine)
if not metrics:
return
iteration = self.global_step_transform(
engine, Events.ITERATION_COMPLETED)
epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED)
# convert metrics name
rendered_metrics = {}
for k, v in metrics.items():
rendered_metrics['{}/{}'.format(self.tag, k)] = v
rendered_metrics['iteration'] = iteration
rendered_metrics['epoch'] = epoch
if 'elapsed_time' not in rendered_metrics:
rendered_metrics['elapsed_time'] = _get_time() - logger.start_at
if self.interval <= 1:
logger.post_log([rendered_metrics])
return
# enable interval, cache metrics
logger.cache.setdefault(self.tag, []).append(rendered_metrics)
# select appropriate even set by handler init
global_count = self.global_step_transform(engine, event_name)
if global_count % self.interval == 0:
logger.post_log(logger.cache[self.tag])
logger.cache[self.tag].clear()