本文整理汇总了Python中torch.Tensor.float方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.float方法的具体用法?Python Tensor.float怎么用?Python Tensor.float使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.Tensor
的用法示例。
在下文中一共展示了Tensor.float方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _construct_loss
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def _construct_loss(self,
arc_scores: torch.Tensor,
arc_tag_logits: torch.Tensor,
arc_tags: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the arc and tag loss for an adjacency matrix.
Parameters
----------
arc_scores : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a
binary classification decision for whether an edge is present between two words.
arc_tag_logits : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to generate
a distribution over edge tags for a given edge.
arc_tags : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, sequence_length).
The labels for every arc.
mask : ``torch.Tensor``, required.
A mask of shape (batch_size, sequence_length), denoting unpadded
elements in the sequence.
Returns
-------
arc_nll : ``torch.Tensor``, required.
The negative log likelihood from the arc loss.
tag_nll : ``torch.Tensor``, required.
The negative log likelihood from the arc tag loss.
"""
float_mask = mask.float()
arc_indices = (arc_tags != -1).float()
# Make the arc tags not have negative values anywhere
# (by default, no edge is indicated with -1).
arc_tags = arc_tags * arc_indices
arc_nll = self._arc_loss(arc_scores, arc_indices) * float_mask.unsqueeze(1) * float_mask.unsqueeze(2)
# We want the mask for the tags to only include the unmasked words
# and we only care about the loss with respect to the gold arcs.
tag_mask = float_mask.unsqueeze(1) * float_mask.unsqueeze(2) * arc_indices
batch_size, sequence_length, _, num_tags = arc_tag_logits.size()
original_shape = [batch_size, sequence_length, sequence_length]
reshaped_logits = arc_tag_logits.view(-1, num_tags)
reshaped_tags = arc_tags.view(-1)
tag_nll = self._tag_loss(reshaped_logits, reshaped_tags.long()).view(original_shape) * tag_mask
valid_positions = tag_mask.sum()
arc_nll = arc_nll.sum() / valid_positions.float()
tag_nll = tag_nll.sum() / valid_positions.float()
return arc_nll, tag_nll
示例2: __call__
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
Parameters
----------
predictions : ``torch.Tensor``, required.
A tensor of predictions of shape (batch_size, ..., num_classes).
gold_labels : ``torch.Tensor``, required.
A tensor of integer class label of shape (batch_size, ...). It must be the same
shape as the ``predictions`` tensor without the ``num_classes`` dimension.
mask: ``torch.Tensor``, optional (default = None).
A masking tensor the same size as ``gold_labels``.
"""
# Get the data from the Variables.
predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
num_classes = predictions.size(-1)
if (gold_labels >= num_classes).any():
raise ConfigurationError("A gold label passed to F1Measure contains an id >= {}, "
"the number of classes.".format(num_classes))
if mask is None:
mask = ones_like(gold_labels)
mask = mask.float()
gold_labels = gold_labels.float()
positive_label_mask = gold_labels.eq(self._positive_label).float()
negative_label_mask = 1.0 - positive_label_mask
argmax_predictions = predictions.max(-1)[1].float().squeeze(-1)
# True Negatives: correct non-positive predictions.
correct_null_predictions = (argmax_predictions !=
self._positive_label).float() * negative_label_mask
self._true_negatives += (correct_null_predictions.float() * mask).sum()
# True Positives: correct positively labeled predictions.
correct_non_null_predictions = (argmax_predictions ==
self._positive_label).float() * positive_label_mask
self._true_positives += (correct_non_null_predictions * mask).sum()
# False Negatives: incorrect negatively labeled predictions.
incorrect_null_predictions = (argmax_predictions !=
self._positive_label).float() * positive_label_mask
self._false_negatives += (incorrect_null_predictions * mask).sum()
# False Positives: incorrect positively labeled predictions
incorrect_non_null_predictions = (argmax_predictions ==
self._positive_label).float() * negative_label_mask
self._false_positives += (incorrect_non_null_predictions * mask).sum()
示例3: _input_likelihood
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
"""
batch_size, sequence_length, num_tags = logits.size()
# Transpose batch size and sequence dimensions
mask = mask.float().transpose(0, 1).contiguous()
logits = logits.transpose(0, 1).contiguous()
# Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
# transitions to the initial states and the logits for the first timestep.
if self.include_start_end_transitions:
alpha = self.start_transitions.view(1, num_tags) + logits[0]
else:
alpha = logits[0]
# For each i we compute logits for the transitions from timestep i-1 to timestep i.
# We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
# (instance, current_tag, next_tag)
for i in range(1, sequence_length):
# The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
emit_scores = logits[i].view(batch_size, 1, num_tags)
# Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
transition_scores = self.transitions.view(1, num_tags, num_tags)
# Alpha is for the current_tag, so we broadcast along the next_tag axis.
broadcast_alpha = alpha.view(batch_size, num_tags, 1)
# Add all the scores together and logexp over the current_tag axis
inner = broadcast_alpha + emit_scores + transition_scores
# In valid positions (mask == 1) we want to take the logsumexp over the current_tag dimension
# of ``inner``. Otherwise (mask == 0) we want to retain the previous alpha.
alpha = (util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) +
alpha * (1 - mask[i]).view(batch_size, 1))
# Every sequence needs to end with a transition to the stop_tag.
if self.include_start_end_transitions:
stops = alpha + self.end_transitions.view(1, num_tags)
else:
stops = alpha
# Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
return util.logsumexp(stops)
示例4: forward
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ
mask: torch.Tensor = None) -> torch.Tensor:
"""
Compute a weighted average of the ``tensors``. The input tensors an be any shape
with at least two dimensions, but must all be the same shape.
When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are
dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned
``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape
``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``.
When ``do_layer_norm=False`` the ``mask`` is ignored.
"""
if len(tensors) != self.mixture_size:
raise ConfigurationError("{} tensors were passed, but the module was initialized to "
"mix {} tensors.".format(len(tensors), self.mixture_size))
def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
tensor_masked = tensor * broadcast_mask
mean = torch.sum(tensor_masked) / num_elements_not_masked
variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked
return (tensor - mean) / torch.sqrt(variance + 1E-12)
normed_weights = torch.nn.functional.softmax(torch.cat([parameter for parameter
in self.scalar_parameters]), dim=0)
normed_weights = torch.split(normed_weights, split_size=1)
if not self.do_layer_norm:
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * tensor)
return self.gamma * sum(pieces)
else:
mask_float = mask.float()
broadcast_mask = mask_float.unsqueeze(-1)
input_dim = tensors[0].size(-1)
num_elements_not_masked = torch.sum(mask_float) * input_dim
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * _do_layer_norm(tensor,
broadcast_mask, num_elements_not_masked))
return self.gamma * sum(pieces)
示例5: bucket_values
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def bucket_values(distances: torch.Tensor,
num_identity_buckets: int = 4,
num_total_buckets: int = 10) -> torch.Tensor:
"""
Places the given values (designed for distances) into ``num_total_buckets``semi-logscale
buckets, with ``num_identity_buckets`` of these capturing single values.
The default settings will bucket values into the following buckets:
[0, 1, 2, 3, 4, 5-7, 8-15, 16-31, 32-63, 64+].
Parameters
----------
distances : ``torch.Tensor``, required.
A Tensor of any size, to be bucketed.
num_identity_buckets: int, optional (default = 4).
The number of identity buckets (those only holding a single value).
num_total_buckets : int, (default = 10)
The total number of buckets to bucket values into.
Returns
-------
A tensor of the same shape as the input, containing the indices of the buckets
the values were placed in.
"""
# Chunk the values into semi-logscale buckets using .floor().
# This is a semi-logscale bucketing because we divide by log(2) after taking the log.
# We do this to make the buckets more granular in the initial range, where we expect
# most values to fall. We then add (num_identity_buckets - 1) because we want these indices
# to start _after_ the fixed number of buckets which we specified would only hold single values.
logspace_index = (distances.float().log()/math.log(2)).floor().long() + (num_identity_buckets - 1)
# create a mask for values which will go into single number buckets (i.e not a range).
use_identity_mask = (distances <= num_identity_buckets).long()
use_buckets_mask = 1 + (-1 * use_identity_mask)
# Use the original values if they are less than num_identity_buckets, otherwise
# use the logspace indices.
combined_index = use_identity_mask * distances + use_buckets_mask * logspace_index
# Clamp to put anything > num_total_buckets into the final bucket.
return combined_index.clamp(0, num_total_buckets - 1)
示例6: _mst_decode
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def _mst_decode(self,
head_tag_representation: torch.Tensor,
child_tag_representation: torch.Tensor,
attended_arcs: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decodes the head and head tag predictions using the Edmonds' Algorithm
for finding minimum spanning trees on directed graphs. Nodes in the
graph are the words in the sentence, and between each pair of nodes,
there is an edge in each direction, where the weight of the edge corresponds
to the most likely dependency label probability for that arc. The MST is
then generated from this directed graph.
Parameters
----------
head_tag_representation : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, tag_representation_dim),
which will be used to generate predictions for the dependency tags
for the given arcs.
child_tag_representation : ``torch.Tensor``, required
A tensor of shape (batch_size, sequence_length, tag_representation_dim),
which will be used to generate predictions for the dependency tags
for the given arcs.
attended_arcs : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
a distribution over attachements of a given word to all other words.
Returns
-------
heads : ``torch.Tensor``
A tensor of shape (batch_size, sequence_length) representing the
greedily decoded heads of each word.
head_tags : ``torch.Tensor``
A tensor of shape (batch_size, sequence_length) representing the
dependency tags of the optimally decoded heads of each word.
"""
batch_size, sequence_length, tag_representation_dim = head_tag_representation.size()
lengths = mask.data.sum(dim=1).long().cpu().numpy()
expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim]
head_tag_representation = head_tag_representation.unsqueeze(2)
head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous()
child_tag_representation = child_tag_representation.unsqueeze(1)
child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous()
# Shape (batch_size, sequence_length, sequence_length, num_head_tags)
pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation)
# Note that this log_softmax is over the tag dimension, and we don't consider pairs
# of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
# Shape (batch, num_labels,sequence_length, sequence_length)
normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2)
# Mask padded tokens, because we only want to consider actual words as heads.
minus_inf = -1e8
minus_mask = (1 - mask.float()) * minus_inf
attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)
# Shape (batch_size, sequence_length, sequence_length)
normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2)
# Shape (batch_size, num_head_tags, sequence_length, sequence_length)
# This energy tensor expresses the following relation:
# energy[i,j] = "Score that i is the head of j". In this
# case, we have heads pointing to their children.
batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits)
return self._run_mst_decoding(batch_energy, lengths)
示例7: _construct_loss
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def _construct_loss(self,
head_tag_representation: torch.Tensor,
child_tag_representation: torch.Tensor,
attended_arcs: torch.Tensor,
head_indices: torch.Tensor,
head_tags: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the arc and tag loss for a sequence given gold head indices and tags.
Parameters
----------
head_tag_representation : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, tag_representation_dim),
which will be used to generate predictions for the dependency tags
for the given arcs.
child_tag_representation : ``torch.Tensor``, required
A tensor of shape (batch_size, sequence_length, tag_representation_dim),
which will be used to generate predictions for the dependency tags
for the given arcs.
attended_arcs : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
a distribution over attachements of a given word to all other words.
head_indices : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length).
The indices of the heads for every word.
head_tags : ``torch.Tensor``, required.
A tensor of shape (batch_size, sequence_length).
The dependency labels of the heads for every word.
mask : ``torch.Tensor``, required.
A mask of shape (batch_size, sequence_length), denoting unpadded
elements in the sequence.
Returns
-------
arc_nll : ``torch.Tensor``, required.
The negative log likelihood from the arc loss.
tag_nll : ``torch.Tensor``, required.
The negative log likelihood from the arc tag loss.
"""
float_mask = mask.float()
batch_size, sequence_length, _ = attended_arcs.size()
# shape (batch_size, 1)
range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1)
# shape (batch_size, sequence_length, sequence_length)
normalised_arc_logits = masked_log_softmax(attended_arcs,
mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1)
# shape (batch_size, sequence_length, num_head_tags)
head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices)
normalised_head_tag_logits = masked_log_softmax(head_tag_logits,
mask.unsqueeze(-1)) * float_mask.unsqueeze(-1)
# index matrix with shape (batch, sequence_length)
timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs))
child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long()
# shape (batch_size, sequence_length)
arc_loss = normalised_arc_logits[range_vector, child_index, head_indices]
tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags]
# We don't care about predictions for the symbolic ROOT token's head,
# so we remove it from the loss.
arc_loss = arc_loss[:, 1:]
tag_loss = tag_loss[:, 1:]
# The number of valid positions is equal to the number of unmasked elements minus
# 1 per sequence in the batch, to account for the symbolic HEAD token.
valid_positions = mask.sum() - batch_size
arc_nll = -arc_loss.sum() / valid_positions.float()
tag_nll = -tag_loss.sum() / valid_positions.float()
return arc_nll, tag_nll
示例8: forward
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def forward(self,
context_1: torch.Tensor,
mask_1: torch.Tensor,
context_2: torch.Tensor,
mask_2: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# pylint: disable=arguments-differ
"""
Given the forward (or backward) representations of sentence1 and sentence2, apply four bilateral
matching functions between them in one direction.
Parameters
----------
context_1 : ``torch.Tensor``
Tensor of shape (batch_size, seq_len1, hidden_dim) representing the encoding of the first sentence.
mask_1 : ``torch.Tensor``
Binary Tensor of shape (batch_size, seq_len1), indicating which
positions in the first sentence are padding (0) and which are not (1).
context_2 : ``torch.Tensor``
Tensor of shape (batch_size, seq_len2, hidden_dim) representing the encoding of the second sentence.
mask_2 : ``torch.Tensor``
Binary Tensor of shape (batch_size, seq_len2), indicating which
positions in the second sentence are padding (0) and which are not (1).
Returns
-------
A tuple of matching vectors for the two sentences. Each of which is a list of
matching vectors of shape (batch, seq_len, num_perspectives or 1)
"""
assert (not mask_2.requires_grad) and (not mask_1.requires_grad)
assert context_1.size(-1) == context_2.size(-1) == self.hidden_dim
# (batch,)
len_1 = get_lengths_from_binary_sequence_mask(mask_1)
len_2 = get_lengths_from_binary_sequence_mask(mask_2)
# (batch, seq_len*)
mask_1, mask_2 = mask_1.float(), mask_2.float()
# explicitly set masked weights to zero
# (batch_size, seq_len*, hidden_dim)
context_1 = context_1 * mask_1.unsqueeze(-1)
context_2 = context_2 * mask_2.unsqueeze(-1)
# array to keep the matching vectors for the two sentences
matching_vector_1: List[torch.Tensor] = []
matching_vector_2: List[torch.Tensor] = []
# Step 0. unweighted cosine
# First calculate the cosine similarities between each forward
# (or backward) contextual embedding and every forward (or backward)
# contextual embedding of the other sentence.
# (batch, seq_len1, seq_len2)
cosine_sim = F.cosine_similarity(context_1.unsqueeze(-2), context_2.unsqueeze(-3), dim=3)
# (batch, seq_len*, 1)
cosine_max_1 = masked_max(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
cosine_mean_1 = masked_mean(cosine_sim, mask_2.unsqueeze(-2), dim=2, keepdim=True)
cosine_max_2 = masked_max(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)
cosine_mean_2 = masked_mean(cosine_sim.permute(0, 2, 1), mask_1.unsqueeze(-2), dim=2, keepdim=True)
matching_vector_1.extend([cosine_max_1, cosine_mean_1])
matching_vector_2.extend([cosine_max_2, cosine_mean_2])
# Step 1. Full-Matching
# Each time step of forward (or backward) contextual embedding of one sentence
# is compared with the last time step of the forward (or backward)
# contextual embedding of the other sentence
if self.with_full_match:
# (batch, 1, hidden_dim)
if self.is_forward:
# (batch, 1, hidden_dim)
last_position_1 = (len_1 - 1).clamp(min=0)
last_position_1 = last_position_1.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)
last_position_2 = (len_2 - 1).clamp(min=0)
last_position_2 = last_position_2.view(-1, 1, 1).expand(-1, 1, self.hidden_dim)
context_1_last = context_1.gather(1, last_position_1)
context_2_last = context_2.gather(1, last_position_2)
else:
context_1_last = context_1[:, 0:1, :]
context_2_last = context_2[:, 0:1, :]
# (batch, seq_len*, num_perspectives)
matching_vector_1_full = multi_perspective_match(context_1,
context_2_last,
self.full_match_weights)
matching_vector_2_full = multi_perspective_match(context_2,
context_1_last,
self.full_match_weights_reversed)
matching_vector_1.extend(matching_vector_1_full)
matching_vector_2.extend(matching_vector_2_full)
# Step 2. Maxpooling-Matching
# Each time step of forward (or backward) contextual embedding of one sentence
# is compared with every time step of the forward (or backward)
# contextual embedding of the other sentence, and only the max value of each
# dimension is retained.
#.........这里部分代码省略.........
示例9: create_from_tensors
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import float [as 别名]
def create_from_tensors(
cls,
trainer: RLTrainer,
mdp_ids: np.ndarray,
sequence_numbers: torch.Tensor,
states: torch.Tensor,
actions: torch.Tensor,
propensities: torch.Tensor,
rewards: torch.Tensor,
possible_actions_state_concat: Optional[torch.Tensor],
possible_actions_mask: torch.Tensor,
metrics: Optional[torch.Tensor] = None,
):
with torch.no_grad():
# Switch to evaluation mode for the network
old_q_train_state = trainer.q_network.training
old_reward_train_state = trainer.reward_network.training
trainer.q_network.train(False)
trainer.reward_network.train(False)
if possible_actions_state_concat is not None:
state_action_pairs = torch.cat((states, actions), dim=1)
# Parametric actions
rewards = rewards
model_values = trainer.q_network(possible_actions_state_concat)
assert (
model_values.shape[0] * model_values.shape[1]
== possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
), (
"Invalid shapes: "
+ str(model_values.shape)
+ " != "
+ str(possible_actions_mask.shape)
)
model_values = model_values.reshape(possible_actions_mask.shape)
model_rewards = trainer.reward_network(possible_actions_state_concat)
assert (
model_rewards.shape[0] * model_rewards.shape[1]
== possible_actions_mask.shape[0] * possible_actions_mask.shape[1]
), (
"Invalid shapes: "
+ str(model_rewards.shape)
+ " != "
+ str(possible_actions_mask.shape)
)
model_rewards = model_rewards.reshape(possible_actions_mask.shape)
model_values_for_logged_action = trainer.q_network(state_action_pairs)
model_rewards_for_logged_action = trainer.reward_network(
state_action_pairs
)
action_mask = (
torch.abs(model_values - model_values_for_logged_action) < 1e-3
).float()
model_metrics = None
model_metrics_for_logged_action = None
model_metrics_values = None
model_metrics_values_for_logged_action = None
else:
action_mask = actions.float()
# Switch to evaluation mode for the network
old_q_cpe_train_state = trainer.q_network_cpe.training
trainer.q_network_cpe.train(False)
# Discrete actions
rewards = trainer.boost_rewards(rewards, actions)
model_values = trainer.get_detached_q_values(states)[0]
assert model_values.shape == actions.shape, (
"Invalid shape: "
+ str(model_values.shape)
+ " != "
+ str(actions.shape)
)
assert model_values.shape == possible_actions_mask.shape, (
"Invalid shape: "
+ str(model_values.shape)
+ " != "
+ str(possible_actions_mask.shape)
)
model_values_for_logged_action = torch.sum(
model_values * action_mask, dim=1, keepdim=True
)
rewards_and_metric_rewards = trainer.reward_network(states)
num_actions = trainer.num_actions
model_rewards = rewards_and_metric_rewards[:, 0:num_actions]
assert model_rewards.shape == actions.shape, (
"Invalid shape: "
+ str(model_rewards.shape)
+ " != "
+ str(actions.shape)
)
model_rewards_for_logged_action = torch.sum(
#.........这里部分代码省略.........