本文整理汇总了Python中onmt.Statistics方法的典型用法代码示例。如果您正苦于以下问题:Python onmt.Statistics方法的具体用法?Python onmt.Statistics怎么用?Python onmt.Statistics使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类onmt
的用法示例。
在下文中一共展示了onmt.Statistics方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: monolithic_compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def monolithic_compute_loss(self, batch, output, attns):
"""
Compute the forward loss for the batch.
Args:
batch (batch): batch of labeled examples
output (:obj:`FloatTensor`):
output of decoder model `[tgt_len x batch x hidden]`
attns (dict of :obj:`FloatTensor`) :
dictionary of attention distributions
`[tgt_len x batch x src_len]`
Returns:
:obj:`onmt.Statistics`: loss statistics
"""
range_ = (0, batch.tgt.size(0))
shard_state = self._make_shard_state(batch, output, range_, attns)
_, batch_stats = self._compute_loss(batch, **shard_state)
return batch_stats
示例2: _stats
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _stats(self, loss, scores, target):
"""
Args:
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
scores (:obj:`FloatTensor`): a score for each possible output
target (:obj:`FloatTensor`): true targets
Returns:
:obj:`Statistics` : statistics for this batch.
"""
pred = scores.max(1)[1]
non_padding = target.ne(self.padding_idx)
num_correct = pred.eq(target) \
.masked_select(non_padding) \
.sum()
return onmt.Statistics(loss[0], non_padding.sum(), num_correct)
示例3: _stats
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _stats(self, loss, scores, target):
"""
Args:
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
scores (:obj:`FloatTensor`): a score for each possible output
target (:obj:`FloatTensor`): true targets
Returns:
:obj:`Statistics` : statistics for this batch.
"""
non_padding = target.ne(self.padding_idx)
if scores is not None:
pred = scores.max(1)[1]
num_correct = pred.eq(target).masked_select(non_padding).sum()
else:
num_correct = 0
return onmt.Statistics(loss.item(), float(non_padding.sum()), float(num_correct))
示例4: report_func
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def report_func(epoch, batch, num_batches,
start_time, lr, report_stats):
"""
This is the user-defined batch-level traing progress
report function.
Args:
epoch(int): current epoch count.
batch(int): current batch count.
num_batches(int): total number of batches.
start_time(float): last report time.
lr(float): current learning rate.
report_stats(Statistics): old Statistics instance.
Returns:
report_stats(Statistics): updated Statistics instance.
"""
if batch % opt.report_every == -1 % opt.report_every:
report_stats.output(epoch, batch+1, num_batches, start_time)
if opt.exp_host:
report_stats.log("progress", experiment, lr)
report_stats = onmt.Statistics()
return report_stats
示例5: monolithic_compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def monolithic_compute_loss(self, batch, output, attns, dist_info=None, output_baseline=None):
"""
Compute the forward loss for the batch.
Args:
batch (batch): batch of labeled examples
output (:obj:`FloatTensor`):
output of decoder model `[tgt_len x batch x hidden]`
attns (dict of :obj:`FloatTensor`) :
dictionary of attention distributions
`[tgt_len x batch x src_len]`
Returns:
:obj:`onmt.Statistics`: loss statistics
"""
if dist_info is not None:
self.dist_type = dist_info.p.dist_type
range_ = (0, batch.tgt.size(0))
shard_state = self._make_shard_state(batch, output, range_, attns, dist_info=dist_info, output_baseline=output_baseline)
_, batch_stats = self._compute_loss(batch, **shard_state)
return batch_stats
示例6: _stats
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _stats(self, xent, kl, scores, target):
"""
Args:
loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
scores (:obj:`FloatTensor`): a score for each possible output
target (:obj:`FloatTensor`): true targets
Returns:
:obj:`Statistics` : statistics for this batch.
"""
pred = scores.max(1)[1]
non_padding = target.ne(self.padding_idx)
num_correct = pred.eq(target) \
.masked_select(non_padding) \
.sum()
return onmt.Statistics(xent.item(), kl.item(), non_padding.sum().item(), num_correct.item())
示例7: sharded_compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
cur_trunc, trunc_size, shard_size, teacher_outputs=None):
"""
Compute the loss in shards for efficiency.
"""
batch_stats = onmt.Statistics()
range_ = (cur_trunc, cur_trunc + trunc_size)
gen_state = make_gen_state(output, batch, attns, range_,
self.copy_attn, teacher_outputs)
for shard in shards(gen_state, shard_size):
loss, stats = self.compute_loss(batch, **shard)
loss.div(batch.batch_size).backward()
batch_stats.update(stats)
return batch_stats
示例8: report_func
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def report_func(epoch, batch, num_batches, start_time, lr, report_stats, options=None):
"""
This is the user-defined batch-level traing progress
report function.
Args:
epoch(int): current epoch count.
batch(int): current batch count.
num_batches(int): total number of batches.
start_time(float): last report time.
lr(float): current learning rate.
report_stats(Statistics): a Statistics instance.
"""
if options is None:
report_every = onmt.standard_options.stdOptions['report_every']
else:
try:
report_every = options['report_every']
except KeyError:
report_every = options.report_every
if batch % report_every == -1 % report_every:
report_stats.output(epoch, batch+1, num_batches, start_time)
示例9: sharded_compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
cur_trunc, trunc_size, shard_size,
normalization):
"""Compute the forward loss and backpropagate. Computation is done
with shards and optionally truncation for memory efficiency.
Also supports truncated BPTT for long sequences by taking a
range in the decoder output sequence to back propagate in.
Range is from `(cur_trunc, cur_trunc + trunc_size)`.
Note sharding is an exact efficiency trick to relieve memory
required for the generation buffers. Truncation is an
approximate efficiency trick to relieve the memory required
in the RNN buffers.
Args:
batch (batch) : batch of labeled examples
output (:obj:`FloatTensor`) :
output of decoder model `[tgt_len x batch x hidden]`
attns (dict) : dictionary of attention distributions
`[tgt_len x batch x src_len]`
cur_trunc (int) : starting position of truncation window
trunc_size (int) : length of truncation window
shard_size (int) : maximum number of examples in a shard
normalization (int) : Loss is divided by this number
Returns:
:obj:`onmt.Statistics`: validation loss statistics
"""
batch_stats = onmt.Statistics()
range_ = (cur_trunc, cur_trunc + trunc_size)
shard_state = self._make_shard_state(batch, output, range_, attns)
for shard in shards(shard_state, shard_size):
loss, stats = self._compute_loss(batch, **shard)
loss.div(normalization).backward()
batch_stats.update(stats)
return batch_stats
示例10: validate
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def validate(self, valid_iter):
""" Validate model.
valid_iter: validate data iterator
Returns:
:obj:`onmt.Statistics`: validation loss statistics
"""
# Set model in validating mode.
self.model.eval()
stats = Statistics()
for batch in valid_iter:
cur_dataset = valid_iter.get_cur_dataset()
self.valid_loss.cur_dataset = cur_dataset
src = onmt.io.make_features(batch, 'src', self.data_type)
if self.data_type == 'text':
_, src_lengths = batch.src
else:
src_lengths = None
tgt = onmt.io.make_features(batch, 'tgt')
# F-prop through the model.
outputs, attns, _ = self.model(src, tgt, src_lengths)
# Compute loss.
batch_stats = self.valid_loss.monolithic_compute_loss(
batch, outputs, attns)
# Update statistics.
stats.update(batch_stats)
# Set model back to training mode.
self.model.train()
return stats
示例11: report_func
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def report_func(epoch, batch, num_batches,
progress_step,
start_time, lr, report_stats):
"""
This is the user-defined batch-level traing progress
report function.
Args:
epoch(int): current epoch count.
batch(int): current batch count.
num_batches(int): total number of batches.
progress_step(int): the progress step.
start_time(float): last report time.
lr(float): current learning rate.
report_stats(Statistics): old Statistics instance.
Returns:
report_stats(Statistics): updated Statistics instance.
"""
if batch % opt.report_every == -1 % opt.report_every:
report_stats.output(epoch, batch + 1, num_batches, start_time)
if opt.exp_host:
report_stats.log("progress", experiment, lr)
if opt.tensorboard:
# Log the progress using the number of batches on the x-axis.
report_stats.log_tensorboard(
"progress", writer, lr, progress_step)
report_stats = onmt.Statistics()
return report_stats
示例12: _compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _compute_loss(self, batch, output, target, **kwargs):
"""
Compute the loss. Subclass must define this method.
Args:
batch: the current batch.
output: the predict output from the model.
target: the validate target to compare output with.
**kwargs(optional): additional info for computing loss.
"""
return NotImplementedError
# def monolithic_compute_loss(self, batch, output, attns):
# """
# Compute the forward loss for the batch.
# Args:
# batch (batch): batch of labeled examples
# output (:obj:`FloatTensor`):
# output of decoder model `[tgt_len x batch x hidden]`
# attns (dict of :obj:`FloatTensor`) :
# dictionary of attention distributions
# `[tgt_len x batch x src_len]`
# Returns:
# :obj:`onmt.Statistics`: loss statistics
# """
# # range_ = (0, batch.tgt.size(0))
# # shard_state = self._make_shard_state(batch, output, range_, attns)
# # loss, batch_stats = self._compute_loss(batch, **shard_state)
# target = batch.tgt[1:batch.tgt.size(0)]
# loss, batch_stats = self._compute_loss(batch, output, target)
# return loss, batch_stats
示例13: sharded_compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
cur_trunc, trunc_size, shard_size):
"""Compute the forward loss and backpropagate. Computation is done
with shards and optionally truncation for memory efficiency.
Also supports truncated BPTT for long sequences by taking a
range in the decoder output sequence to back propagate in.
Range is from `(cur_trunc, cur_trunc + trunc_size)`.
Note harding is an exact efficiency trick to relieve memory
required for the generation buffers. Truncation is an
approximate efficiency trick to relieve the memory required
in the RNN buffers.
Args:
batch (batch) : batch of labeled examples
output (:obj:`FloatTensor`) :
output of decoder model `[tgt_len x batch x hidden]`
attns (dict) : dictionary of attention distributions
`[tgt_len x batch x src_len]`
cur_trunc (int) : starting position of truncation window
trunc_size (int) : length of truncation window
shard_size (int) : maximum number of examples in a shard
Returns:
:obj:`onmt.Statistics`: validation loss statistics
"""
batch_stats = onmt.Statistics()
range_ = (cur_trunc, cur_trunc + trunc_size)
shard_state = self._make_shard_state(batch, output, range_, attns)
for shard in shards(shard_state, shard_size):
ipdb.set_trace()
loss, stats = self._compute_loss(batch, **shard)
loss.div(batch.batch_size).backward()
batch_stats.update(stats)
return batch_stats
示例14: sharded_compute_loss
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
cur_trunc, trunc_size, shard_size,
normalization):
"""Compute the forward loss and backpropagate. Computation is done
with shards and optionally truncation for memory efficiency.
Also supports truncated BPTT for long sequences by taking a
range in the decoder output sequence to back propagate in.
Range is from `(cur_trunc, cur_trunc + trunc_size)`.
Note harding is an exact efficiency trick to relieve memory
required for the generation buffers. Truncation is an
approximate efficiency trick to relieve the memory required
in the RNN buffers.
Args:
batch (batch) : batch of labeled examples
output (:obj:`FloatTensor`) :
output of decoder model `[tgt_len x batch x hidden]`
attns (dict) : dictionary of attention distributions
`[tgt_len x batch x src_len]`
cur_trunc (int) : starting position of truncation window
trunc_size (int) : length of truncation window
shard_size (int) : maximum number of examples in a shard
Returns:
:obj:`onmt.Statistics`: validation loss statistics
"""
batch_stats = onmt.Statistics()
range_ = (cur_trunc, cur_trunc + trunc_size)
shard_state = self._make_shard_state(batch, output, range_, attns)
for shard in shards(shard_state, shard_size):
loss, stats = self._compute_loss(batch, **shard)
loss.div(normalization).backward()
batch_stats.update(stats)
return batch_stats
示例15: validate
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def validate(self, valid_iter):
""" Validate model.
valid_iter: validate data iterator
Returns:
:obj:`onmt.Statistics`: validation loss statistics
"""
# Set model in validating mode.
self.model.eval()
stats = Statistics()
for batch in valid_iter:
cur_dataset = valid_iter.get_cur_dataset()
self.valid_loss.cur_dataset = cur_dataset
src = onmt.io.make_features(batch, 'src', self.data_type)
if self.data_type == 'text':
_, src_lengths = batch.src
else:
src_lengths = None
tgt = onmt.io.make_features(batch, 'tgt')
# F-prop through the model.
outputs, attns, _ = self.model(src, tgt, src_lengths,
entities_list=batch.entities_list,
entities_len=batch.entities_len,
count_entities=batch.count_entities,
total_entities_list=batch.total_entities_list)
# Compute loss.
batch_stats = self.valid_loss.monolithic_compute_loss(
batch, outputs, attns)
# Update statistics.
stats.update(batch_stats)
# Set model back to training mode.
self.model.train()
return stats