当前位置: 首页>>代码示例>>Python>>正文


Python functions.argmax方法代码示例

本文整理汇总了Python中chainer.functions.argmax方法的典型用法代码示例。如果您正苦于以下问题:Python functions.argmax方法的具体用法?Python functions.argmax怎么用?Python functions.argmax使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.functions的用法示例。


在下文中一共展示了functions.argmax方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __call__

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def __call__(self, x):
        heatmap = x
        vector_dim = 2
        batch = heatmap.shape[0]
        channels = heatmap.shape[1]
        in_size = x.shape[2:]
        heatmap_vector = F.reshape(heatmap, shape=(batch, channels, -1))
        indices = F.cast(F.expand_dims(F.argmax(heatmap_vector, axis=vector_dim), axis=vector_dim), np.float32)
        scores = F.max(heatmap_vector, axis=vector_dim, keepdims=True)
        scores_mask = (scores.array > 0.0).astype(np.float32)
        pts_x = (indices.array % in_size[1]) * scores_mask
        pts_y = (indices.array // in_size[1]) * scores_mask
        pts = F.concat((pts_x, pts_y, scores), axis=vector_dim).array
        for b in range(batch):
            for k in range(channels):
                hm = heatmap[b, k, :, :].array
                px = int(pts_x[b, k])
                py = int(pts_y[b, k])
                if (0 < px < in_size[1] - 1) and (0 < py < in_size[0] - 1):
                    pts[b, k, 0] += np.sign(hm[py, px + 1] - hm[py, px - 1]) * 0.25
                    pts[b, k, 1] += np.sign(hm[py + 1, px] - hm[py - 1, px]) * 0.25
        return pts 
开发者ID:osmr,项目名称:imgclsmob,代码行数:24,代码来源:common.py

示例2: update

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def update(Q, target_Q, opt, samples, gamma=0.99, target_type='double_dqn'):
    """Update a Q-function with given samples and a target Q-function."""
    dtype = chainer.get_dtype()
    xp = Q.xp
    obs = xp.asarray([sample[0] for sample in samples], dtype=dtype)
    action = xp.asarray([sample[1] for sample in samples], dtype=np.int32)
    reward = xp.asarray([sample[2] for sample in samples], dtype=dtype)
    done = xp.asarray([sample[3] for sample in samples], dtype=dtype)
    obs_next = xp.asarray([sample[4] for sample in samples], dtype=dtype)
    # Predicted values: Q(s,a)
    y = F.select_item(Q(obs), action)
    # Target values: r + gamma * max_b Q(s',b)
    with chainer.no_backprop_mode():
        if target_type == 'dqn':
            next_q = F.max(target_Q(obs_next), axis=1)
        elif target_type == 'double_dqn':
            next_q = F.select_item(target_Q(obs_next),
                                   F.argmax(Q(obs_next), axis=1))
        else:
            raise ValueError('Unsupported target_type: {}'.format(target_type))
        target = reward + gamma * (1 - done) * next_q
    loss = mean_clipped_loss(y, target)
    Q.cleargrads()
    loss.backward()
    opt.update() 
开发者ID:chainer,项目名称:chainer,代码行数:27,代码来源:dqn_cartpole.py

示例3: determine_best_prediction_indices

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def determine_best_prediction_indices(self, raw_classification_result):
        distribution = F.softmax(raw_classification_result, axis=3)
        predicted_classes = F.argmax(distribution, axis=3)

        scores = []
        for i, image in enumerate(predicted_classes):
            means = []
            for j, image_variant in enumerate(image):
                num_predictions = len([prediction for prediction in image_variant if prediction.array != self.blank_label_class])
                probs = F.max(distribution[i, j, :num_predictions], axis=1).array
                if len(probs) == 0:
                    means.append(self.xp.array(0, dtype=probs.dtype))
                means.append(self.xp.mean(probs))
            means = self.xp.stack(means, axis=0)
            scores.append(means)
        scores = self.xp.stack(scores, axis=0)
        # scores = F.sum(F.max(F.softmax(raw_classification_result, axis=3), axis=3), axis=2)
        best_indices = F.argmax(scores, axis=1).array
        best_indices = best_indices[:, self.xp.newaxis]
        return best_indices, scores 
开发者ID:Bartzi,项目名称:kiss,代码行数:22,代码来源:text_recognition_evaluator.py

示例4: forward

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def forward(self, v1):
        return F.argmax(v1) 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:4,代码来源:ArgMinMax.py

示例5: main

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def main():
    np.random.seed(314)
    a1 = np.random.rand(6, 2, 3).astype(np.float32)

    testtools.generate_testcase(ArgMin(), [a1], subname='argmin')
    testtools.generate_testcase(ArgMinNumpy(), [a1], subname='argmin_np')
    testtools.generate_testcase(ArgMinAxis(), [a1], subname='argmin_axis')
    testtools.generate_testcase(ArgMinAxisNumpy(), [a1], subname='argmin_axis_np')

    testtools.generate_testcase(ArgMax(), [a1], subname='argmax')
    testtools.generate_testcase(ArgMaxNumpy(), [a1], subname='argmax_np')
    testtools.generate_testcase(ArgMaxAxis(), [a1], subname='argmax_axis')
    testtools.generate_testcase(ArgMaxAxisNumpy(), [a1], subname='argmax_axis_np') 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:15,代码来源:ArgMinMax.py

示例6: get_greedy_action

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def get_greedy_action(Q, obs):
    """Get a greedy action wrt a given Q-function."""
    dtype = chainer.get_dtype()
    obs = Q.xp.asarray(obs[None], dtype=dtype)
    with chainer.no_backprop_mode():
        q = Q(obs).array[0]
    return int(q.argmax()) 
开发者ID:chainer,项目名称:chainer,代码行数:9,代码来源:dqn_cartpole.py

示例7: decode_prediction

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def decode_prediction(self, x):
        """
            helper function for greedy decoding
            :param x: the output of the classifier
            :return: the most probable class index
        """
        return F.argmax(F.softmax(x, axis=2), axis=2) 
开发者ID:chainer,项目名称:models,代码行数:9,代码来源:copy_transformer.py

示例8: argmax

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def argmax(self, hs_pad):
        """argmax of frame activations

        :param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs)
        :return: argmax applied 2d tensor (B, Tmax)
        :rtype: chainer.Variable
        """
        return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1) 
