当前位置: 首页>>代码示例>>Python>>正文


Python torch.split方法代码示例

本文整理汇总了Python中torch.split方法的典型用法代码示例。如果您正苦于以下问题:Python torch.split方法的具体用法?Python torch.split怎么用?Python torch.split使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.split方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: node_forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def node_forward(self, inputs, child_c, child_h):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        iou = self.ioux(inputs) + self.iouh(child_h_sum)
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)

        f = F.sigmoid(
            self.fh(child_h) +
            self.fx(inputs).repeat(len(child_h), 1)
        )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, F.tanh(c))
        return c, h 
开发者ID:dasguptar,项目名称:treelstm.pytorch,代码行数:18,代码来源:model.py

示例2: sampling_decode

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def sampling_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index):
        vars = torch.split(pz_dec_outs, 1, dim=1), torch.split(u_enc_out, 1, dim=1), torch.split(
            m_tm1, 1, dim=1), torch.split(last_hidden, 1, dim=1), torch.split(degree_input, 1, dim=0)
        batch_loss = []

        sample_num = 1

        for i, (pz_dec_out_s, u_enc_out_s, m_tm1_s, last_hidden_s, degree_input_s) in enumerate(zip(*vars)):
            if not self.get_req_slots(bspan_index[i]):
                continue
            for j in range(sample_num):
                loss = self.sampling_decode_single(pz_dec_out_s, u_enc_out_s, m_tm1_s, u_input_np[:, i].reshape((-1, 1)),
                                                   last_hidden_s, degree_input_s, bspan_index[i])
                batch_loss.append(loss)
        if not batch_loss:
            return None
        else:
            return sum(batch_loss) / len(batch_loss) 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:20,代码来源:tsd_net.py

