当前位置: 首页>>代码示例>>Python>>正文


Python torch.ge方法代码示例

本文整理汇总了Python中torch.ge方法的典型用法代码示例。如果您正苦于以下问题:Python torch.ge方法的具体用法?Python torch.ge怎么用?Python torch.ge使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.ge方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _compute_xi

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def _compute_xi(self, s, aug, y):

        # find argmax of augmented scores
        _, y_star = torch.max(aug, 1)
        # xi_max: one-hot encoding of maximal indices
        xi_max = torch.eq(y_star[:, None], self._range).float()

        if MultiClassHingeLoss.smooth:
            # find smooth argmax of scores
            xi_smooth = nn.functional.softmax(s, dim=1)
            # compute for each sample whether it has a positive contribution to the loss
            losses = torch.sum(xi_smooth * aug, 1)
            mask_smooth = torch.ge(losses, 0).float()[:, None]
            # keep only smoothing for positive contributions
            xi = mask_smooth * xi_smooth + (1 - mask_smooth) * xi_max
        else:
            xi = xi_max

        return xi 
开发者ID:oval-group,项目名称:dfw,代码行数:21,代码来源:hinge.py

示例2: iouloss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def iouloss(input, target):
    smooth = 1.
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1. - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))
    # works for one binary pred and associated target
    # make byte tensors
    #pred = torch.ge(pred, 0.5) 
    #pred = (pred == 1)
    #mask = (gt == 0)
    #gt = (gt == 1)
    #union = (gt | pred)[mask].long().sum()
    #if not union:
    #    return 0.
    #else:
    #    intersection = (gt & pred)[mask].long().sum()
    #    return 1. - intersection / union 
开发者ID:saeedizadi,项目名称:binseg_pytoch,代码行数:22,代码来源:losses.py

示例3: training_step

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def training_step(self, batch, batch_idx):

        # 1. Forward pass:
        x, y = batch
        y_logits = self.forward(x)
        y_true = y.view((-1, 1)).type_as(x)
        y_bin = torch.ge(y_logits, 0)

        # 2. Compute loss & accuracy:
        train_loss = self.loss(y_logits, y_true)
        num_correct = torch.eq(y_bin.view(-1), y_true.view(-1)).sum()

        # 3. Outputs:
        tqdm_dict = {'train_loss': train_loss}
        output = OrderedDict({'loss': train_loss,
                              'num_correct': num_correct,
                              'log': tqdm_dict,
                              'progress_bar': tqdm_dict})

        return output 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:22,代码来源:computer_vision_fine_tuning.py

示例4: pruneWeights

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def pruneWeights(self, minWeight):
    """
    Prune all the weights whose absolute magnitude is less than minWeight
    :param minWeight: min weight to prune. If zero then no pruning
    :type minWeight: float
    """
    if minWeight == 0.0:
      return

    # Collect all weights
    weights = [v for k, v in self.named_parameters() if 'weight' in k]
    for w in weights:
      # Filter weights above threshold
      mask = torch.ge(torch.abs(w.data), minWeight)
      # Zero other weights
      w.data.mul_(mask.type(torch.float32)) 
开发者ID:numenta,项目名称:htmpapers,代码行数:18,代码来源:sparse_net.py

示例5: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def forward(self, words, frequent_tuning=False):
        if frequent_tuning and self.training:

            padding_mask = words.eq(0).long()

            # Fine-tuning - N the most frequent
            fine_tune_mask = torch.lt(words, self.threshold_index) * padding_mask.eq(
                0
            )  # < threshold_index
            fine_tune_words = words * fine_tune_mask.long()

            fine_tune_embedded = self.fine_tune_word_embedding(fine_tune_words)
            fine_tune_embedded = f.masked_zero(fine_tune_embedded, fine_tune_mask)

            # Fixed - under N frequent
            fixed_mask = torch.ge(words, self.threshold_index)  # >= threshold_index

            fixed_embedeed = self.fixed_word_embedding(words).detach()  # Fixed
            fixed_embedeed = f.masked_zero(fixed_embedeed, fixed_mask)

            embedded_words = fine_tune_embedded + fixed_embedeed
        else:
            embedded_words = self.fixed_word_embedding(words)

        return self.dropout(embedded_words) 
