当前位置: 首页>>代码示例>>Python>>正文


Python onmt.Statistics方法代码示例

本文整理汇总了Python中onmt.Statistics方法的典型用法代码示例。如果您正苦于以下问题:Python onmt.Statistics方法的具体用法?Python onmt.Statistics怎么用?Python onmt.Statistics使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在onmt的用法示例。


在下文中一共展示了onmt.Statistics方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: monolithic_compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def monolithic_compute_loss(self, batch, output, attns):
        """
        Compute the forward loss for the batch.

        Args:
          batch (batch): batch of labeled examples
          output (:obj:`FloatTensor`):
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict of :obj:`FloatTensor`) :
              dictionary of attention distributions
              `[tgt_len x batch x src_len]`
        Returns:
            :obj:`onmt.Statistics`: loss statistics
        """
        range_ = (0, batch.tgt.size(0))
        shard_state = self._make_shard_state(batch, output, range_, attns)
        _, batch_stats = self._compute_loss(batch, **shard_state)

        return batch_stats 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:21,代码来源:Loss.py

示例2: _stats

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _stats(self, loss, scores, target):
        """
        Args:
            loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
            scores (:obj:`FloatTensor`): a score for each possible output
            target (:obj:`FloatTensor`): true targets

        Returns:
            :obj:`Statistics` : statistics for this batch.
        """
        pred = scores.max(1)[1]
        non_padding = target.ne(self.padding_idx)
        num_correct = pred.eq(target) \
                          .masked_select(non_padding) \
                          .sum()
        return onmt.Statistics(loss[0], non_padding.sum(), num_correct) 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:18,代码来源:Loss.py

示例3: _stats

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _stats(self, loss, scores, target):
        """
        Args:
            loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
            scores (:obj:`FloatTensor`): a score for each possible output
            target (:obj:`FloatTensor`): true targets

        Returns:
            :obj:`Statistics` : statistics for this batch.
        """
        non_padding = target.ne(self.padding_idx)
        if scores is not None:
            pred = scores.max(1)[1]
            num_correct = pred.eq(target).masked_select(non_padding).sum()
        else:
            num_correct = 0
        return onmt.Statistics(loss.item(), float(non_padding.sum()), float(num_correct)) 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:19,代码来源:Trainer.py

示例4: report_func

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def report_func(epoch, batch, num_batches,
                start_time, lr, report_stats):
    """
    This is the user-defined batch-level traing progress
    report function.

    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): old Statistics instance.
    Returns:
        report_stats(Statistics): updated Statistics instance.
    """
    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch+1, num_batches, start_time)
        if opt.exp_host:
            report_stats.log("progress", experiment, lr)
        report_stats = onmt.Statistics()

    return report_stats 
开发者ID:abaheti95,项目名称:DC-NeuralConversation,代码行数:25,代码来源:train.py

示例5: monolithic_compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def monolithic_compute_loss(self, batch, output, attns, dist_info=None, output_baseline=None):
        """
        Compute the forward loss for the batch.

        Args:
          batch (batch): batch of labeled examples
          output (:obj:`FloatTensor`):
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict of :obj:`FloatTensor`) :
              dictionary of attention distributions
              `[tgt_len x batch x src_len]`
        Returns:
            :obj:`onmt.Statistics`: loss statistics
        """
        if dist_info is not None:
            self.dist_type = dist_info.p.dist_type
        range_ = (0, batch.tgt.size(0))
        shard_state = self._make_shard_state(batch, output, range_, attns, dist_info=dist_info, output_baseline=output_baseline)
        _, batch_stats = self._compute_loss(batch, **shard_state)

        return batch_stats 
开发者ID:harvardnlp,项目名称:var-attn,代码行数:23,代码来源:Loss.py

示例6: _stats

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _stats(self, xent, kl, scores, target):
        """
        Args:
            loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
            scores (:obj:`FloatTensor`): a score for each possible output
            target (:obj:`FloatTensor`): true targets

        Returns:
            :obj:`Statistics` : statistics for this batch.
        """
        pred = scores.max(1)[1]
        non_padding = target.ne(self.padding_idx)
        num_correct = pred.eq(target) \
                          .masked_select(non_padding) \
                          .sum()
        return onmt.Statistics(xent.item(), kl.item(), non_padding.sum().item(), num_correct.item()) 
开发者ID:harvardnlp,项目名称:var-attn,代码行数:18,代码来源:Loss.py

示例7: sharded_compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
                             cur_trunc, trunc_size, shard_size, teacher_outputs=None):
        """
        Compute the loss in shards for efficiency.
        """
        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        gen_state = make_gen_state(output, batch, attns, range_,
                                   self.copy_attn, teacher_outputs)

        for shard in shards(gen_state, shard_size):
            loss, stats = self.compute_loss(batch, **shard)
            loss.div(batch.batch_size).backward()
            batch_stats.update(stats)

        return batch_stats 
开发者ID:antspy,项目名称:quantized_distillation,代码行数:18,代码来源:Loss.py

