当前位置: 首页>>代码示例>>Python>>正文


Python torch.where方法代码示例

本文整理汇总了Python中torch.where方法的典型用法代码示例。如果您正苦于以下问题:Python torch.where方法的具体用法?Python torch.where怎么用?Python torch.where使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.where方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: smooth_l1_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def smooth_l1_loss(pred, target, beta=1.0):
    """Smooth L1 loss.

    Args:
        pred (torch.Tensor): The prediction.
        target (torch.Tensor): The learning target of the prediction.
        beta (float, optional): The threshold in the piecewise function.
            Defaults to 1.0.

    Returns:
        torch.Tensor: Calculated loss
    """
    assert beta > 0
    assert pred.size() == target.size() and target.numel() > 0
    diff = torch.abs(pred - target)
    loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
                       diff - 0.5 * beta)
    return loss 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:20,代码来源:smooth_l1_loss.py

示例2: _get_body

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def _get_body(self, x, target):
        cos_t = torch.gather(x, 1, target.unsqueeze(1))  # cos(theta_yi)
        if self.easy_margin:
            cond = torch.relu(cos_t)
        else:
            cond_v = cos_t - self.threshold
            cond = torch.relu(cond_v)
        cond = cond.bool()
        # Apex would convert FP16 to FP32 here
        # cos(theta_yi + m)
        new_zy = torch.cos(torch.acos(cos_t) + self.m).type(cos_t.dtype)
        if self.easy_margin:
            zy_keep = cos_t
        else:
            zy_keep = cos_t - self.mm  # (cos(theta_yi) - sin(pi - m)*m)
        new_zy = torch.where(cond, new_zy, zy_keep)
        diff = new_zy - cos_t  # cos(theta_yi + m) - cos(theta_yi)
        gt_one_hot = F.one_hot(target, num_classes=self.classes)
        body = gt_one_hot * diff
        return body 
开发者ID:PistonY,项目名称:torch-toolbox,代码行数:22,代码来源:loss.py

示例3: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def forward(self, x, target):
        similarity_matrix = x @ x.T  # need gard here
        label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
        negative_matrix = label_matrix.logical_not()
        positive_matrix = label_matrix.fill_diagonal_(False)

        sp = torch.where(positive_matrix, similarity_matrix,
                         torch.zeros_like(similarity_matrix))
        sn = torch.where(negative_matrix, similarity_matrix,
                         torch.zeros_like(similarity_matrix))

        ap = torch.clamp_min(1 + self.m - sp.detach(), min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        logit_p = -self.gamma * ap * (sp - self.dp)
        logit_n = self.gamma * an * (sn - self.dn)

        logit_p = torch.where(positive_matrix, logit_p,
                              torch.zeros_like(logit_p))
        logit_n = torch.where(negative_matrix, logit_n,
                              torch.zeros_like(logit_n))

        loss = F.softplus(torch.logsumexp(logit_p, dim=1) +
                          torch.logsumexp(logit_n, dim=1)).mean()
        return loss 
开发者ID:PistonY,项目名称:torch-toolbox,代码行数:27,代码来源:loss.py

示例4: smooth_l1_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def smooth_l1_loss(input, target, beta=1. / 9, size_average=True):
    """
    very similar to the smooth_l1_loss from pytorch, but with
    the extra beta parameter

    Modified according to detectron2's fvcore,
    refer to https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/smooth_l1_loss.py
    """
    if beta < 1e-5:
        # if beta == 0, then torch.where will result in nan gradients when
        # the chain rule is applied due to pytorch implementation details
        # (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
        # zeros, rather than "no gradient"). To avoid this issue, we define
        # small values of beta to be exactly l1 loss.
        loss = torch.abs(input - target)
    else:
        n = torch.abs(input - target)
        cond = n < beta
        loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)

    if size_average:
        return loss.mean()
    return loss.sum() 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:25,代码来源:smooth_l1_loss.py

