当前位置: 首页>>代码示例>>Python>>正文


Python torch.index_select方法代码示例

本文整理汇总了Python中torch.index_select方法的典型用法代码示例。如果您正苦于以下问题:Python torch.index_select方法的具体用法?Python torch.index_select怎么用?Python torch.index_select使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.index_select方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: pose_inv_full

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def pose_inv_full(pose):
  '''
  param pose: N x 6
  Inverse the 2x3 transformer matrix.
  '''
  N, _ = pose.size()
  b = pose.view(N, 2, 3)[:, :, 2:]
  # A^{-1}
  # Calculate determinant
  determinant = (pose[:, 0] * pose[:, 4] - pose[:, 1] * pose[:, 3] + 1e-8).view(N, 1)
  indices = Variable(torch.LongTensor([4, 1, 3, 0]).cuda())
  scale = Variable(torch.Tensor([1, -1, -1, 1]).cuda())
  A_inv = torch.index_select(pose, 1, indices) * scale / determinant
  A_inv = A_inv.view(N, 2, 2)
  # b' = - A^{-1} b
  b_inv = - A_inv.matmul(b).view(N, 2, 1)
  transformer_inv = torch.cat([A_inv, b_inv], dim=2)
  return transformer_inv 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:20,代码来源:DDPAE_utils.py

示例2: gen_base_anchors

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def gen_base_anchors(self):
        """Generate base anchors.

        Returns:
            list(torch.Tensor): Base anchors of a feature grid in multiple
                feature levels.
        """
        multi_level_base_anchors = []
        for i, base_size in enumerate(self.base_sizes):
            base_anchors = self.gen_single_level_base_anchors(
                base_size,
                scales=self.scales[i],
                ratios=self.ratios[i],
                center=self.centers[i])
            indices = list(range(len(self.ratios[i])))
            indices.insert(1, len(indices))
            base_anchors = torch.index_select(base_anchors, 0,
                                              torch.LongTensor(indices))
            multi_level_base_anchors.append(base_anchors)
        return multi_level_base_anchors 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:22,代码来源:anchor_generator.py

示例3: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]
        # And now we add the zero padding
        if Nin < self.N:
            x = torch.cat((x,
                           torch.zeros(B, F, self.N-Nin)\
                                   .type(x.dtype).to(x.device)
                          ), dim = 2)
        # Compute the filter output
        u = LSIGF(self.weight, self.S, x, self.bias)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
开发者ID:alelab-upenn,项目名称:graph-neural-networks,代码行数:22,代码来源:graphML.py

示例4: sample

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def sample(verts, faces, num=10000, ret_choice = False):
    dist_uni = torch.distributions.Uniform(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
    x1,x2,x3 = torch.split(torch.index_select(verts, 0, faces[:,0]) - torch.index_select(verts, 0, faces[:,1]), 1, dim = 1)
    y1,y2,y3 = torch.split(torch.index_select(verts, 0, faces[:,1]) - torch.index_select(verts, 0, faces[:,2]), 1, dim = 1)
    a = (x2*y3 - x3*y2)**2
    b = (x3*y1 - x1*y3)**2
    c = (x1*y2 - x2*y1)**2
    Areas = torch.sqrt(a+b+c)/2
    Areas = Areas / torch.sum(Areas)
    cat_dist = torch.distributions.Categorical(Areas.view(-1))
    choices = cat_dist.sample_n(num)
    select_faces = faces[choices]
    xs = torch.index_select(verts, 0,select_faces[:,0])
    ys = torch.index_select(verts, 0,select_faces[:,1])
    zs = torch.index_select(verts, 0,select_faces[:,2])
    u = torch.sqrt(dist_uni.sample_n(num))
    v = dist_uni.sample_n(num)
    points = (1- u)*xs + (u*(1-v))*ys + u*v*zs
    if ret_choice:
        return points, choices
    else:
        return points 
开发者ID:nv-tlabs,项目名称:DIB-R,代码行数:24,代码来源:check_chamfer.py

示例5: visualize

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def visualize(args):
  saved_path = constant.EXP_ROOT
  model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal])
  model.cuda()
  model.eval()
  model.load_state_dict(torch.load(saved_path + '/' + args.model_id + '_best.pt')["state_dict"])

  label2id = constant.ANS2ID_DICT["open"] 
  visualize = SummaryWriter("../visualize/" + args.model_id)
  # label_list = ["person", "leader", "president", "politician", "organization", "company", "athlete","adult",  "male",  "man", "television_program", "event"]
  label_list = list(label2id.keys())
  ids = [label2id[_] for _ in label_list]
  if args.gcn:
    # connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
    connection_matrix = model.decoder.label_matrix + model.decoder.weight * model.decoder.affinity
    label_vectors = model.decoder.transform(connection_matrix.mm(model.decoder.linear.weight) / connection_matrix.sum(1, keepdim=True))
  else:
    label_vectors = model.decoder.linear.weight.data

  interested_vectors = torch.index_select(label_vectors, 0, torch.tensor(ids).to(torch.device("cuda")))
  visualize.add_embedding(interested_vectors, metadata=label_list, label_img=None) 
