本文整理汇总了Python中torch.split方法的典型用法代码示例。如果您正苦于以下问题:Python torch.split方法的具体用法?Python torch.split怎么用?Python torch.split使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.split方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: node_forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def node_forward(self, inputs, child_c, child_h):
child_h_sum = torch.sum(child_h, dim=0, keepdim=True)
iou = self.ioux(inputs) + self.iouh(child_h_sum)
i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)
f = F.sigmoid(
self.fh(child_h) +
self.fx(inputs).repeat(len(child_h), 1)
)
fc = torch.mul(f, child_c)
c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
h = torch.mul(o, F.tanh(c))
return c, h
示例2: sampling_decode
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def sampling_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index):
vars = torch.split(pz_dec_outs, 1, dim=1), torch.split(u_enc_out, 1, dim=1), torch.split(
m_tm1, 1, dim=1), torch.split(last_hidden, 1, dim=1), torch.split(degree_input, 1, dim=0)
batch_loss = []
sample_num = 1
for i, (pz_dec_out_s, u_enc_out_s, m_tm1_s, last_hidden_s, degree_input_s) in enumerate(zip(*vars)):
if not self.get_req_slots(bspan_index[i]):
continue
for j in range(sample_num):
loss = self.sampling_decode_single(pz_dec_out_s, u_enc_out_s, m_tm1_s, u_input_np[:, i].reshape((-1, 1)),
last_hidden_s, degree_input_s, bspan_index[i])
batch_loss.append(loss)
if not batch_loss:
return None
else:
return sum(batch_loss) / len(batch_loss)
示例3: sample
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def sample(verts, faces, num=10000, ret_choice = False):
dist_uni = torch.distributions.Uniform(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())
x1,x2,x3 = torch.split(torch.index_select(verts, 0, faces[:,0]) - torch.index_select(verts, 0, faces[:,1]), 1, dim = 1)
y1,y2,y3 = torch.split(torch.index_select(verts, 0, faces[:,1]) - torch.index_select(verts, 0, faces[:,2]), 1, dim = 1)
a = (x2*y3 - x3*y2)**2
b = (x3*y1 - x1*y3)**2
c = (x1*y2 - x2*y1)**2
Areas = torch.sqrt(a+b+c)/2
Areas = Areas / torch.sum(Areas)
cat_dist = torch.distributions.Categorical(Areas.view(-1))
choices = cat_dist.sample_n(num)
select_faces = faces[choices]
xs = torch.index_select(verts, 0,select_faces[:,0])
ys = torch.index_select(verts, 0,select_faces[:,1])
zs = torch.index_select(verts, 0,select_faces[:,2])
u = torch.sqrt(dist_uni.sample_n(num))
v = dist_uni.sample_n(num)
points = (1- u)*xs + (u*(1-v))*ys + u*v*zs
if ret_choice:
return points, choices
else:
return points
示例4: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, x):
"""
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
"""
xs = list()
x0, h = x.unsqueeze(2), x
for i in range(self.num_layers):
x = x0 * h.unsqueeze(1)
batch_size, f0_dim, fin_dim, embed_dim = x.shape
x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
x = F.relu(self.conv_layers[i](x))
if self.split_half and i != self.num_layers - 1:
x, h = torch.split(x, x.shape[1] // 2, dim=1)
else:
h = x
xs.append(x)
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
示例5: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, input):
if get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
C = input.shape[1]
mean = torch.mean(input, dim=[0, 2, 3])
meansqr = torch.mean(input * input, dim=[0, 2, 3])
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
self.running_var += self.momentum * (var.detach() - self.running_var)
invstd = torch.rsqrt(var + self.eps)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return input * scale + bias
示例6: intersection_area
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
"""
Calculates the intersection area of two lists of bounding boxes.
:author 申瑞珉 (Ruimin Shen)
:param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
:param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes.
:param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
:param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes.
:return: The matrix (size [N1, N2]) of the intersection area.
"""
ymin1, xmin1 = torch.split(yx_min1, 1, -1)
ymax1, xmax1 = torch.split(yx_max1, 1, -1)
ymin2, xmin2 = torch.split(yx_min2, 1, -1)
ymax2, xmax2 = torch.split(yx_max2, 1, -1)
max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug
min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug
height = torch.clamp(min_ymax - max_ymin, min=0)
max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug
min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug
width = torch.clamp(min_xmax - max_xmin, min=0)
return height * width
示例7: batch_intersection_area
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2):
"""
Calculates the intersection area of two lists of bounding boxes for N independent batches.
:author 申瑞珉 (Ruimin Shen)
:param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
:param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes.
:param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
:param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes.
:return: The matrics (size [N, N1, N2]) of the intersection area.
"""
ymin1, xmin1 = torch.split(yx_min1, 1, -1)
ymax1, xmax1 = torch.split(yx_max1, 1, -1)
ymin2, xmin2 = torch.split(yx_min2, 1, -1)
ymax2, xmax2 = torch.split(yx_max2, 1, -1)
max_ymin = torch.max(ymin1.repeat(1, 1, ymin2.size(1)), torch.transpose(ymin2, 1, 2).repeat(1, ymin1.size(1), 1)) # PyTorch's bug
min_ymax = torch.min(ymax1.repeat(1, 1, ymax2.size(1)), torch.transpose(ymax2, 1, 2).repeat(1, ymax1.size(1), 1)) # PyTorch's bug
height = torch.clamp(min_ymax - max_ymin, min=0)
max_xmin = torch.max(xmin1.repeat(1, 1, xmin2.size(1)), torch.transpose(xmin2, 1, 2).repeat(1, xmin1.size(1), 1)) # PyTorch's bug
min_xmax = torch.min(xmax1.repeat(1, 1, xmax2.size(1)), torch.transpose(xmax2, 1, 2).repeat(1, xmax1.size(1), 1)) # PyTorch's bug
width = torch.clamp(min_xmax - max_xmin, min=0)
return height * width
示例8: _test
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def _test(self, bbox1, bbox2, ans, batch_size=2, dtype=np.float32):
bbox1, bbox2, ans = (np.expand_dims(np.array(a, dtype), 0) for a in (bbox1, bbox2, ans))
if batch_size > 1:
bbox1, bbox2, ans = (np.tile(a, (batch_size, 1, 1)) for a in (bbox1, bbox2, ans))
for b in range(batch_size):
indices1 = np.random.permutation(bbox1.shape[1])
indices2 = np.random.permutation(bbox2.shape[1])
bbox1[b] = bbox1[b][indices1]
bbox2[b] = bbox2[b][indices2]
ans[b] = ans[b][indices1][:, indices2]
yx_min1, yx_max1 = np.split(bbox1, 2, -1)
yx_min2, yx_max2 = np.split(bbox2, 2, -1)
assert np.all(yx_min1 <= yx_max1)
assert np.all(yx_min2 <= yx_max2)
assert np.all(ans >= 0)
yx_min1, yx_max1 = torch.autograd.Variable(torch.from_numpy(yx_min1)), torch.autograd.Variable(torch.from_numpy(yx_max1))
yx_min2, yx_max2 = torch.autograd.Variable(torch.from_numpy(yx_min2)), torch.autograd.Variable(torch.from_numpy(yx_max2))
if torch.cuda.is_available():
yx_min1, yx_max1, yx_min2, yx_max2 = (v.cuda() for v in (yx_min1, yx_max1, yx_min2, yx_max2))
matrix = batch_iou_matrix(yx_min1, yx_max1, yx_min2, yx_max2).data.cpu().numpy()
np.testing.assert_almost_equal(matrix, ans)
示例9: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def __init__(self, boxes, ops, syms):
box_list = [b for b in torch.split(boxes, 1, 0)]
sym_param = [s for s in torch.split(syms, 1, 0)]
box_list.reverse()
sym_param.reverse()
queue = []
for id in range(ops.size()[1]):
if ops[0, id] == Tree.NodeType.BOX.value:
queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX))
elif ops[0, id] == Tree.NodeType.ADJ.value:
left_node = queue.pop()
right_node = queue.pop()
queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ))
elif ops[0, id] == Tree.NodeType.SYM.value:
node = queue.pop()
queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM))
assert len(queue) == 1
self.root = queue[0]
示例10: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def __init__(self, boxes, ops, syms):
box_list = [b for b in torch.split(boxes, 1, 0)]
sym_param = [s for s in torch.split(syms, 1, 0)]
box_list.reverse()
sym_param.reverse()
queue = []
for id in xrange(ops.size()[1]):
if ops[0, id] == Tree.NodeType.BOX.value:
queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX))
elif ops[0, id] == Tree.NodeType.ADJ.value:
left_node = queue.pop()
right_node = queue.pop()
queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ))
elif ops[0, id] == Tree.NodeType.SYM.value:
node = queue.pop()
queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM))
assert len(queue) == 1
self.root = queue[0]
示例11: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, input, logdet=None, reverse=False, local_condition=None):
if not reverse:
x_a, x_b = torch.split(input, self.input_channels // 2, 1)
log_s, t = torch.split(
self.wavenet(x_a, local_condition), self.input_channels // 2, 1)
x_b = torch.exp(log_s) * x_b + t
output = torch.cat([x_a, x_b], 1)
if logdet is not None:
logdet = logdet + torch.sum(log_s, (1, 2))
return output, logdet
else:
x_a, x_b = torch.split(input, self.input_channels // 2, 1)
log_s, t = torch.split(
self.wavenet(x_a, local_condition), self.input_channels // 2, 1)
x_b = (x_b - t) * torch.exp(-log_s)
output = torch.cat([x_a, x_b], 1)
if logdet is not None:
logdet = logdet - torch.sum(log_s, (1, 2))
return output, logdet
示例12: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, inputs):
if len(inputs.shape) != 3:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))
inputs = torch.split(inputs, 1, dim=1)
if self.bilinear_type == "all":
p = [torch.mul(self.bilinear(v_i), v_j)
for v_i, v_j in itertools.combinations(inputs, 2)]
elif self.bilinear_type == "each":
p = [torch.mul(self.bilinear[i](inputs[i]), inputs[j])
for i, j in itertools.combinations(range(len(inputs)), 2)]
elif self.bilinear_type == "interaction":
p = [torch.mul(bilinear(v[0]), v[1])
for v, bilinear in zip(itertools.combinations(inputs, 2), self.bilinear)]
else:
raise NotImplementedError
return torch.cat(p, dim=1)
示例13: load_partial_weight
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def load_partial_weight(model, pretrained, nl_nums, nl_layer_id):
"""Loads the partial weights for NL/CGNL network.
"""
_pretrained = pretrained
_model_dict = model.state_dict()
_pretrained_dict = OrderedDict()
for k, v in _pretrained.items():
ks = k.split('.')
layer_name = '.'.join(ks[0:2])
if nl_nums == 1 and \
layer_name == 'layer3.{}'.format(nl_layer_id):
ks[1] = str(int(ks[1]) + 1)
k = '.'.join(ks)
_pretrained_dict[k] = v
_model_dict.update(_pretrained_dict)
return _model_dict
示例14: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def forward(self, input, class_id):
codes = torch.split(input, 20, 1)
class_emb = self.linear(class_id) # 128
out = self.G_linear(codes[0])
# out = out.view(-1, 1536, 4, 4)
out = out.view(-1, self.first_view, 4, 4)
ids = 1
for i, conv in enumerate(self.conv):
if isinstance(conv, GBlock):
conv_code = codes[ids]
ids = ids+1
condition = torch.cat([conv_code, class_emb], 1)
# print('condition',condition.size()) #torch.Size([4, 148])
out = conv(out, condition)
else:
out = conv(out)
out = self.ScaledCrossReplicaBN(out)
out = F.relu(out)
out = self.colorize(out)
return F.tanh(out)
示例15: get_num_level_anchors_inside
# 需要导入模块: import torch [as 别名]
# 或者: from torch import split [as 别名]
def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
split_inside_flags = torch.split(inside_flags, num_level_anchors)
num_level_anchors_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_anchors_inside