示例5: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def forward(self, inputs, target, size_average=True):

        n = torch.abs(inputs -target)
        with torch.no_grad():
            if torch.isnan(n.var(dim=0)).sum().item() == 0:
                self.running_mean = self.running_mean.to(n.device)
                self.running_mean *= (1 - self.momentum)
                self.running_mean += (self.momentum * n.mean(dim=0))
                self.running_var = self.running_var.to(n.device)
                self.running_var *= (1 - self.momentum)
                self.running_var += (self.momentum * n.var(dim=0))


        beta = (self.running_mean - self.running_var)
        beta = beta.clamp(max=self.beta, min=1e-3)

        beta = beta.view(-1, self.num_features).to(n.device)
        cond = n < beta.expand_as(n)
        loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
        if size_average:
            return loss.mean()
        return loss.sum() 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:24,代码来源:adjust_smooth_l1_loss.py

示例6: prepare_boxlist

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def prepare_boxlist(self, boxes, scores, image_shape):
        """
        Returns BoxList from `boxes` and adds probability scores information
        as an extra field
        `boxes` has shape (#detections, 4 * #classes), where each row represents
        a list of predicted bounding boxes for each of the object classes in the
        dataset (including the background class). The detections in each row
        originate from the same object proposal.
        `scores` has shape (#detection, #classes), where each row represents a list
        of object detection confidence scores for each of the object classes in the
        dataset (including the background class). `scores[i, j]`` corresponds to the
        box at `boxes[i, j * 4:(j + 1) * 4]`.
        """
        boxes = boxes.reshape(-1, 4)
        scores = scores.reshape(-1)
        boxlist = BoxList(boxes, image_shape, mode="xyxy")
        boxlist.add_field("scores", scores)
        return boxlist 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:20,代码来源:inference.py

示例7: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def __init__(self, dim=-1, k=None):
        """1.5-entmax: normalizing sparse transform (a la softmax).

        Solves the optimization problem:

            max_p <x, p> - H_1.5(p)    s.t.    p >= 0, sum(p) == 1.

        where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5.

        Parameters
        ----------
        dim : int
            The dimension along which to apply 1.5-entmax.

        k : int or None
            number of largest elements to partial-sort over. For optimal
            performance, should be slightly bigger than the expected number of
            nonzeros in the solution. If the solution is more than k-sparse,
            this function is recursively called with a 2*k schedule.
            If `None`, full sorting is performed from the beginning.
        """
        self.dim = dim
        self.k = k
        super(Entmax15, self).__init__() 
开发者ID:deep-spin,项目名称:entmax,代码行数:26,代码来源:activations.py

示例8: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        f_1 = torch.matmul(h, self.a1)
        f_2 = torch.matmul(h, self.a2)
        e = self.leakyrelu(f_1 + f_2.transpose(0,1))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime 
开发者ID:meliketoy,项目名称:graph-cnn.pytorch,代码行数:20,代码来源:layers.py

示例9: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def forward(self, inputs, labels):
        cos_th = F.linear(inputs, F.normalize(self.weight))
        cos_th = cos_th.clamp(-1, 1)
        sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2))
        cos_th_m = cos_th * self.cos_m - sin_th * self.sin_m
        cos_th_m = torch.where(cos_th > self.th, cos_th_m, cos_th - self.mm)

        cond_v = cos_th - self.th
        cond = cond_v <= 0
        cos_th_m[cond] = (cos_th - self.mm)[cond]

        if labels.dim() == 1:
            labels = labels.unsqueeze(-1)
        onehot = torch.zeros(cos_th.size()).cuda()
        onehot.scatter_(1, labels, 1)
        outputs = onehot * cos_th_m + (1.0 - onehot) * cos_th
        outputs = outputs * self.s
        return outputs 
开发者ID:pudae,项目名称:kaggle-humpback,代码行数:20,代码来源:identifier.py