开发者ID:xwhan,项目名称:Extremely-Fine-Grained-Entity-Typing,代码行数:23,代码来源:main.py

示例6: pytorch_tile

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def pytorch_tile(tensor, n_tile, dim=0):
    """
    Tile utility as there is not `torch.tile`.
    Args:
        tensor (torch.Tensor): Tensor to tile.
        n_tile (int): Num tiles.
        dim (int): Dim to tile.

    Returns:
        torch.Tensor: Tiled tensor.
    """
    if isinstance(n_tile, torch.Size):
        n_tile = n_tile[0]
    init_dim = tensor.size(dim)
    repeat_idx = [1] * tensor.dim()
    repeat_idx[dim] = n_tile
    tensor = tensor.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(tensor, dim, order_index)


# TODO remove when we have handled pytorch placeholder inference better. 
开发者ID:rlgraph,项目名称:rlgraph,代码行数:24,代码来源:pytorch_util.py

示例7: symmetricImagePad

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def symmetricImagePad(self, image_batch, padding_factor):
        b, c, h, w = image_batch.size()
        pad_h, pad_w = int(h*padding_factor), int(w*padding_factor)
        idx_pad_left = torch.LongTensor(range(pad_w-1,-1,-1))
        idx_pad_right = torch.LongTensor(range(w-1,w-pad_w-1,-1))
        idx_pad_top = torch.LongTensor(range(pad_h-1,-1,-1))
        idx_pad_bottom = torch.LongTensor(range(h-1,h-pad_h-1,-1))
        if self.use_cuda:
                idx_pad_left = idx_pad_left.cuda()
                idx_pad_right = idx_pad_right.cuda()
                idx_pad_top = idx_pad_top.cuda()
                idx_pad_bottom = idx_pad_bottom.cuda()
        image_batch = torch.cat((image_batch.index_select(3,idx_pad_left),image_batch,
                                 image_batch.index_select(3,idx_pad_right)),3)
        image_batch = torch.cat((image_batch.index_select(2,idx_pad_top),image_batch,
                                 image_batch.index_select(2,idx_pad_bottom)),2)
        return image_batch 
开发者ID:ignacio-rocco,项目名称:weakalign,代码行数:19,代码来源:transformation.py

示例8: project

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def project(self, label, lin_indices_3d, lin_indices_2d, num_points):
        """
        forward pass of backprojection for 2d features onto 3d points

        :param label: image features (shape: (num_input_channels, proj_image_dims[0], proj_image_dims[1]))
        :param lin_indices_3d: point indices from projection (shape: (num_input_channels, num_points_sample))
        :param lin_indices_2d: pixel indices from projection (shape: (num_input_channels, num_points_sample))
        :param num_points: number of points in one sample
        :return: array of points in sample with projected features (shape: (num_input_channels, num_points))
        """
        
        num_label_ft = 1 if len(label.shape) == 2 else label.shape[0] # = num_input_channels

        output = label.new(num_label_ft, num_points).fill_(0)
        num_ind = lin_indices_3d[0]
        if num_ind > 0:
            # selects values from image_features at indices given by lin_indices_2d
            vals = torch.index_select(label.view(num_label_ft, -1), 1, lin_indices_2d[1:1+num_ind])
            output.view(num_label_ft, -1)[:, lin_indices_3d[1:1+num_ind]] = vals
        
        return output


