本文整理汇总了Python中torch.chunk方法的典型用法代码示例。如果您正苦于以下问题:Python torch.chunk方法的具体用法?Python torch.chunk怎么用?Python torch.chunk使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.chunk方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, x, query_input=None):
"""x: [bs, num cols, d_model]. Output has the same shape."""
assert x.dim() == 3, x.size()
bs, ncols, _ = x.size()
# [bs, num cols, d_state * 3 * num_heads]
qkv = self.qkv_linear(x)
# [bs, num heads, num cols, d_state] each
qs, ks, vs = map(self._split_heads, torch.chunk(qkv, 3, dim=-1))
if query_input is not None:
# TODO: obviously can avoid redundant calc.
qkv = self.qkv_linear(query_input)
qs, _, _ = map(self._split_heads, torch.chunk(qkv, 3, dim=-1))
# [bs, num heads, num cols, d_state]
x = self._do_attention(qs, ks, vs, mask=self.attn_mask.to(x.device))
# [bs, num cols, num heads, d_state]
x = x.transpose(1, 2)
# Concat all heads' outputs: [bs, num cols, num heads * d_state]
x = x.contiguous().view(bs, ncols, -1)
# Then do a transform: [bs, num cols, d_model].
x = self.linear(x)
return x
示例2: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self,h,emb):
sbatch,nsq,lchunk=h.size()
h=h.contiguous()
"""
# Slower version
ws=list(self.adapt_w(emb).view(sbatch,self.ncha,1,self.kw))
bs=list(self.adapt_b(emb))
hs=list(torch.chunk(h,sbatch,dim=0))
out=[]
for hi,wi,bi in zip(hs,ws,bs):
out.append(torch.nn.functional.conv1d(hi,wi,bias=bi,padding=self.kw//2,groups=nsq))
h=torch.cat(out,dim=0)
"""
# Faster version fully using group convolution
w=self.adapt_w(emb).view(-1,1,self.kw)
b=self.adapt_b(emb).view(-1)
h=torch.nn.functional.conv1d(h.view(1,-1,lchunk),w,bias=b,padding=self.kw//2,groups=sbatch*nsq).view(sbatch,self.ncha,lchunk)
#"""
h=self.net.forward(h)
s,m=torch.chunk(h,2,dim=1)
s=torch.sigmoid(s+2)+1e-7
return s,m
########################################################################################################################
########################################################################################################################
示例3: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def __init__(self, dir, transform=None):
self.dir = dir
box_data = torch.from_numpy(loadmat(self.dir+'/box_data.mat')['boxes']).float()
op_data = torch.from_numpy(loadmat(self.dir+'/op_data.mat')['ops']).int()
sym_data = torch.from_numpy(loadmat(self.dir+'/sym_data.mat')['syms']).float()
#weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float()
num_examples = op_data.size()[1]
box_data = torch.chunk(box_data, num_examples, 1)
op_data = torch.chunk(op_data, num_examples, 1)
sym_data = torch.chunk(sym_data, num_examples, 1)
#weight_list = torch.chunk(weight_list, num_examples, 1)
self.transform = transform
self.trees = []
for i in range(len(op_data)) :
boxes = torch.t(box_data[i])
ops = torch.t(op_data[i])
syms = torch.t(sym_data[i])
tree = Tree(boxes, ops, syms)
self.trees.append(tree)
示例4: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def __init__(self, dir, transform=None):
self.dir = dir
box_data = torch.from_numpy(loadmat(self.dir+u'/box_data.mat')[u'boxes']).float()
op_data = torch.from_numpy(loadmat(self.dir+u'/op_data.mat')[u'ops']).int()
sym_data = torch.from_numpy(loadmat(self.dir+u'/sym_data.mat')[u'syms']).float()
#weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float()
num_examples = op_data.size()[1]
box_data = torch.chunk(box_data, num_examples, 1)
op_data = torch.chunk(op_data, num_examples, 1)
sym_data = torch.chunk(sym_data, num_examples, 1)
#weight_list = torch.chunk(weight_list, num_examples, 1)
self.transform = transform
self.trees = []
for i in xrange(len(op_data)) :
boxes = torch.t(box_data[i])
ops = torch.t(op_data[i])
syms = torch.t(sym_data[i])
tree = Tree(boxes, ops, syms)
self.trees.append(tree)
示例5: pack_sequence_for_linear
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def pack_sequence_for_linear(inputs, lengths, batch_first=True):
"""
:param inputs: [B, T, D] if batch_first
:param lengths: [B]
:param batch_first:
:return:
"""
batch_list = []
if batch_first:
for i, l in enumerate(lengths):
# print(inputs[i, :l].size())
batch_list.append(inputs[i, :l])
packed_sequence = torch.cat(batch_list, 0)
# if chuck:
# return list(torch.chunk(packed_sequence, chuck, dim=0))
# else:
return packed_sequence
else:
raise NotImplemented()
示例6: seq2seq_cross_entropy
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def seq2seq_cross_entropy(logits, label, l, chuck=None, sos_truncate=True):
"""
:param logits: [exB, V] : exB = sum(l)
:param label: [B] : a batch of Label
:param l: [B] : a batch of LongTensor indicating the lengths of each inputs
:param chuck: Number of chuck to process
:return: A loss value
"""
packed_label = pack_sequence_for_linear(label, l)
cross_entropy_loss = functools.partial(F.cross_entropy, size_average=False)
total = sum(l)
assert total == logits.size(0) or packed_label.size(0) == logits.size(0),\
"logits length mismatch with label length."
if chuck:
logits_losses = 0
for x, y in zip(torch.chunk(logits, chuck, dim=0), torch.chunk(packed_label, chuck, dim=0)):
logits_losses += cross_entropy_loss(x, y)
return logits_losses * (1 / total)
else:
return cross_entropy_loss(logits, packed_label) * (1 / total)
示例7: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, x):
# input x with shape [B, F, T]
# FORWARD THROUGH DRN
# ----------------------------
if self.frontend is not None:
x = self.frontend(x)
if not self.ft_fe:
x = x.detach()
x = F.pad(x, (4, 5))
x = self.drn(x)
# FORWARD THROUGH RNN
# ----------------------------
x = x.transpose(1, 2)
x, _ = self.rnn(x)
xt = torch.chunk(x, x.shape[1], dim=1)
x = xt[-1].transpose(1, 2)
# FORWARD THROUGH DNn
# ----------------------------
x = self.mlp(x)
return x
示例8: format_frontend_chunk
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def format_frontend_chunk(batch, device='cpu'):
if type(batch) == dict:
if 'chunk_ctxt' and 'chunk_rand' in batch:
keys = ['chunk', 'chunk_ctxt', 'chunk_rand', 'cchunk']
# cluster all 'chunk's, including possible 'cchunk'
batches = [batch[k] for k in keys if k in batch]
x = torch.cat(batches, dim=0).to(device)
# store the number of batches condensed as format
data_fmt = len(batches)
else:
x = batch['chunk'].to(device)
data_fmt = 1
else:
x = batch
data_fmt = 0
return x, data_fmt
示例9: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, batch, device=None, mode=None):
# batch possible chunk and contexts, or just forward non-dict tensor
x, data_fmt = format_frontend_chunk(batch, device)
sinc_out = self.sinc(x).unsqueeze(1)
# print(sinc_out.shape)
conv_out = self.conv1(sinc_out)
# print(conv_out.shape)
res_out = self.resnet(conv_out)
# print(res_out.shape)
h =self.conv2(res_out).squeeze(2)
# print(h.shape)
return format_frontend_output(h, data_fmt, mode)
示例10: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, x):
h = self.frontend(x)
if not self.ft_fe:
h = h.detach()
if hasattr(self, 'z_bnorm'):
h = self.z_bnorm(h)
ht, state = self.rnn(h.transpose(1, 2))
if self.return_sequence:
ht = ht.transpose(1, 2)
else:
if not self.uni:
# pick last time-step for each dir
# first chunk feat dim
bsz, slen, feats = ht.size()
ht = torch.chunk(ht.view(bsz, slen, 2, feats // 2), 2, dim=2)
# now select fwd
ht_fwd = ht[0][:, -1, 0, :].unsqueeze(2)
ht_bwd = ht[1][:, 0, 0, :].unsqueeze(2)
ht = torch.cat((ht_fwd, ht_bwd), dim=1)
else:
# just last time-step works
ht = ht[:, -1, :].unsqueeze(2)
y = self.model(ht)
return y
示例11: step
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
t0 = time.time()
z = model(g, lg, deg_g, deg_lg, pm_pd)
t_forward = time.time() - t0
z_list = th.chunk(z, args.batch_size, 0)
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
overlap = compute_overlap(z_list)
optimizer.zero_grad()
t0 = time.time()
loss.backward()
t_backward = time.time() - t0
optimizer.step()
return loss, overlap, t_forward, t_backward
示例12: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, input):
input = input.clone()
input = self.preprocess_range(input)
if self.preprocessing_type == 'caffe':
r, g, b = torch.chunk(input, 3, dim=1)
bgr = torch.cat([b, g, r], 1)
out = bgr * 255 - self.vgg_mean
elif self.preprocessing_type == 'pytorch':
input = input - self.vgg_mean
input = input / self.vgg_std
output = input
outputs = []
for block in self.blocks:
output = block(output)
outputs.append(output)
return outputs
示例13: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, x):
if self.downsample:
y1 = self.shortcut_dconv(x)
y1 = self.shortcut_conv(y1)
x2 = x
else:
y1, x2 = torch.chunk(x, chunks=2, dim=1)
y2 = self.conv1(x2)
y2 = self.dconv(y2)
y2 = self.conv2(y2)
if self.use_se:
y2 = self.se(y2)
if self.use_residual and not self.downsample:
y2 = y2 + x2
x = torch.cat((y1, y2), dim=1)
x = self.c_shuffle(x)
return x
示例14: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(ctx, x, fm, gm, *params):
with torch.no_grad():
x1, x2 = torch.chunk(x, chunks=2, dim=1)
x1 = x1.contiguous()
x2 = x2.contiguous()
y1 = x1 + fm(x2)
y2 = x2 + gm(y1)
y = torch.cat((y1, y2), dim=1)
x1.set_()
x2.set_()
y1.set_()
y2.set_()
del x1, x2, y1, y2
ctx.save_for_backward(x, y)
ctx.fm = fm
ctx.gm = gm
return y
示例15: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import chunk [as 别名]
def forward(self, inputs, children, arities):
i = self.wi_net(inputs)
o = self.wo_net(inputs)
u = self.wu_net(inputs)
f_base = self.wf_net(inputs)
fc_sum = inputs.new_zeros(self.memory_size)
for k, child in enumerate(children):
child_h, child_c = torch.chunk(child, 2, dim=1)
i.add_(self.ui_nets[k](child_h))
o.add_(self.uo_nets[k](child_h))
u.add_(self.uu_nets[k](child_h))
f = f_base
for l, other_child in enumerate(children):
other_child_h, _ = torch.chunk(other_child, 2, dim=1)
f = f.add(self.uf_nets[k][l](other_child_h))
fc_sum.add(torch.sigmoid(f) * child_c)
c = torch.sigmoid(i) * torch.tanh(u) + fc_sum
h = torch.sigmoid(o) * torch.tanh(c)
return torch.cat([h, c], dim=1)