本文整理汇总了Python中tensorflow.compat.v2.argmax方法的典型用法代码示例。如果您正苦于以下问题:Python v2.argmax方法的具体用法?Python v2.argmax怎么用?Python v2.argmax使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v2
的用法示例。
在下文中一共展示了v2.argmax方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: select_actor_action
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def select_actor_action(self, env_output, agent_output):
assert self._mode, 'mode must be set for selecting action in actor.'
oracle_next_action = env_output.observation[
streetview_constants.ORACLE_NEXT_ACTION]
if self._mode == 'train':
if self._loss_type == common.CE_LOSS:
# This is teacher-forcing mode, so choose action same as oracle action.
action_idx = oracle_next_action
elif self._loss_type == common.AC_LOSS:
action_idx = tfp.distributions.Categorical(
logits=agent_output.policy_logits).sample()
else:
# In non-train modes, choose greedily.
action_idx = tf.argmax(agent_output.policy_logits, axis=-1)
# Return ActorAction and the action to be passed to the env step function.
return common.ActorAction(
chosen_action_idx=int(action_idx.numpy()),
oracle_next_action_idx=int(
oracle_next_action.numpy())), action_idx.numpy()
示例2: select_actor_action
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def select_actor_action(self, env_output, agent_output):
oracle_next_action = env_output.observation[constants.ORACLE_NEXT_ACTION]
oracle_next_action_indices = tf.where(
tf.equal(env_output.observation[constants.CONN_IDS],
oracle_next_action))
oracle_next_action_idx = tf.reduce_min(oracle_next_action_indices)
assert self._mode, 'mode must be set.'
if self._mode == 'train':
if self._loss_type == common.CE_LOSS:
# This is teacher-forcing mode, so choose action same as oracle action.
action_idx = oracle_next_action_idx
elif self._loss_type == common.AC_LOSS:
# Choose next pano from probability distribution over next panos
action_idx = tfp.distributions.Categorical(
logits=agent_output.policy_logits).sample()
else:
raise ValueError('Unsupported loss type {}'.format(self._loss_type))
else:
# In non-train modes, choose greedily.
action_idx = tf.argmax(agent_output.policy_logits, axis=-1)
action_val = env_output.observation[constants.CONN_IDS][action_idx]
return common.ActorAction(
chosen_action_idx=int(action_idx.numpy()),
oracle_next_action_idx=int(oracle_next_action_idx.numpy())), int(
action_val.numpy())
示例3: argmax
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def argmax(a, axis=None):
return _argminmax(tf.argmax, a, axis)
示例4: config_model_evaluation
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def config_model_evaluation(self, model, labels_ph, params=None):
model.accuracy = tf1.metrics.accuracy(
tf.argmax(input=model.labels, axis=1),
tf.argmax(input=model.predicted_y.tf, axis=1))
model.top_labels = util.labels_of_top_ranked_predictions_in_batch(
model.labels, model.predicted_y.tf)
model.precision_at_one = tf1.metrics.mean(model.top_labels)
model.evaluations = {
"accuracy": model.accuracy,
"precision@1": model.precision_at_one
}
示例5: convert_to_one_hot
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def convert_to_one_hot(self, samples):
return tf.one_hot(
tf.argmax(samples, axis=-1),
self.distribution.event_size, dtype=self._output_dtype)
示例6: _quantization_offset
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def _quantization_offset(self):
# Picks the "peakiest" of the component quantization offsets.
offsets = helpers.quantization_offset(self.components_distribution)
rank = self.batch_shape.rank
transposed_offsets = tf.transpose(offsets, [rank] + list(range(rank)))
component = tf.argmax(self.log_prob(transposed_offsets), axis=0)
return tf.gather(offsets, component, axis=-1, batch_dims=rank)
示例7: predict_class
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def predict_class(self, text_token_ids, action_panos):
"""Takes in an instruction and action and returns classifier outputs.
Args:
text_token_ids: Tensor of token indices for the input instruction.
action_panos: Tensor of concatenated image panoramas.
Returns:
(class_outputs, predictions): Output of last layer of MLP and prediction.
"""
text_enc_outputs, current_lstm_state = self.encode_instruction(
text_token_ids)
lstm_output, next_lstm_state = self.encode_action(current_lstm_state,
action_panos)
lstm_output = tf.expand_dims(lstm_output, axis=1)
batch_size = text_enc_outputs.shape[0]
# c_text has shape [batch_size, 1, self._text_attention_size]
c_text = self._text_attention([
self._text_attention_project_hidden(lstm_output),
self._text_attention_project_text(text_enc_outputs)
])
# convert ListWrapper output of next_lstm_state to tuples
result_state = []
for one_state in next_lstm_state:
result_state.append((one_state[0], one_state[1]))
hidden_state = lstm_output
c_visual = self._visual_attention([
self._visual_attention_project_ctext(c_text),
self._visual_attention_project_feature(action_panos),
])
input_feature = tf.concat([hidden_state, c_text, c_visual], axis=2)
class_outputs = self._mlp_layer(input_feature)
class_outputs = tf.reshape(class_outputs, (batch_size, 2))
predictions = tf.argmax(class_outputs, axis=-1)
return (class_outputs, predictions)
示例8: accuracy
# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def accuracy(self, true_labels, predictions):
"""Takes in predictions and true labels and returns accuracy.
Args:
true_labels: a tensor of shape [batch_size, n_classes]
predictions: a tensor of shape [batch_size, 1].
Returns:
loss: a scalar cross entropy loss.
"""
true_labels = tf.keras.backend.argmax(true_labels)
metric = tf.keras.metrics.Accuracy()
accuracy = metric.update_state(true_labels, predictions)
return accuracy