# Inherit from Function 
开发者ID:daveredrum,项目名称:Pointnet2.ScanNet,代码行数:26,代码来源:projection.py

示例9: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def forward(ctx, label, lin_indices_3d, lin_indices_2d, num_points):
        """
        forward pass of backprojection for 2d features onto 3d points

        :param label: image features (shape: (num_input_channels, proj_image_dims[0], proj_image_dims[1]))
        :param lin_indices_3d: point indices from projection (shape: (num_input_channels, num_points_sample))
        :param lin_indices_2d: pixel indices from projection (shape: (num_input_channels, num_points_sample))
        :param num_points: number of points in one sample
        :return: array of points in sample with projected features (shape: (num_input_channels, num_points))
        """
        # ctx.save_for_backward(lin_indices_3d, lin_indices_2d)
        num_label_ft = 1 if len(label.shape) == 2 else label.shape[0] # = num_input_channels

        output = label.new(num_label_ft, num_points).fill_(0)
        num_ind = lin_indices_3d[0]
        if num_ind > 0:
            # selects values from image_features at indices given by lin_indices_2d
            vals = torch.index_select(label.view(num_label_ft, -1), 1, lin_indices_2d[1:1+num_ind])
            output.view(num_label_ft, -1)[:, lin_indices_3d[1:1+num_ind]] = vals
        return output 
开发者ID:daveredrum,项目名称:Pointnet2.ScanNet,代码行数:22,代码来源:projection.py

示例10: _feature_distance

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def _feature_distance(self, feaMat):
        probe_feature = torch.index_select(feaMat, dim=0, index=torch.from_numpy(self.probe_index).long().cuda())
        gallery_feature = torch.index_select(feaMat, dim=0, index=torch.from_numpy(self.gallery_index).long().cuda())

        idx = 0
        while idx + self.probe_dst_max < self.probe_num:
            tmp_probe_fea = probe_feature[idx:idx+self.probe_dst_max]
            dst_pg = self._feature_distance_mini(tmp_probe_fea, gallery_feature)
            self.distMat[idx:idx+self.probe_dst_max] += dst_pg
            idx += self.probe_dst_max
        tmp_probe_fea = probe_feature[idx:self.probe_num]
        dst_pg = self._feature_distance_mini(tmp_probe_fea, gallery_feature)
        self.distMat[idx:self.probe_num] += dst_pg

        for i_p, p in enumerate(self.probe_index):
            for i_g, g in enumerate(self.gallery_index):
                if self.test_info[p, 0] != self.test_info[g, 0]:
                    self.avgDiff = self.avgDiff + self.distMat[i_p, i_g]
                    self.avgDiffCount = self.avgDiffCount + 1
                elif p != g:
                    self.avgSame = self.avgSame + self.distMat[i_p, i_g]
                    self.avgSameCount = self.avgSameCount + 1 
开发者ID:yolomax,项目名称:person-reid-lib,代码行数:24,代码来源:eval_base.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def forward(self, pred, target, weights):
        mseloss = torch.sum(weights * torch.pow((pred - target), 2))
        pred = pred.data
        target = target.data
        #
        ones_idx_set = (target == 1).nonzero()
        zeros_idx_set = (target == 0).nonzero()
        # zeros_idx_set = (target == -1).nonzero()
        
        ones_set = torch.index_select(pred, 1, ones_idx_set[:, 1])
        zeros_set = torch.index_select(pred, 1, zeros_idx_set[:, 1])
        
        repeat_ones = ones_set.repeat(1, zeros_set.shape[1])
        repeat_zeros_set = torch.transpose(zeros_set.repeat(ones_set.shape[1], 1), 0, 1).clone()
        repeat_zeros = repeat_zeros_set.view(1, -1)
        difference_val = -(repeat_ones - repeat_zeros)
        exp_val = torch.exp(difference_val)
        exp_loss = torch.sum(exp_val)
        normalized_loss = exp_loss / (zeros_set.shape[1] * ones_set.shape[1])
        set_loss = Variable(torch.FloatTensor([labmda * normalized_loss]), requires_grad=True)
        if use_cuda:
            set_loss = set_loss.cuda()
        loss = mseloss + set_loss
        #loss = mseloss
        return loss 