示例10: glu

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def glu(input, dim=-1):
    # type: (Tensor, int) -> Tensor
    r"""
    glu(input, dim=-1) -> Tensor

    The gated linear unit. Computes:

    .. math ::

        H = A \times \sigma(B)

    where `input` is split in half along `dim` to form `A` and `B`.

    See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_.

    Args:
        input (Tensor): input tensor
        dim (int): dimension on which to split the input
    """
    if input.dim() == 0:
        raise RuntimeError("glu does not suppport scalars because halving size must be even")
    return torch._C._nn.glu(input, dim) 
开发者ID:MagicChuyi,项目名称:SlowFast-Network-pytorch,代码行数:24,代码来源:functional.py

示例11: linear

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def linear(input, weight, bias=None):
    # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    Shape:

        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += torch.jit._unwrap_optional(bias)
        ret = output
    return ret 
开发者ID:MagicChuyi,项目名称:SlowFast-Network-pytorch,代码行数:24,代码来源:functional.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def forward(self, g, h, weights):
        """
        g : graph
        h : node features
        weights : scalar edge weights
        """
        h_src, h_dst = h
        with g.local_scope():
            g.srcdata['n'] = self.act(self.Q(self.dropout(h_src)))
            g.edata['w'] = weights.float()
            g.update_all(fn.u_mul_e('n', 'w', 'm'), fn.sum('m', 'n'))
            g.update_all(fn.copy_e('w', 'm'), fn.sum('m', 'ws'))
            n = g.dstdata['n']
            ws = g.dstdata['ws'].unsqueeze(1).clamp(min=1)
            z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))
            z_norm = z.norm(2, 1, keepdim=True)
            z_norm = torch.where(z_norm == 0, torch.tensor(1.).to(z_norm), z_norm)
            z = z / z_norm
            return z 
开发者ID:dmlc,项目名称:dgl,代码行数:21,代码来源:layers.py

示例13: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def forward(ctx, input, target):
        """
        input (FloatTensor): n x num_classes
        target (LongTensor): n, the indices of the target classes
        """
        input_batch, classes = input.size()
        target_batch = target.size(0)
        aeq(input_batch, target_batch)

        z_k = input.gather(1, target.unsqueeze(1)).squeeze()
        tau_z, support_size = _threshold_and_support(input, dim=1)
        support = input > tau_z
        x = torch.where(
            support, input**2 - tau_z**2,
            torch.tensor(0.0, device=input.device)
        ).sum(dim=1)
        ctx.save_for_backward(input, target, tau_z)
        # clamping necessary because of numerical errors: loss should be lower
        # bounded by zero, but negative values near zero are possible without
        # the clamp
        return torch.clamp(x / 2 - z_k + 0.5, min=0.0) 
开发者ID:lizekang,项目名称:ITDD,代码行数:23,代码来源:sparse_losses.py

示例14: balanced_l1_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def balanced_l1_loss(pred,
                     target,
                     beta=1.0,
                     alpha=0.5,
                     gamma=1.5,
                     reduction='mean'):
    assert beta > 0
    assert pred.size() == target.size() and target.numel() > 0

    diff = torch.abs(pred - target)
    b = np.e**(gamma / alpha) - 1
    loss = torch.where(
        diff < beta, alpha / b *
        (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
        gamma * diff + gamma / b - alpha * beta)

    return loss 
开发者ID:xvjiarui,项目名称:GCNet,代码行数:19,代码来源:balanced_l1_loss.py

示例15: _calc_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import where [as 别名]
def _calc_loss(self, errors):
        """Calculates the losses given the batch-wise 'td-errors'

        This is either squared-error or huber loss
        """
        if self.loss_mode == "mse":
            return errors.pow(2)
        elif self.loss_mode == "huber":
            # Huber loss element-wise
            abs_errors = torch.abs(errors)
            return torch.where(
                abs_errors <= self.huber_kappa,
                0.5 * errors.pow(2),
                self.huber_kappa * (abs_errors - (0.5 * self.huber_kappa)))
        else:
            assert(False), \
                f"{self.loss_mode} is not a valid q-learning loss mode" 
开发者ID:opherlieber,项目名称:rltime,代码行数:19,代码来源:dqn.py


注:本文中的torch.where方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。