示例3: sample

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def sample(verts, faces, num=10000, ret_choice = False):
    dist_uni = torch.distributions.Uniform(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
    x1,x2,x3 = torch.split(torch.index_select(verts, 0, faces[:,0]) - torch.index_select(verts, 0, faces[:,1]), 1, dim = 1)
    y1,y2,y3 = torch.split(torch.index_select(verts, 0, faces[:,1]) - torch.index_select(verts, 0, faces[:,2]), 1, dim = 1)
    a = (x2*y3 - x3*y2)**2
    b = (x3*y1 - x1*y3)**2
    c = (x1*y2 - x2*y1)**2
    Areas = torch.sqrt(a+b+c)/2
    Areas = Areas / torch.sum(Areas)
    cat_dist = torch.distributions.Categorical(Areas.view(-1))
    choices = cat_dist.sample_n(num)
    select_faces = faces[choices]
    xs = torch.index_select(verts, 0,select_faces[:,0])
    ys = torch.index_select(verts, 0,select_faces[:,1])
    zs = torch.index_select(verts, 0,select_faces[:,2])
    u = torch.sqrt(dist_uni.sample_n(num))
    v = dist_uni.sample_n(num)
    points = (1- u)*xs + (u*(1-v))*ys + u*v*zs
    if ret_choice:
        return points, choices
    else:
        return points 
开发者ID:nv-tlabs,项目名称:DIB-R,代码行数:24,代码来源:check_chamfer.py

示例4: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        xs = list()
        x0, h = x.unsqueeze(2), x
        for i in range(self.num_layers):
            x = x0 * h.unsqueeze(1)
            batch_size, f0_dim, fin_dim, embed_dim = x.shape
            x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
            x = F.relu(self.conv_layers[i](x))
            if self.split_half and i != self.num_layers - 1:
                x, h = torch.split(x, x.shape[1] // 2, dim=1)
            else:
                h = x
            xs.append(x)
        return self.fc(torch.sum(torch.cat(xs, dim=1), 2)) 
开发者ID:rixwew,项目名称:pytorch-fm,代码行数:19,代码来源:layer.py

示例5: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, input):
        if get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return input * scale + bias 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:25,代码来源:batch_norm.py

示例6: intersection_area

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
    """
    Calculates the intersection area of two lists of bounding boxes.
    :author 申瑞珉 (Ruimin Shen)
    :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
    :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
    :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
    :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
    :return: The matrix (size [N1, N2]) of the intersection area.
    """
    ymin1, xmin1 = torch.split(yx_min1, 1, -1)
    ymax1, xmax1 = torch.split(yx_max1, 1, -1)
    ymin2, xmin2 = torch.split(yx_min2, 1, -1)
    ymax2, xmax2 = torch.split(yx_max2, 1, -1)
    max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug
    min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug
    height = torch.clamp(min_ymax - max_ymin, min=0)
    max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug
    min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug
    width = torch.clamp(min_xmax - max_xmin, min=0)
    return height * width 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:23,代码来源:torch.py

示例7: batch_intersection_area

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
    """
    Calculates the intersection area of two lists of bounding boxes for N independent batches.
    :author 申瑞珉 (Ruimin Shen)
    :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
    :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
    :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
    :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
    :return: The matrics (size [N, N1, N2]) of the intersection area.
    """
    ymin1, xmin1 = torch.split(yx_min1, 1, -1)
    ymax1, xmax1 = torch.split(yx_max1, 1, -1)
    ymin2, xmin2 = torch.split(yx_min2, 1, -1)
    ymax2, xmax2 = torch.split(yx_max2, 1, -1)
    max_ymin = torch.max(ymin1.repeat(1, 1, ymin2.size(1)), torch.transpose(ymin2, 1, 2).repeat(1, ymin1.size(1), 1)) # PyTorch's bug
    min_ymax = torch.min(ymax1.repeat(1, 1, ymax2.size(1)), torch.transpose(ymax2, 1, 2).repeat(1, ymax1.size(1), 1)) # PyTorch's bug
    height = torch.clamp(min_ymax - max_ymin, min=0)
    max_xmin = torch.max(xmin1.repeat(1, 1, xmin2.size(1)), torch.transpose(xmin2, 1, 2).repeat(1, xmin1.size(1), 1)) # PyTorch's bug
    min_xmax = torch.min(xmax1.repeat(1, 1, xmax2.size(1)), torch.transpose(xmax2, 1, 2).repeat(1, xmax1.size(1), 1)) # PyTorch's bug
    width = torch.clamp(min_xmax - max_xmin, min=0)
    return height * width 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:23,代码来源:torch.py

示例8: _test

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def _test(self, bbox1, bbox2, ans, batch_size=2, dtype=np.float32):
        bbox1, bbox2, ans = (np.expand_dims(np.array(a, dtype), 0) for a in (bbox1, bbox2, ans))
        if batch_size > 1:
            bbox1, bbox2, ans = (np.tile(a, (batch_size, 1, 1)) for a in (bbox1, bbox2, ans))
            for b in range(batch_size):
                indices1 = np.random.permutation(bbox1.shape[1])
                indices2 = np.random.permutation(bbox2.shape[1])
                bbox1[b] = bbox1[b][indices1]
                bbox2[b] = bbox2[b][indices2]
                ans[b] = ans[b][indices1][:, indices2]
        yx_min1, yx_max1 = np.split(bbox1, 2, -1)
        yx_min2, yx_max2 = np.split(bbox2, 2, -1)
        assert np.all(yx_min1 <= yx_max1)
        assert np.all(yx_min2 <= yx_max2)
        assert np.all(ans >= 0)
        yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1))
        yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2))
        if torch.cuda.is_available():
            yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2))
        matrix = batch_iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy()
        np.testing.assert_almost_equal(matrix, ans) 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:23,代码来源:torch.py

示例9: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def __init__(self, boxes, ops, syms):
        box_list = [b for b in torch.split(boxes, 1, 0)]
        sym_param = [s for s in torch.split(syms, 1, 0)]
        box_list.reverse()
        sym_param.reverse()
        queue = []
        for id in range(ops.size()[1]):
            if ops[0, id] == Tree.NodeType.BOX.value:
                queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX))
            elif ops[0, id] == Tree.NodeType.ADJ.value:
                left_node = queue.pop()
                right_node = queue.pop()
                queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ))
            elif ops[0, id] == Tree.NodeType.SYM.value:
                node = queue.pop()
                queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM))
        assert len(queue) == 1
        self.root = queue[0] 
开发者ID:kevin-kaixu,项目名称:grass_pytorch,代码行数:20,代码来源:grassdata.py

示例10: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def __init__(self, boxes, ops, syms):
        box_list = [b for b in torch.split(boxes, 1, 0)]
        sym_param = [s for s in torch.split(syms, 1, 0)]
        box_list.reverse()
        sym_param.reverse()
        queue = []
        for id in xrange(ops.size()[1]):
            if ops[0, id] == Tree.NodeType.BOX.value:
                queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX))
            elif ops[0, id] == Tree.NodeType.ADJ.value:
                left_node = queue.pop()
                right_node = queue.pop()
                queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ))
            elif ops[0, id] == Tree.NodeType.SYM.value:
                node = queue.pop()
                queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM))
        assert len(queue) == 1
        self.root = queue[0] 
