本文整理汇总了Python中allennlp.nn.util.replace_masked_values方法的典型用法代码示例。如果您正苦于以下问题:Python util.replace_masked_values方法的具体用法?Python util.replace_masked_values怎么用?Python util.replace_masked_values使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类allennlp.nn.util
的用法示例。
在下文中一共展示了util.replace_masked_values方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: masked_mean
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def masked_mean(tensor, dim, mask):
"""
``Performs a mean on just the non-masked portions of the ``tensor`` in the
``dim`` dimension of the tensor.
"""
if mask is None:
return torch.mean(tensor, dim)
if tensor.dim() != mask.dim():
raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
masked_tensor = replace_masked_values(tensor, mask, 0.0)
# total value
total_tensor = torch.sum(masked_tensor, dim)
# count
count_tensor = torch.sum((mask != 0), dim)
# set zero count to 1 to avoid nans
zero_count_mask = (count_tensor == 0)
count_plus_zeros = (count_tensor + zero_count_mask).float()
# average
mean_tensor = total_tensor / count_plus_zeros
return mean_tensor
示例2: masked_mean
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def masked_mean(tensor, dim, mask):
"""
``Performs a mean on just the non-masked portions of the ``tensor`` in the
``dim`` dimension of the tensor.
=====================================================================
From Decomposable Graph Entailment Model code replicated from SciTail repo
https://github.com/allenai/scitail
=====================================================================
"""
if mask is None:
return torch.mean(tensor, dim)
if tensor.dim() != mask.dim():
raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
masked_tensor = replace_masked_values(tensor, mask, 0.0)
# total value
total_tensor = torch.sum(masked_tensor, dim)
# count
count_tensor = torch.sum((mask != 0), dim)
# set zero count to 1 to avoid nans
zero_count_mask = (count_tensor == 0)
count_plus_zeros = (count_tensor + zero_count_mask).float()
# average
mean_tensor = total_tensor / count_plus_zeros
return mean_tensor
示例3: test_replace_masked_values_replaces_masked_values_with_finite_value
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
mask = torch.tensor([[True, True, False]])
replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]])
示例4: test_replace_masked_values_replaces_masked_values_with_finite_value
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
mask = torch.FloatTensor([[1, 1, 0]])
replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]])
示例5: forward
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
gt_span=None, mode=ForwardMode.TRAIN):
sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
joint_length = allen_util.get_lengths_from_binary_sequence_mask(attention_mask)
joint_seq_logits = self.qa_outputs(sequence_output)
# The following line is from AllenNLP bidaf.
start_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 0], attention_mask, -1e18)
# B, T, 2
end_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 1], attention_mask, -1e18)
if mode == BertSpan.ForwardMode.TRAIN:
assert gt_span is not None
gt_start = gt_span[:, 0] # gt_span: [B, 2] -> [B]
gt_end = gt_span[:, 1]
start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, attention_mask), gt_start)
end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, attention_mask), gt_end)
# We delete squeeze bc it will cause problem when the batch size is 1, and remember the gt_start and gt_end should have shape [B].
# start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
# end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))
loss = start_loss + end_loss
return loss
else:
return start_logits, end_logits, joint_length
示例6: forward
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_span=None,
gt_span=None, max_context_length=0, mode=ForwardMode.TRAIN):
# Precomputing of the max_context_length is important
# because we want the same value to be shared to different GPUs, dynamic calculating is not feasible.
sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
joint_seq_logits = self.qa_outputs(sequence_output)
context_logits, context_length = span_util.span_select(joint_seq_logits, context_span, max_context_length)
context_mask = allen_util.get_mask_from_sequence_lengths(context_length, max_context_length)
# The following line is from AllenNLP bidaf.
start_logits = allen_util.replace_masked_values(context_logits[:, :, 0], context_mask, -1e18)
# B, T, 2
end_logits = allen_util.replace_masked_values(context_logits[:, :, 1], context_mask, -1e18)
if mode == BertSpan.ForwardMode.TRAIN:
assert gt_span is not None
gt_start = gt_span[:, 0] # gt_span: [B, 2]
gt_end = gt_span[:, 1]
start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))
loss = start_loss + end_loss
return loss
else:
return start_logits, end_logits, context_length
示例7: forward
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def forward(self, # pylint: disable=arguments-differ
premises_relevance_logits: torch.Tensor,
premises_presence_mask: torch.Tensor,
relevance_presence_mask: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
premises_relevance_logits = replace_masked_values(premises_relevance_logits, premises_presence_mask, -1e10)
binary_losses = self._loss(premises_relevance_logits, relevance_presence_mask)
coverage_losses = masked_mean(binary_losses, premises_presence_mask, dim=1)
coverage_loss = coverage_losses.mean()
return coverage_loss
示例8: compute_location_spans
# 需要导入模块: from allennlp.nn import util [as 别名]
# 或者: from allennlp.nn.util import replace_masked_values [as 别名]
def compute_location_spans(self, contextual_seq_embedding, embedded_sentence_verb_entity, mask):
# # ===============================================================test============================================
# # Layer 5: Span prediction for before and after location
# Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
batch_size, num_sentences, num_participants, sentence_length, encoder_dim = contextual_seq_embedding.shape
#print("contextual_seq_embedding: ", contextual_seq_embedding.shape)
# size(span_start_input_after): batch_size * num_sentences *
# num_participants * sentence_length * (embedding_size+2+2*seq2seq_output_size)
span_start_input_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding], dim=-1)
#print("span_start_input_after: ", span_start_input_after.shape)
# Shape: (bs, ns , np, sl)
span_start_logits_after = self._span_start_predictor_after(span_start_input_after).squeeze(-1)
#print("span_start_logits_after: ", span_start_logits_after.shape)
# Shape: (bs, ns , np, sl)
span_start_probs_after = util.masked_softmax(span_start_logits_after, mask)
#print("span_start_probs_after: ", span_start_probs_after.shape)
# span_start_representation_after: (bs, ns , np, encoder_dim)
span_start_representation_after = util.weighted_sum(contextual_seq_embedding, span_start_probs_after)
#print("span_start_representation_after: ", span_start_representation_after.shape)
# span_tiled_start_representation_after: (bs, ns , np, sl, 2*seq2seq_output_size)
span_tiled_start_representation_after = span_start_representation_after.unsqueeze(3).expand(batch_size,
num_sentences,
num_participants,
sentence_length,
encoder_dim)
#print("span_tiled_start_representation_after: ", span_tiled_start_representation_after.shape)
# Shape: (batch_size, passage_length, (embedding+2 + encoder_dim + encoder_dim + encoder_dim))
span_end_representation_after = torch.cat([embedded_sentence_verb_entity,
contextual_seq_embedding,
span_tiled_start_representation_after,
contextual_seq_embedding * span_tiled_start_representation_after],
dim=-1)
#print("span_end_representation_after: ", span_end_representation_after.shape)
# Shape: (batch_size, passage_length, encoding_dim)
encoded_span_end_after = self.time_distributed_encoder_span_end_after(span_end_representation_after, mask)
#print("encoded_span_end_after: ", encoded_span_end_after.shape)
span_end_logits_after = self._span_end_predictor_after(encoded_span_end_after).squeeze(-1)
#print("span_end_logits_after: ", span_end_logits_after.shape)
span_end_probs_after = util.masked_softmax(span_end_logits_after, mask)
#print("span_end_probs_after: ", span_end_probs_after.shape)
span_start_logits_after = util.replace_masked_values(span_start_logits_after, mask, -1e7)
span_end_logits_after = util.replace_masked_values(span_end_logits_after, mask, -1e7)
# Fixme: we should condition this on predicted_action so that we can output '-' when needed
# Fixme: also add a functionality to be able to output '?': we can use span_start_probs_after, span_end_probs_after
best_span_after = self.get_best_span(span_start_logits_after, span_end_logits_after)
#print("best_span_after: ", best_span_after)
return best_span_after, span_start_logits_after, span_end_logits_after