开发者ID:espnet,项目名称:espnet,代码行数:10,代码来源:ctc.py

示例9: argmax

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def argmax(self, hs_pad):
        """Argmax of frame activations.

        :param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs)
        :return: argmax applied 2d tensor (B, Tmax)
        :rtype: chainer.Variable.
        """
        return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1) 
开发者ID:espnet,项目名称:espnet,代码行数:10,代码来源:ctc.py

示例10: decode_prediction

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def decode_prediction(self, prediction):
        return F.argmax(F.softmax(prediction, axis=2), axis=2) 
开发者ID:Bartzi,项目名称:kiss,代码行数:4,代码来源:transformer_recognizer.py

示例11: decode_prediction

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def decode_prediction(self, prediction):
        words = []
        for box in F.separate(prediction, axis=1):
            word = [F.argmax(F.softmax(character), axis=1) for character in F.separate(box, axis=1)]
            words.append(F.stack(word, axis=1))

        return F.stack(words, axis=1) 
开发者ID:Bartzi,项目名称:kiss,代码行数:9,代码来源:text_recognizer.py

示例12: __call__

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def __call__(self, x_img, t_detection, **others):
        # Alexnet
        h = F.relu(self.conv1(x_img))  # conv1
        h = F.max_pooling_2d(h, 3, stride=2, pad=0)  # max1
        h = F.local_response_normalization(h)  # norm1
        h = F.relu(self.conv2(h))  # conv2
        h = F.max_pooling_2d(h, 3, stride=2, pad=0)  # max2
        h = F.local_response_normalization(h)  # norm2
        h = F.relu(self.conv3(h))  # conv3
        h = F.relu(self.conv4(h))  # conv4
        h = F.relu(self.conv5(h))  # conv5
        h = F.max_pooling_2d(h, 3, stride=2, pad=0)  # pool5

        h = F.dropout(F.relu(self.fc6(h)), train=self.train)  # fc6
        h = F.dropout(F.relu(self.fc7(h)), train=self.train)  # fc7
        h_detection = self.fc8(h)  # fc8

        # Loss
        loss = F.softmax_cross_entropy(h_detection, t_detection)

        chainer.report({'loss': loss}, self)

        # Prediction
        h_detection = F.argmax(h_detection, axis=1)

        # Report results
        predict_data = {'img': x_img, 'detection': h_detection}
        teacher_data = {'img': x_img, 'detection': t_detection}
        chainer.report({'predict': predict_data}, self)
        chainer.report({'teacher': teacher_data}, self)

        # Report layer weights
        chainer.report({'conv1_w': {'weights': self.conv1.W},
                        'conv2_w': {'weights': self.conv2.W},
                        'conv3_w': {'weights': self.conv3.W},
                        'conv4_w': {'weights': self.conv4.W},
                        'conv5_w': {'weights': self.conv5.W}}, self)

        return loss 
开发者ID:takiyu,项目名称:hyperface,代码行数:41,代码来源:models.py

示例13: set_init_grad

# 需要导入模块: from chainer import functions [as 别名]
# 或者: from chainer.functions import argmax [as 别名]
def set_init_grad(self, var, label):
		var.grad = self.xp.zeros_like(var.data)
		if label is None:
			class_id = F.argmax(var).data
			var.grad[0][class_id] = 1

		else:
			class_id = self.xp.random.choice(label, 1)
			var.grad[0][class_id] = 1
		return class_id 
开发者ID:alokwhitewolf,项目名称:Guided-Attention-Inference-Network,代码行数:12,代码来源:GAIN.py


注:本文中的chainer.functions.argmax方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。