开发者ID:kevin-kaixu,项目名称:grass_pytorch,代码行数:20,代码来源:grassdata.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, input, logdet=None, reverse=False, local_condition=None):
        if not reverse:
            x_a, x_b = torch.split(input, self.input_channels // 2, 1)
            log_s, t = torch.split(
                self.wavenet(x_a, local_condition), self.input_channels // 2, 1)
            x_b = torch.exp(log_s) * x_b + t
            output = torch.cat([x_a, x_b], 1)
            if logdet is not None:
                logdet = logdet + torch.sum(log_s, (1, 2)) 
            return output, logdet
        else:
            x_a, x_b = torch.split(input, self.input_channels // 2, 1)
            log_s, t = torch.split(
                self.wavenet(x_a, local_condition), self.input_channels // 2, 1)
            x_b = (x_b - t) * torch.exp(-log_s)
            output = torch.cat([x_a, x_b], 1)
            if logdet is not None:
                logdet = logdet - torch.sum(log_s, (1, 2))
            return output, logdet 
开发者ID:npuichigo,项目名称:waveglow,代码行数:21,代码来源:modules.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, inputs):
        if len(inputs.shape) != 3:
            raise ValueError(
                "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))
        inputs = torch.split(inputs, 1, dim=1)
        if self.bilinear_type == "all":
            p = [torch.mul(self.bilinear(v_i), v_j)
                 for v_i, v_j in itertools.combinations(inputs, 2)]
        elif self.bilinear_type == "each":
            p = [torch.mul(self.bilinear[i](inputs[i]), inputs[j])
                 for i, j in itertools.combinations(range(len(inputs)), 2)]
        elif self.bilinear_type == "interaction":
            p = [torch.mul(bilinear(v[0]), v[1])
                 for v, bilinear in zip(itertools.combinations(inputs, 2), self.bilinear)]
        else:
            raise NotImplementedError
        return torch.cat(p, dim=1) 
开发者ID:shenweichen,项目名称:DeepCTR-Torch,代码行数:19,代码来源:interaction.py

示例13: load_partial_weight

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def load_partial_weight(model, pretrained, nl_nums, nl_layer_id):
    """Loads the partial weights for NL/CGNL network.
    """
    _pretrained = pretrained
    _model_dict = model.state_dict()
    _pretrained_dict = OrderedDict()
    for k, v in _pretrained.items():
        ks = k.split('.')
        layer_name = '.'.join(ks[0:2])
        if nl_nums == 1 and \
                layer_name == 'layer3.{}'.format(nl_layer_id):
            ks[1] = str(int(ks[1]) + 1)
            k = '.'.join(ks)
        _pretrained_dict[k] = v
    _model_dict.update(_pretrained_dict)
    return _model_dict 
开发者ID:KaiyuYue,项目名称:cgnl-network.pytorch,代码行数:18,代码来源:resnet.py

示例14: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, input, class_id):
        codes = torch.split(input, 20, 1)
        class_emb = self.linear(class_id)  # 128

        out = self.G_linear(codes[0])
        # out = out.view(-1, 1536, 4, 4)
        out = out.view(-1, self.first_view, 4, 4)
        ids = 1
        for i, conv in enumerate(self.conv):
            if isinstance(conv, GBlock):
                
                conv_code = codes[ids]
                ids = ids+1
                condition = torch.cat([conv_code, class_emb], 1)
                # print('condition',condition.size()) #torch.Size([4, 148])
                out = conv(out, condition)

            else:
                out = conv(out)

        out = self.ScaledCrossReplicaBN(out)
        out = F.relu(out)
        out = self.colorize(out)

        return F.tanh(out) 
开发者ID:sxhxliang,项目名称:BigGAN-pytorch,代码行数:27,代码来源:model_resnet.py

示例15: get_num_level_anchors_inside

# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
        split_inside_flags = torch.split(inside_flags, num_level_anchors)
        num_level_anchors_inside = [
            int(flags.sum()) for flags in split_inside_flags
        ]
        return num_level_anchors_inside 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:8,代码来源:gfl_head.py


注:本文中的torch.split方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。