当前位置: 首页>>代码示例>>Python>>正文


Python util.replace_masked_values方法代码示例

本文整理汇总了Python中allennlp.nn.util.replace_masked_values方法的典型用法代码示例。如果您正苦于以下问题:Python util.replace_masked_values方法的具体用法?Python util.replace_masked_values怎么用?Python util.replace_masked_values使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在allennlp.nn.util的用法示例。


在下文中一共展示了util.replace_masked_values方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: masked_mean

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def masked_mean(tensor, dim, mask):
    """
    ``Performs a mean on just the non-masked portions of the ``tensor`` in the
    ``dim`` dimension of the tensor.
    """
    if mask is None:
        return torch.mean(tensor, dim)
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
    masked_tensor = replace_masked_values(tensor, mask, 0.0)
    # total value
    total_tensor = torch.sum(masked_tensor, dim)
    # count
    count_tensor = torch.sum((mask != 0), dim)
    # set zero count to 1 to avoid nans
    zero_count_mask = (count_tensor == 0)
    count_plus_zeros = (count_tensor + zero_count_mask).float()
    # average
    mean_tensor = total_tensor / count_plus_zeros
    return mean_tensor 
开发者ID:allenai,项目名称:scitail,代码行数:22,代码来源:util.py

示例2: masked_mean

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def masked_mean(tensor, dim, mask):
    """
    ``Performs a mean on just the non-masked portions of the ``tensor`` in the
    ``dim`` dimension of the tensor.

    =====================================================================
    From Decomposable Graph Entailment Model code replicated from SciTail repo
    https://github.com/allenai/scitail
    =====================================================================
    """
    if mask is None:
        return torch.mean(tensor, dim)
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
    masked_tensor = replace_masked_values(tensor, mask, 0.0)
    # total value
    total_tensor = torch.sum(masked_tensor, dim)
    # count
    count_tensor = torch.sum((mask != 0), dim)
    # set zero count to 1 to avoid nans
    zero_count_mask = (count_tensor == 0)
    count_plus_zeros = (count_tensor + zero_count_mask).float()
    # average
    mean_tensor = total_tensor / count_plus_zeros
    return mean_tensor 
开发者ID:allenai,项目名称:ARC-Solvers,代码行数:27,代码来源:util.py

示例3: test_replace_masked_values_replaces_masked_values_with_finite_value

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
        tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
        mask = torch.tensor([[True, True, False]])
        replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
        assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]]) 
开发者ID:allenai,项目名称:allennlp,代码行数:7,代码来源:util_test.py

示例4: test_replace_masked_values_replaces_masked_values_with_finite_value

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
        tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
        mask = torch.FloatTensor([[1, 1, 0]])
        replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
        assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]]) 
开发者ID:plasticityai,项目名称:magnitude,代码行数:7,代码来源:util_test.py

示例5: forward

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                gt_span=None, mode=ForwardMode.TRAIN):
        sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask,
                                               output_all_encoded_layers=False)
        joint_length = allen_util.get_lengths_from_binary_sequence_mask(attention_mask)

        joint_seq_logits = self.qa_outputs(sequence_output)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 0], attention_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 1], attention_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2] -> [B]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, attention_mask), gt_start)
            end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, attention_mask), gt_end)
            # We delete squeeze bc it will cause problem when the batch size is 1, and remember the gt_start and gt_end should have shape [B].
            # start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
            # end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, joint_length 
开发者ID:easonnie,项目名称:semanticRetrievalMRS,代码行数:30,代码来源:hotpot_bert_v0.py