开发者ID:naver,项目名称:claf,代码行数:27,代码来源:frequent_word_embedding.py

示例6: create_negative_mask

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def create_negative_mask(
    labels: torch.Tensor, neg_label: int = -1
) -> torch.Tensor:
    """@TODO: Docs. Contribution is welcome."""
    neg_labels = torch.ge(labels, neg_label)
    pos_labels = ~neg_labels

    i_less_neg = pos_labels.unsqueeze(1).unsqueeze(2)
    j_less_neg = pos_labels.unsqueeze(1).unsqueeze(0)
    k_less_neg = pos_labels.unsqueeze(0).unsqueeze(0)

    anchors = labels.unsqueeze(1).unsqueeze(2)
    negatives = labels.unsqueeze(0).unsqueeze(0)
    k_equal = torch.eq(anchors + neg_label, negatives)

    k_less_or_equal = k_equal | k_less_neg
    mask = i_less_neg & j_less_neg & k_less_or_equal

    return mask 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:21,代码来源:functional.py

示例7: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50 * 4 * 4)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS 
开发者ID:AMLab-Amsterdam,项目名称:AttentionDeepMIL,代码行数:21,代码来源:model.py

示例8: get_accuracy_bin

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def get_accuracy_bin(scores, labels):
    preds = torch.ge(scores, 0).long()
    acc = torch.eq(preds, labels).float()
    return torch.sum(acc) / labels.nelement() 
开发者ID:wengong-jin,项目名称:hgraph2graph,代码行数:6,代码来源:nnutils.py

示例9: distance_bin

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def distance_bin(self, mention_distance):
        bins = torch.zeros(mention_distance.size()).byte().to(self.device)
        rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31], [32, 63], [64, 300]]
        for t, k in enumerate(rg):
            i, j = k[0], k[1]
            b = torch.LongTensor([i]).unsqueeze(-1).expand(mention_distance.size()).to(self.device)
            m1 = torch.ge(mention_distance, b)
            e = torch.LongTensor([j]).unsqueeze(-1).expand(mention_distance.size()).to(self.device)
            m2 = torch.le(mention_distance, e)
            bins = bins + (t + 1) * (m1 & m2)
        return bins.long() 
开发者ID:fastnlp,项目名称:fastNLP,代码行数:13,代码来源:model_re.py

示例10: detect_large

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def detect_large(x, k, tau, thresh):
    top, _ = x.topk(k + 1, 1)
    # switch to hard top-k if (k+1)-largest element is much smaller
    # than k-largest element
    hard = torch.ge(top[:, k - 1] - top[:, k], k * tau * math.log(thresh)).detach()
    smooth = hard.eq(0)
    return smooth, hard 
开发者ID:oval-group,项目名称:smooth-topk,代码行数:9,代码来源:utils.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def forward(self, inputs, targets):
        """
        Args:
        - inputs: feature matrix with shape (batch_size, feat_dim)
        - targets: ground truth labels with shape (num_classes)
        """
        n = inputs.size(0)
        
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        
        # compute accuracy
        correct = torch.ge(dist_an, dist_ap).sum().item()
        return loss, correct



        
        
# Adaptive weights 
开发者ID:mangye16,项目名称:Cross-Modal-Re-ID-baseline,代码行数:38,代码来源:loss.py

示例12: decode

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def decode(self, z, deterministic):
        '''

        Args:
            z: Tensor
                the tensor of latent z shape=[batch, nz]
            deterministic: boolean
                randomly sample of decode via argmaximizing probability

        Returns: Tensor
            the tensor of decoded x shape=[batch, *]

        '''
        H = W = 28
        batch_size, nz = z.size()

        # [batch, -1] --> [batch, fm, H, W]
        z = self.z_transform(z).view(batch_size, self.fm_latent, H, W)
        img = Variable(z.data.new(batch_size, self.nc, H, W).zero_(), volatile=True)
        # [batch, nc+fm, H, W]
        img = torch.cat([img, z], dim=1)
        for i in range(H):
            for j in range(W):
                # [batch, nc, H, W]
                recon_img = self.forward(img)
                # [batch, nc]
                img[:, :self.nc, i, j] = torch.ge(recon_img[:, :, i, j], 0.5).float() if deterministic else torch.bernoulli(recon_img[:, :, i, j])
                # img[:, :self.nc, i, j] = torch.bernoulli(recon_img[:, :, i, j])

        # [batch, nc, H, W]
        img_probs = self.forward(img)
        return img[:, :self.nc], img_probs 
