当前位置: 首页>>代码示例>>Python>>正文


Python v2.argmax方法代码示例

本文整理汇总了Python中tensorflow.compat.v2.argmax方法的典型用法代码示例。如果您正苦于以下问题:Python v2.argmax方法的具体用法?Python v2.argmax怎么用?Python v2.argmax使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v2的用法示例。


在下文中一共展示了v2.argmax方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: select_actor_action

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def select_actor_action(self, env_output, agent_output):
    assert self._mode, 'mode must be set for selecting action in actor.'
    oracle_next_action = env_output.observation[
        streetview_constants.ORACLE_NEXT_ACTION]
    if self._mode == 'train':
      if self._loss_type == common.CE_LOSS:
        # This is teacher-forcing mode, so choose action same as oracle action.
        action_idx = oracle_next_action
      elif self._loss_type == common.AC_LOSS:
        action_idx = tfp.distributions.Categorical(
            logits=agent_output.policy_logits).sample()
    else:
      # In non-train modes, choose greedily.
      action_idx = tf.argmax(agent_output.policy_logits, axis=-1)

    # Return ActorAction and the action to be passed to the env step function.
    return common.ActorAction(
        chosen_action_idx=int(action_idx.numpy()),
        oracle_next_action_idx=int(
            oracle_next_action.numpy())), action_idx.numpy() 
开发者ID:google-research,项目名称:valan,代码行数:22,代码来源:td_problem.py

示例2: select_actor_action

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def select_actor_action(self, env_output, agent_output):
    oracle_next_action = env_output.observation[constants.ORACLE_NEXT_ACTION]
    oracle_next_action_indices = tf.where(
        tf.equal(env_output.observation[constants.CONN_IDS],
                 oracle_next_action))
    oracle_next_action_idx = tf.reduce_min(oracle_next_action_indices)
    assert self._mode, 'mode must be set.'
    if self._mode == 'train':
      if self._loss_type == common.CE_LOSS:
        # This is teacher-forcing mode, so choose action same as oracle action.
        action_idx = oracle_next_action_idx
      elif self._loss_type == common.AC_LOSS:
        # Choose next pano from probability distribution over next panos
        action_idx = tfp.distributions.Categorical(
            logits=agent_output.policy_logits).sample()
      else:
        raise ValueError('Unsupported loss type {}'.format(self._loss_type))
    else:
      # In non-train modes, choose greedily.
      action_idx = tf.argmax(agent_output.policy_logits, axis=-1)
    action_val = env_output.observation[constants.CONN_IDS][action_idx]
    return common.ActorAction(
        chosen_action_idx=int(action_idx.numpy()),
        oracle_next_action_idx=int(oracle_next_action_idx.numpy())), int(
            action_val.numpy()) 
开发者ID:google-research,项目名称:valan,代码行数:27,代码来源:ndh_problem.py

示例3: argmax

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def argmax(a, axis=None):
  return _argminmax(tf.argmax, a, axis) 
开发者ID:google,项目名称:trax,代码行数:4,代码来源:math_ops.py

示例4: config_model_evaluation

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def config_model_evaluation(self, model, labels_ph, params=None):
    model.accuracy = tf1.metrics.accuracy(
        tf.argmax(input=model.labels, axis=1),
        tf.argmax(input=model.predicted_y.tf, axis=1))
    model.top_labels = util.labels_of_top_ranked_predictions_in_batch(
        model.labels, model.predicted_y.tf)
    model.precision_at_one = tf1.metrics.mean(model.top_labels)
    model.evaluations = {
        "accuracy": model.accuracy,
        "precision@1": model.precision_at_one
    } 
开发者ID:google-research,项目名称:language,代码行数:13,代码来源:util_test.py

示例5: convert_to_one_hot

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def convert_to_one_hot(self, samples):
    return tf.one_hot(
        tf.argmax(samples, axis=-1),
        self.distribution.event_size, dtype=self._output_dtype) 
开发者ID:tensorflow,项目名称:agents,代码行数:6,代码来源:gumbel_softmax.py

示例6: _quantization_offset

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def _quantization_offset(self):
    # Picks the "peakiest" of the component quantization offsets.
    offsets = helpers.quantization_offset(self.components_distribution)
    rank = self.batch_shape.rank
    transposed_offsets = tf.transpose(offsets, [rank] + list(range(rank)))
    component = tf.argmax(self.log_prob(transposed_offsets), axis=0)
    return tf.gather(offsets, component, axis=-1, batch_dims=rank) 
开发者ID:tensorflow,项目名称:compression,代码行数:9,代码来源:uniform_noise.py

示例7: predict_class

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def predict_class(self, text_token_ids, action_panos):
    """Takes in an instruction and action and returns classifier outputs.

    Args:
      text_token_ids: Tensor of token indices for the input instruction.
      action_panos: Tensor of concatenated image panoramas.

    Returns:
      (class_outputs, predictions): Output of last layer of MLP and prediction.
    """
    text_enc_outputs, current_lstm_state = self.encode_instruction(
        text_token_ids)
    lstm_output, next_lstm_state = self.encode_action(current_lstm_state,
                                                      action_panos)
    lstm_output = tf.expand_dims(lstm_output, axis=1)
    batch_size = text_enc_outputs.shape[0]

    # c_text has shape [batch_size, 1, self._text_attention_size]
    c_text = self._text_attention([
        self._text_attention_project_hidden(lstm_output),
        self._text_attention_project_text(text_enc_outputs)
    ])
    # convert ListWrapper output of next_lstm_state to tuples
    result_state = []
    for one_state in next_lstm_state:
      result_state.append((one_state[0], one_state[1]))

    hidden_state = lstm_output
    c_visual = self._visual_attention([
        self._visual_attention_project_ctext(c_text),
        self._visual_attention_project_feature(action_panos),
    ])

    input_feature = tf.concat([hidden_state, c_text, c_visual], axis=2)
    class_outputs = self._mlp_layer(input_feature)
    class_outputs = tf.reshape(class_outputs, (batch_size, 2))
    predictions = tf.argmax(class_outputs, axis=-1)
    return (class_outputs, predictions) 
开发者ID:google-research,项目名称:valan,代码行数:40,代码来源:language_reward_net.py

示例8: accuracy

# 需要导入模块: from tensorflow.compat import v2 [as 别名]
# 或者: from tensorflow.compat.v2 import argmax [as 别名]
def accuracy(self, true_labels, predictions):
    """Takes in predictions and true labels and returns accuracy.

    Args:
      true_labels: a tensor of shape [batch_size, n_classes]
      predictions: a tensor of shape [batch_size, 1].

    Returns:
      loss: a scalar cross entropy loss.
    """
    true_labels = tf.keras.backend.argmax(true_labels)
    metric = tf.keras.metrics.Accuracy()
    accuracy = metric.update_state(true_labels, predictions)
    return accuracy 
开发者ID:google-research,项目名称:valan,代码行数:16,代码来源:language_reward_net.py


注:本文中的tensorflow.compat.v2.argmax方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。