示例8: report_func

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def report_func(epoch, batch, num_batches, start_time, lr, report_stats, options=None):

    """
    This is the user-defined batch-level traing progress
    report function.
    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): a Statistics instance.
    """

    if options is None:
        report_every = onmt.standard_options.stdOptions['report_every']
    else:
        try:
            report_every = options['report_every']
        except KeyError:
            report_every = options.report_every

    if batch % report_every == -1 % report_every:
        report_stats.output(epoch, batch+1, num_batches, start_time) 
开发者ID:antspy,项目名称:quantized_distillation,代码行数:26,代码来源:model.py

示例9: sharded_compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
                             cur_trunc, trunc_size, shard_size,
                             normalization):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics

        """
        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        shard_state = self._make_shard_state(batch, output, range_, attns)

        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(batch, **shard)
            loss.div(normalization).backward()
            batch_stats.update(stats)

        return batch_stats 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:42,代码来源:Loss.py

示例10: validate

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`onmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics()

        for batch in valid_iter:
            cur_dataset = valid_iter.get_cur_dataset()
            self.valid_loss.cur_dataset = cur_dataset

            src = onmt.io.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = onmt.io.make_features(batch, 'tgt')

            # F-prop through the model.
            outputs, attns, _ = self.model(src, tgt, src_lengths)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:39,代码来源:Trainer.py

示例11: report_func

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def report_func(epoch, batch, num_batches,
                progress_step,
                start_time, lr, report_stats):
    """
    This is the user-defined batch-level traing progress
    report function.

    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        progress_step(int): the progress step.
        start_time(float): last report time.
        lr(float): current learning rate.
        report_stats(Statistics): old Statistics instance.
    Returns:
        report_stats(Statistics): updated Statistics instance.
    """
    if batch % opt.report_every == -1 % opt.report_every:
        report_stats.output(epoch, batch + 1, num_batches, start_time)
        if opt.exp_host:
            report_stats.log("progress", experiment, lr)
        if opt.tensorboard:
            # Log the progress using the number of batches on the x-axis.
            report_stats.log_tensorboard(
                "progress", writer, lr, progress_step)
        report_stats = onmt.Statistics()

    return report_stats 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:31,代码来源:train.py

示例12: _compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def _compute_loss(self, batch, output, target, **kwargs):
        """
        Compute the loss. Subclass must define this method.

        Args:

            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            **kwargs(optional): additional info for computing loss.
        """
        return NotImplementedError

    # def monolithic_compute_loss(self, batch, output, attns):
    #     """
    #     Compute the forward loss for the batch.

    #     Args:
    #       batch (batch): batch of labeled examples
    #       output (:obj:`FloatTensor`):
    #           output of decoder model `[tgt_len x batch x hidden]`
    #       attns (dict of :obj:`FloatTensor`) :
    #           dictionary of attention distributions
    #           `[tgt_len x batch x src_len]`
    #     Returns:
    #         :obj:`onmt.Statistics`: loss statistics
    #     """
    #     # range_ = (0, batch.tgt.size(0))
    #     # shard_state = self._make_shard_state(batch, output, range_, attns)
    #     # loss, batch_stats = self._compute_loss(batch, **shard_state)
    #     target = batch.tgt[1:batch.tgt.size(0)]
    #     loss, batch_stats = self._compute_loss(batch, output, target)
    #     return loss, batch_stats 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:35,代码来源:Loss.py

示例13: sharded_compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
                             cur_trunc, trunc_size, shard_size):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note harding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics

        """

        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        shard_state = self._make_shard_state(batch, output, range_, attns)

        for shard in shards(shard_state, shard_size):
            ipdb.set_trace()
            loss, stats = self._compute_loss(batch, **shard)
            loss.div(batch.batch_size).backward()
            batch_stats.update(stats)

        return batch_stats 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:42,代码来源:Loss.py

示例14: sharded_compute_loss

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def sharded_compute_loss(self, batch, output, attns,
                             cur_trunc, trunc_size, shard_size,
                             normalization):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note harding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard

        Returns:
            :obj:`onmt.Statistics`: validation loss statistics

        """
        batch_stats = onmt.Statistics()
        range_ = (cur_trunc, cur_trunc + trunc_size)
        shard_state = self._make_shard_state(batch, output, range_, attns)

        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(batch, **shard)

            loss.div(normalization).backward()
            batch_stats.update(stats)

        return batch_stats 
开发者ID:abaheti95,项目名称:DC-NeuralConversation,代码行数:42,代码来源:Loss.py

示例15: validate

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Statistics [as 别名]
def validate(self, valid_iter):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`onmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics()

        for batch in valid_iter:
            cur_dataset = valid_iter.get_cur_dataset()
            self.valid_loss.cur_dataset = cur_dataset

            src = onmt.io.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = onmt.io.make_features(batch, 'tgt')

            # F-prop through the model.
            outputs, attns, _ = self.model(src, tgt, src_lengths,
                                           entities_list=batch.entities_list,
                                           entities_len=batch.entities_len,
                                           count_entities=batch.count_entities,
                                           total_entities_list=batch.total_entities_list)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                    batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats 
开发者ID:ratishsp,项目名称:data2text-entity-py,代码行数:43,代码来源:Trainer.py


注:本文中的onmt.Statistics方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。