示例6: forward

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_span=None,
                gt_span=None, max_context_length=0, mode=ForwardMode.TRAIN):
        # Precomputing of the max_context_length is important
        # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible.
        sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask,
                                               output_all_encoded_layers=False)

        joint_seq_logits = self.qa_outputs(sequence_output)
        context_logits, context_length = span_util.span_select(joint_seq_logits, context_span, max_context_length)
        context_mask = allen_util.get_mask_from_sequence_lengths(context_length, max_context_length)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(context_logits[:, :, 0], context_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(context_logits[:, :, 1], context_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
            end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, context_length 
开发者ID:easonnie,项目名称:semanticRetrievalMRS,代码行数:30,代码来源:bert_span_v0.py

示例7: forward

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def forward(self, # pylint: disable=arguments-differ
                premises_relevance_logits: torch.Tensor,
                premises_presence_mask: torch.Tensor,
                relevance_presence_mask: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
        premises_relevance_logits = replace_masked_values(premises_relevance_logits, premises_presence_mask, -1e10)
        binary_losses = self._loss(premises_relevance_logits, relevance_presence_mask)
        coverage_losses = masked_mean(binary_losses, premises_presence_mask, dim=1)
        coverage_loss = coverage_losses.mean()
        return coverage_loss 
开发者ID:StonyBrookNLP,项目名称:multee,代码行数:11,代码来源:coverage_loss.py

示例8: compute_location_spans

# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def compute_location_spans(self, contextual_seq_embedding, embedded_sentence_verb_entity, mask):
        # # ===============================================================test============================================
        # # Layer 5: Span prediction for before and after location
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        batch_size, num_sentences, num_participants, sentence_length, encoder_dim = contextual_seq_embedding.shape
        #print("contextual_seq_embedding: ", contextual_seq_embedding.shape)
        # size(span_start_input_after): batch_size * num_sentences *
        #                                num_participants * sentence_length * (embedding_size+2+2*seq2seq_output_size)
        span_start_input_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding], dim=-1)

        #print("span_start_input_after: ", span_start_input_after.shape)
        # Shape: (bs, ns , np, sl)
        span_start_logits_after = self._span_start_predictor_after(span_start_input_after).squeeze(-1)
        #print("span_start_logits_after: ", span_start_logits_after.shape)

        # Shape: (bs, ns , np, sl)
        span_start_probs_after = util.masked_softmax(span_start_logits_after, mask)
        #print("span_start_probs_after: ", span_start_probs_after.shape)

        # span_start_representation_after: (bs, ns , np, encoder_dim)
        span_start_representation_after = util.weighted_sum(contextual_seq_embedding, span_start_probs_after)
        #print("span_start_representation_after: ", span_start_representation_after.shape)

        # span_tiled_start_representation_after: (bs, ns , np, sl, 2*seq2seq_output_size)
        span_tiled_start_representation_after = span_start_representation_after.unsqueeze(3).expand(batch_size,
                                                                                                    num_sentences,
                                                                                                    num_participants,
                                                                                                    sentence_length,
                                                                                                    encoder_dim)
        #print("span_tiled_start_representation_after: ", span_tiled_start_representation_after.shape)

        # Shape: (batch_size, passage_length, (embedding+2  + encoder_dim + encoder_dim + encoder_dim))
        span_end_representation_after = torch.cat([embedded_sentence_verb_entity,
                                                   contextual_seq_embedding,
                                                   span_tiled_start_representation_after,
                                                   contextual_seq_embedding * span_tiled_start_representation_after],
                                                  dim=-1)
        #print("span_end_representation_after: ", span_end_representation_after.shape)

        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end_after = self.time_distributed_encoder_span_end_after(span_end_representation_after, mask)
        #print("encoded_span_end_after: ", encoded_span_end_after.shape)

        span_end_logits_after = self._span_end_predictor_after(encoded_span_end_after).squeeze(-1)
        #print("span_end_logits_after: ", span_end_logits_after.shape)

        span_end_probs_after = util.masked_softmax(span_end_logits_after, mask)
        #print("span_end_probs_after: ", span_end_probs_after.shape)

        span_start_logits_after = util.replace_masked_values(span_start_logits_after, mask, -1e7)
        span_end_logits_after = util.replace_masked_values(span_end_logits_after, mask, -1e7)

        # Fixme: we should condition this on predicted_action so that we can output '-' when needed
        # Fixme: also add a functionality to be able to output '?': we can use span_start_probs_after, span_end_probs_after
        best_span_after = self.get_best_span(span_start_logits_after, span_end_logits_after)
        #print("best_span_after: ", best_span_after)
        return best_span_after, span_start_logits_after, span_end_logits_after 
开发者ID:allenai,项目名称:propara,代码行数:59,代码来源:prostruct_model.py


注:本文中的allennlp.nn.util.replace_masked_values方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。