本文整理匯總了Python中torch.LongTensor類的典型用法代碼示例。如果您正苦於以下問題:Python LongTensor類的具體用法?Python LongTensor怎麽用?Python LongTensor使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。
在下文中一共展示了LongTensor類的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: __call__
def __call__(self, # type: ignore
predictions: torch.LongTensor,
gold_targets: torch.LongTensor) -> None:
"""
Update precision counts.
Parameters
----------
predictions : ``torch.LongTensor``, required
Batched predicted tokens of shape `(batch_size, max_sequence_length)`.
references : ``torch.LongTensor``, required
Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`.
Returns
-------
None
"""
predictions, gold_targets = self.unwrap_to_tensors(predictions, gold_targets)
for ngram_size, _ in enumerate(self._ngram_weights, start=1):
precision_matches, precision_totals = self._get_modified_precision_counts(
predictions, gold_targets, ngram_size)
self._precision_matches[ngram_size] += precision_matches
self._precision_totals[ngram_size] += precision_totals
if not self._exclude_indices:
self._prediction_lengths += predictions.size(0) * predictions.size(1)
self._reference_lengths += gold_targets.size(0) * gold_targets.size(1)
else:
valid_predictions_mask = self._get_valid_tokens_mask(predictions)
self._prediction_lengths += valid_predictions_mask.sum().item()
valid_gold_targets_mask = self._get_valid_tokens_mask(gold_targets)
self._reference_lengths += valid_gold_targets_mask.sum().item()
示例2: flattened_index_select
def flattened_index_select(target: torch.Tensor,
indices: torch.LongTensor) -> torch.Tensor:
"""
The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
that each of the set_size rows should select. The `target` has size
``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
``(batch_size, set_size, subset_size, embedding_size)``.
Parameters
----------
target : ``torch.Tensor``, required.
A Tensor of shape (batch_size, sequence_length, embedding_size).
indices : ``torch.LongTensor``, required.
A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
as this tensor is an index into the sequence_length dimension of the target.
Returns
-------
selected : ``torch.Tensor``, required.
A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
"""
if indices.dim() != 2:
raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
"only 2 dimensional inputs are supported.".format(indices.size()))
# Shape: (batch_size, set_size * subset_size, embedding_size)
flattened_selected = target.index_select(1, indices.view(-1))
# Shape: (batch_size, set_size, subset_size, embedding_size)
selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
return selected
示例3: _get_mask_for_eval
def _get_mask_for_eval(self,
mask: torch.LongTensor,
pos_tags: torch.LongTensor) -> torch.LongTensor:
"""
Dependency evaluation excludes words are punctuation.
Here, we create a new mask to exclude word indices which
have a "punctuation-like" part of speech tag.
Parameters
----------
mask : ``torch.LongTensor``, required.
The original mask.
pos_tags : ``torch.LongTensor``, required.
The pos tags for the sequence.
Returns
-------
A new mask, where any indices equal to labels
we should be ignoring are masked.
"""
new_mask = mask.detach()
for label in self._pos_to_ignore:
label_mask = pos_tags.eq(label).long()
new_mask = new_mask * (1 - label_mask)
return new_mask
示例4: forward
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.LongTensor = None,
span_indices_mask: torch.LongTensor = None) -> None:
# shape (batch_size, num_spans)
span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]
if span_indices_mask is not None:
# It's not strictly necessary to multiply the span indices by the mask here,
# but it's possible that the span representation was padded with something other
# than 0 (such as -1, which would be an invalid index), so we do so anyway to
# be safe.
span_starts = span_starts * span_indices_mask
span_ends = span_ends * span_indices_mask
if not self._use_exclusive_start_indices:
start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
else:
# We want `exclusive` span starts, so we remove 1 from the forward span starts
# as the AllenNLP ``SpanField`` is inclusive.
# shape (batch_size, num_spans)
exclusive_span_starts = span_starts - 1
# shape (batch_size, num_spans, 1)
start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))
# We'll check the indices here at runtime, because it's difficult to debug
# if this goes wrong and it's tricky to get right.
if (exclusive_span_starts < 0).any():
raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, "
f"but found: exclusive_span_starts: {exclusive_span_starts}.")
start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
# We're using sentinels, so we need to replace all the elements which were
# outside the dimensions of the sequence_tensor with the start sentinel.
float_start_sentinel_mask = start_sentinel_mask.float()
start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
+ float_start_sentinel_mask * self._start_sentinel
combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(span_ends - span_starts,
num_total_buckets=self._num_width_embeddings)
else:
span_widths = span_ends - span_starts
span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([combined_tensors, span_width_embeddings], -1)
if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1).float()
return combined_tensors
示例5: _action_history_match
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
# TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
# Check if target is big enough to cover prediction (including start/end symbols)
if len(predicted) > targets.size(1):
return 0
predicted_tensor = targets.new_tensor(predicted)
targets_trimmed = targets[:, :len(predicted)]
# Return 1 if the predicted sequence is anywhere in the list of targets.
return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()
示例6: sequence_cross_entropy_with_logits
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
targets: torch.LongTensor,
weights: torch.FloatTensor,
batch_average: bool = True) -> torch.FloatTensor:
"""
Computes the cross entropy loss of a sequence, weighted with respect to
some user provided weights. Note that the weighting here is not the same as
in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
classes; here we are weighting the loss contribution from particular elements
in the sequence. This allows loss computations for models which use padding.
Parameters
----------
logits : ``torch.FloatTensor``, required.
A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
which contains the unnormalized probability for each class.
targets : ``torch.LongTensor``, required.
A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
index of the true class for each corresponding step.
weights : ``torch.FloatTensor``, required.
A ``torch.FloatTensor`` of size (batch, sequence_length)
batch_average : bool, optional, (default = True).
A bool indicating whether the loss should be averaged across the batch,
or returned as a vector of losses per batch element.
Returns
-------
A torch.FloatTensor representing the cross entropy loss.
If ``batch_average == True``, the returned loss is a scalar.
If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).
"""
# shape : (batch * sequence_length, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# shape : (batch * sequence_length, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat)
# shape : (batch * max_len, 1)
targets_flat = targets.view(-1, 1).long()
# Contribution to the negative log likelihood only comes from the exact indices
# of the targets, as the target distributions are one-hot. Here we use torch.gather
# to extract the indices of the num_classes dimension which contribute to the loss.
# shape : (batch * sequence_length, 1)
negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood * weights.float()
# shape : (batch_size,)
per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)
if batch_average:
num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
return per_batch_loss.sum() / num_non_empty_sequences
return per_batch_loss
示例7: _get_modified_precision_counts
def _get_modified_precision_counts(self,
predicted_tokens: torch.LongTensor,
reference_tokens: torch.LongTensor,
ngram_size: int) -> Tuple[int, int]:
"""
Compare the predicted tokens to the reference (gold) tokens at the desired
ngram size and calculate the numerator and denominator for a modified
form of precision.
The numerator is the number of ngrams in the predicted sentences that match
with an ngram in the corresponding reference sentence, clipped by the total
count of that ngram in the reference sentence. The denominator is just
the total count of predicted ngrams.
"""
clipped_matches = 0
total_predicted = 0
for batch_num in range(predicted_tokens.size(0)):
predicted_row = predicted_tokens[batch_num, :]
reference_row = reference_tokens[batch_num, :]
predicted_ngram_counts = self._ngrams(predicted_row, ngram_size)
reference_ngram_counts = self._ngrams(reference_row, ngram_size)
for ngram, count in predicted_ngram_counts.items():
clipped_matches += min(count, reference_ngram_counts[ngram])
total_predicted += count
return clipped_matches, total_predicted
示例8: _get_checklist_info
def _get_checklist_info(agenda: torch.LongTensor,
all_actions: List[ProductionRule],
terminal_productions: Set[str],
max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
world, and a length to pad the checklist vectors to, and returns a target checklist against
which the checklist at each state will be compared to compute a loss, indices of
``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
are relevant for checklist loss computation.
Parameters
----------
``agenda`` : ``torch.LongTensor``
Agenda of one instance of size ``(agenda_size, 1)``.
``all_actions`` : ``List[ProductionRule]``
All actions for one instance.
``terminal_productions`` : ``Set[str]``
String representations of terminal productions in the corresponding world.
``max_num_terminals`` : ``int``
Length to which the checklist vectors will be padded till. This is the max number of
terminal productions in all the worlds in the batch.
"""
terminal_indices = []
target_checklist_list = []
agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
# We want to return checklist target and terminal actions that are column vectors to make
# computing softmax over the difference between checklist and target easier.
for index, action in enumerate(all_actions):
# Each action is a ProductionRule, a tuple where the first item is the production
# rule string.
if action[0] in terminal_productions:
terminal_indices.append([index])
if index in agenda_indices_set:
target_checklist_list.append([1])
else:
target_checklist_list.append([0])
while len(target_checklist_list) < max_num_terminals:
target_checklist_list.append([0])
terminal_indices.append([-1])
# (max_num_terminals, 1)
terminal_actions = agenda.new_tensor(terminal_indices)
# (max_num_terminals, 1)
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
checklist_mask = (target_checklist != 0).float()
return target_checklist, terminal_actions, checklist_mask
示例9: forward
def forward(self, # pylint: disable=arguments-differ
inputs: torch.Tensor,
mask: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
inputs : ``torch.Tensor``, required.
A Tensor of shape ``(batch_size, sequence_length, hidden_size)``.
mask : ``torch.LongTensor``, required.
A binary mask of shape ``(batch_size, sequence_length)`` representing the
non-padded elements in each sequence in the batch.
Returns
-------
A ``torch.Tensor`` of shape (num_layers, batch_size, sequence_length, hidden_size),
where the num_layers dimension represents the LSTM output from that layer.
"""
batch_size, total_sequence_length = mask.size()
stacked_sequence_output, final_states, restoration_indices = \
self.sort_and_run_forward(self._lstm_forward, inputs, mask)
num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size()
# Add back invalid rows which were removed in the call to sort_and_run_forward.
if num_valid < batch_size:
zeros = stacked_sequence_output.data.new(num_layers,
batch_size - num_valid,
returned_timesteps,
encoder_dim).fill_(0)
zeros = Variable(zeros)
stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 1)
# The states also need to have invalid rows added back.
new_states = []
for state in final_states:
state_dim = state.size(-1)
zeros = state.data.new(num_layers, batch_size - num_valid, state_dim).fill_(0)
zeros = Variable(zeros)
new_states.append(torch.cat([state, zeros], 1))
final_states = new_states
# It's possible to need to pass sequences which are padded to longer than the
# max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
# the sequences mean that the returned tensor won't include these dimensions, because
# the RNN did not need to process them. We add them back on in the form of zeros here.
sequence_length_difference = total_sequence_length - returned_timesteps
if sequence_length_difference > 0:
zeros = stacked_sequence_output.data.new(num_layers,
batch_size,
sequence_length_difference,
stacked_sequence_output[0].size(-1)).fill_(0)
zeros = Variable(zeros)
stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 2)
self._update_states(final_states, restoration_indices)
# Restore the original indices and return the sequence.
# Has shape (num_layers, batch_size, sequence_length, hidden_size)
return stacked_sequence_output.index_select(1, restoration_indices)
示例10: _get_checklist_info
def _get_checklist_info(self,
agenda: torch.LongTensor,
all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""
Takes an agenda and a list of all actions and returns a target checklist against which the
checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
``False``, indices of all terminals that are not in the agenda will be masked.
Parameters
----------
``agenda`` : ``torch.LongTensor``
Agenda of one instance of size ``(agenda_size, 1)``.
``all_actions`` : ``List[ProductionRuleArray]``
All actions for one instance.
"""
terminal_indices = []
target_checklist_list = []
agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
for index, action in enumerate(all_actions):
# Each action is a ProductionRuleArray, a tuple where the first item is the production
# rule string.
if action[0] in self._terminal_productions:
terminal_indices.append([index])
if index in agenda_indices_set:
target_checklist_list.append([1])
else:
target_checklist_list.append([0])
# We want to return checklist target and terminal actions that are column vectors to make
# computing softmax over the difference between checklist and target easier.
# (num_terminals, 1)
terminal_actions = agenda.new_tensor(terminal_indices)
# (num_terminals, 1)
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
if self._penalize_non_agenda_actions:
# All terminal actions are relevant
checklist_mask = torch.ones_like(target_checklist)
else:
checklist_mask = (target_checklist != 0).float()
return target_checklist, terminal_actions, checklist_mask
示例11: batched_index_select
def batched_index_select(target: torch.Tensor,
indices: torch.LongTensor,
flattened_indices: Optional[torch.LongTensor] = None) -> torch.Tensor:
"""
The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into the sequence
dimension (dimension 2) of the target, which has size ``(batch_size, sequence_length,
embedding_size)``.
This function returns selected values in the target with respect to the provided indices, which
have size ``(batch_size, d_1, ..., d_n, embedding_size)``. This can use the optionally
precomputed :func:`~flattened_indices` with size ``(batch_size * d_1 * ... * d_n)`` if given.
An example use case of this function is looking up the start and end indices of spans in a
sequence tensor. This is used in the
:class:`~allennlp.models.coreference_resolution.CoreferenceResolver`. Model to select
contextual word representations corresponding to the start and end indices of mentions. The key
reason this can't be done with basic torch functions is that we want to be able to use look-up
tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know
a-priori how many spans we are looking up).
Parameters
----------
target : ``torch.Tensor``, required.
A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
This is the tensor to be indexed.
indices : ``torch.LongTensor``
A tensor of shape (batch_size, ...), where each element is an index into the
``sequence_length`` dimension of the ``target`` tensor.
flattened_indices : Optional[torch.Tensor], optional (default = None)
An optional tensor representing the result of calling :func:~`flatten_and_batch_shift_indices`
on ``indices``. This is helpful in the case that the indices can be flattened once and
cached for many batch lookups.
Returns
-------
selected_targets : ``torch.Tensor``
A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices
extracted from the batch flattened target tensor.
"""
if flattened_indices is None:
# Shape: (batch_size * d_1 * ... * d_n)
flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))
# Shape: (batch_size * sequence_length, embedding_size)
flattened_target = target.view(-1, target.size(-1))
# Shape: (batch_size * d_1 * ... * d_n, embedding_size)
flattened_selected = flattened_target.index_select(0, flattened_indices)
selected_shape = list(indices.size()) + [target.size(-1)]
# Shape: (batch_size, d_1, ..., d_n, embedding_size)
selected_targets = flattened_selected.view(*selected_shape)
return selected_targets
示例12: greedy_predict
def greedy_predict(self,
final_encoder_output: torch.LongTensor,
target_embedder: Embedding,
decoder_cell: GRUCell,
output_projection_layer: Linear) -> torch.Tensor:
"""
Greedily produces a sequence using the provided ``decoder_cell``.
Returns the predicted sequence.
Parameters
----------
final_encoder_output : ``torch.LongTensor``, required
Vector produced by ``self._encoder``.
target_embedder : ``Embedding``, required
Used to embed the target tokens.
decoder_cell: ``GRUCell``, required
The recurrent cell used at each time step.
output_projection_layer: ``Linear``, required
Linear layer mapping to the desired number of classes.
"""
num_decoding_steps = self._max_decoding_steps
decoder_hidden = final_encoder_output
batch_size = final_encoder_output.size()[0]
predictions = [final_encoder_output.new_full(
(batch_size,), fill_value=self._start_index, dtype=torch.long
)]
for _ in range(num_decoding_steps):
input_choices = predictions[-1]
decoder_input = target_embedder(input_choices)
decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
# (batch_size, num_classes)
output_projections = output_projection_layer(decoder_hidden)
class_probabilities = F.softmax(output_projections, dim=-1)
_, predicted_classes = torch.max(class_probabilities, 1)
predictions.append(predicted_classes)
all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
# Drop start symbol and return.
return all_predictions[:, 1:]
示例13: forward
def forward(self,
input_ids: torch.LongTensor,
offsets: torch.LongTensor = None,
token_type_ids: torch.LongTensor = None) -> torch.Tensor:
"""
Parameters
----------
input_ids : ``torch.LongTensor``
The (batch_size, max_sequence_length) tensor of wordpiece ids.
offsets : ``torch.LongTensor``, optional
The BERT embeddings are one per wordpiece. However it's possible/likely
you might want one per original token. In that case, ``offsets``
represents the indices of the desired wordpiece for each original token.
Depending on how your token indexer is configured, this could be the
position of the last wordpiece for each token, or it could be the position
of the first wordpiece for each token.
For example, if you had the sentence "Definitely not", and if the corresponding
wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
If offsets are provided, the returned tensor will contain only the wordpiece
embeddings at those positions, and (in particular) will contain one embedding
per token. If offsets are not provided, the entire tensor of wordpiece embeddings
will be returned.
token_type_ids : ``torch.LongTensor``, optional
If an input consists of two sentences (as in the BERT paper),
tokens from the first sentence should have type 0 and tokens from
the second sentence should have type 1. If you don't provide this
(the default BertIndexer doesn't) then it's assumed to be all 0s.
"""
# pylint: disable=arguments-differ
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
input_mask = (input_ids != 0).long()
all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
if self._scalar_mix is not None:
mix = self._scalar_mix(all_encoder_layers, input_mask)
else:
mix = all_encoder_layers[-1]
if offsets is None:
return mix
else:
batch_size = input_ids.size(0)
range_vector = util.get_range_vector(batch_size,
device=util.get_device_of(mix)).unsqueeze(1)
return mix[range_vector, offsets]
示例14: _ngrams
def _ngrams(self,
tensor: torch.LongTensor,
ngram_size: int) -> Dict[Tuple[int, ...], int]:
ngram_counts: Dict[Tuple[int, ...], int] = Counter()
if ngram_size > tensor.size(-1):
return ngram_counts
for start_position in range(ngram_size):
for tensor_slice in tensor[start_position:].split(ngram_size, dim=-1):
if tensor_slice.size(-1) < ngram_size:
break
ngram = tuple(x.item() for x in tensor_slice)
if any(x in self._exclude_indices for x in ngram):
continue
ngram_counts[ngram] += 1
return ngram_counts
示例15: _prepare_decode_step_input
def _prepare_decode_step_input(self,
input_indices: torch.LongTensor,
decoder_hidden_state: torch.LongTensor = None,
encoder_outputs: torch.LongTensor = None,
encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
"""
Given the input indices for the current timestep of the decoder, and all the encoder
outputs, compute the input at the current timestep. Note: This method is agnostic to
whether the indices are gold indices or the predictions made by the decoder at the last
timestep. So, this can be used even if we're doing some kind of scheduled sampling.
If we're not using attention, the output of this method is just an embedding of the input
indices. If we are, the output will be a concatentation of the embedding and an attended
average of the encoder inputs.
Parameters
----------
input_indices : torch.LongTensor
Indices of either the gold inputs to the decoder or the predicted labels from the
previous timestep.
decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
Output of from the decoder at the last time step. Needed only if using attention.
encoder_outputs : torch.LongTensor, optional (not needed if no attention)
Encoder outputs from all time steps. Needed only if using attention.
encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
Masks on encoder outputs. Needed only if using attention.
"""
# input_indices : (batch_size,) since we are processing these one timestep at a time.
# (batch_size, target_embedding_dim)
embedded_input = self._target_embedder(input_indices)
if self._attention_function:
# encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
# Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
# complain.
encoder_outputs_mask = encoder_outputs_mask.float()
# (batch_size, input_sequence_length)
input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
# (batch_size, encoder_output_dim)
attended_input = weighted_sum(encoder_outputs, input_weights)
# (batch_size, encoder_output_dim + target_embedding_dim)
return torch.cat((attended_input, embedded_input), -1)
else:
return embedded_input