开发者ID:jxhe,项目名称:vae-lagging-encoder,代码行数:34,代码来源:dec_pixelcnn_v2.py

示例13: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def __call__(self, x):
        """
        Args:
            img (PIL Image): Image to be converted to grayscale.

        Returns:
            PIL Image: Randomly grayscaled image.
        """
        threshold = torch.zeros_like(x)
        threshold.uniform_()
        
        
        return torch.ge(x, threshold).float() 
开发者ID:CW-Huang,项目名称:torchkit,代码行数:15,代码来源:transforms.py

示例14: ge

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def ge(t1, t2):
    """
    Element-wise rich greater than or equal comparison between values from operand t1 with respect to values of
    operand t2 (i.e. t1 >= t2), not commutative.
    Takes the first and second operand (scalar or tensor) whose elements are to be compared as argument.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand to be compared greater than or equal to second operand
    t2: tensor or scalar
       The second operand to be compared less than or equal to first operand

    Returns
    -------
    result: ht.DNDarray
        A uint8-tensor holding 1 for all elements in which values of t1 are greater than or equal tp values of t2,
        0 for all other elements

    Examples
    -------
    >>> import heat as ht
    >>> T1 = ht.float32([[1, 2],[3, 4]])
    >>> ht.ge(T1, 3.0)
    tensor([[0, 0],
            [1, 1]], dtype=torch.uint8)

    >>> T2 = ht.float32([[2, 2], [2, 2]])
    >>> ht.ge(T1, T2)
    tensor([[0, 1],
            [1, 1]], dtype=torch.uint8)
    """
    return operations.__binary_op(torch.ge, t1, t2) 
开发者ID:helmholtz-analytics,项目名称:heat,代码行数:35,代码来源:relational.py

示例15: get_weighted_clipped_pos_diffs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import ge [as 别名]
def get_weighted_clipped_pos_diffs(sorted_std_labels):
    #num_pos = torch.nonzero(sorted_std_labels).size(0)
    #print('sorted_std_labels', sorted_std_labels)
    num_pos = torch.gt(sorted_std_labels, 0).nonzero().size(0) # supporting the case of including '-1'

    #total_items = sorted_std_labels.size(0)
    total_items = torch.ge(sorted_std_labels, 0).nonzero().size(0)

    mat_diffs = torch.unsqueeze(sorted_std_labels, dim=1) - torch.unsqueeze(sorted_std_labels, dim=0)
    pos_diffs = torch.where(mat_diffs < 0, tor_zero, mat_diffs)
    clipped_pos_diffs = pos_diffs[0:num_pos, 0:total_items]
    #print('clipped_pos_diffs', clipped_pos_diffs)

    total_true_pairs = torch.nonzero(clipped_pos_diffs).size(0)

    r_discounts = torch.arange(total_items).type(tensor)
    r_discounts = torch.log2(2.0 + r_discounts)
    r_discounts = torch.unsqueeze(r_discounts, dim=0)

    c_discounts = torch.arange(num_pos).type(tensor)
    c_discounts = torch.log2(2.0 + c_discounts)
    c_discounts = torch.unsqueeze(c_discounts, dim=1)

    weighted_clipped_pos_diffs = clipped_pos_diffs / r_discounts
    weighted_clipped_pos_diffs = weighted_clipped_pos_diffs / c_discounts

    #print(weighted_clipped_pos_diffs)

    return weighted_clipped_pos_diffs, total_true_pairs, total_items 
开发者ID:pt-ranking,项目名称:pt-ranking.github.io,代码行数:31,代码来源:pair_sampling.py


注:本文中的torch.ge方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。