开发者ID:HaojiHu,项目名称:Sets2Sets,代码行数:27,代码来源:Sets2Sets.py

示例12: sort_mention

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens):
        # 排序记录,高分段在前面
        mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True)
        preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens))
        mention_ids = mention_ids[0:preserve_mention_num]
        mention_score = mention_score[0:preserve_mention_num]

        mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0,
                                                                                            index=mention_ids)  # [lamda*word_num]
        mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0,
                                                                                        index=mention_ids)  # [lamda*word_num]
        mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0)  # [lamda*word_num,emb]
        assert mention_score.shape[0] == preserve_mention_num
        assert mention_start_tensor.shape[0] == preserve_mention_num
        assert mention_end_tensor.shape[0] == preserve_mention_num
        assert mention_emb.shape[0] == preserve_mention_num
        # TODO 不交叉没做处理

        # 对start进行再排序,实际位置在前面
        # TODO 这里只考虑了start没有考虑end
        mention_start_tensor, temp_index = torch.sort(mention_start_tensor)
        mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index)
        mention_emb = mention_emb.index_select(dim=0, index=temp_index)
        mention_score = mention_score.index_select(dim=0, index=temp_index)
        return mention_start_tensor, mention_end_tensor, mention_score, mention_emb 
开发者ID:fastnlp,项目名称:fastNLP,代码行数:27,代码来源:model_re.py

示例13: expand_pose

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def expand_pose(pose):
  '''
  param pose: N x 3
  Takes 3-dimensional vectors, and massages them into 2x3 affine transformation matrices:
  [s,x,y] -> [[s,0,x],
              [0,s,y]]
  '''
  n = pose.size(0)
  expansion_indices = Variable(torch.LongTensor([1, 0, 2, 0, 1, 3]).cuda(), requires_grad=False)
  zeros = Variable(torch.zeros(n, 1).cuda(), requires_grad=False)
  out = torch.cat([zeros, pose], dim=1)
  return torch.index_select(out, 1, expansion_indices).view(n, 2, 3) 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:14,代码来源:DDPAE_utils.py

示例14: get_input

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def get_input(i,gnn_input):
    length = len(edge_matrix[i])
    select_index = torch.tensor(edge_matrix[i])
    select_index = select_index.to(device)
    #print(select_index)
    input_new = torch.index_select(gnn_input,1,select_index)
    return input_new,length 
开发者ID:HaiyangLiu1997,项目名称:Pytorch-Networks,代码行数:9,代码来源:GNNlikeCNN2015.py

示例15: get_input_paf

# 需要导入模块: import torch [as 别名]
# 或者: from torch import index_select [as 别名]
def get_input_paf(i,gnn_input):
    select_index = torch.tensor(edge_matrix_paf[i])
    select_index = select_index.to(device)
    select_index_1 = torch.tensor([2*i+19,2*i+20])
    select_index_1 = select_index_1.to(device)
    input_paf = torch.index_select(gnn_input,1,select_index_1)
    #print(input_paf.size())
    input_new = torch.index_select(gnn_input,1,select_index)
    #print(input_new.size())
    input_final = torch.cat([input_new,input_paf],1)
    #print(input_final.size())
    return input_final 
开发者ID:HaiyangLiu1997,项目名称:Pytorch-Networks,代码行数:14,代码来源:GNNlikeCNN2015.py


注:本文中的torch.index_select方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。