本文整理汇总了Python中torch.repeat_interleave方法的典型用法代码示例。如果您正苦于以下问题:Python torch.repeat_interleave方法的具体用法?Python torch.repeat_interleave怎么用?Python torch.repeat_interleave使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.repeat_interleave方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: hard_k_hot
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def hard_k_hot(logits, k, temperature=0.1):
r"""Returns a hard k-hot sample given a categorical
distribution defined by a tensor of unnormalized
log-likelihoods.
This is useful for example to sample a set of pixels in an
image to move from a grid-structured data representation to a
set- or graph-structured representation within a network.
Args:
logits (torch.Tensor): unnormalized log-likelihood tensor.
k (int): number of items to sample without replacement.
temperature (float): temparature of the soft distribution.
Returns:
Hard k-hot vector from the relaxed k-hot distribution
defined by logits and temperature.
"""
soft = soft_k_hot(logits, k, temperature=temperature)
hard = torch.zeros_like(soft)
_, top_k = torch.topk(logits, k)
index = torch.repeat_interleave(torch.arange(0, hard.size(0)), k)
hard[index, top_k.view(-1)] = 1.0
return replace_gradient(hard, soft)
示例2: pairwise_no_pad
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def pairwise_no_pad(op, data, indices):
unique, counts = indices.unique(return_counts=True)
expansion = torch.cumsum(counts, dim=0)
expansion = torch.repeat_interleave(expansion, counts)
offset = torch.arange(0, counts.sum(), device=data.device)
expansion = expansion - offset - 1
expanded = torch.repeat_interleave(data, expansion.to(data.device), dim=0)
expansion_offset = counts.roll(1)
expansion_offset[0] = 0
expansion_offset = expansion_offset.cumsum(dim=0)
expansion_offset = torch.repeat_interleave(expansion_offset, counts)
expansion_offset = torch.repeat_interleave(expansion_offset, expansion)
off_start = torch.repeat_interleave(torch.repeat_interleave(counts, counts) - expansion, expansion)
access = torch.arange(expansion.sum(), device=data.device)
access = access - torch.repeat_interleave(expansion.roll(1).cumsum(dim=0), expansion) + off_start + expansion_offset
result = op(expanded, data[access.to(data.device)])
return result, torch.repeat_interleave(indices, expansion, dim=0)
示例3: __patched_conv_ops
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def __patched_conv_ops(op, x, y, *args, **kwargs):
x_encoded = CUDALongTensor.__encode_as_fp64(x).data
y_encoded = CUDALongTensor.__encode_as_fp64(y).data
repeat_idx = [1] * (x_encoded.dim() - 1)
x_enc_span = x_encoded.repeat(3, *repeat_idx)
y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)
bs, c, *img = x.size()
c_out, c_in, *ks = y.size()
x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img)
y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks)
c_z = c_out if op in ["conv1d", "conv2d"] else c_in
z_encoded = getattr(torch, op)(
x_enc_span, y_enc_span, *args, **kwargs, groups=9
)
z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_(
0, 1
)
return CUDALongTensor.__decode_as_int64(z_encoded)
示例4: get_tiled_batch
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def get_tiled_batch(self, num_tiles: int):
assert (
self.has_float_features_only
), f"only works for float features now: {self}"
"""
tiled_feature should be (batch_size * num_tiles, feature_dim)
forall i in [batch_size],
tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
"""
feat = self.float_features
assert (
len(feat.shape) == 2
), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
batch_size, _ = feat.shape
# pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
return FeatureData(float_features=tiled_feat)
示例5: get_output
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def get_output(encoder_output, duration_predictor_output, alpha, mel_max_length=None):
output = list()
dec_pos = list()
for i in range(encoder_output.size(0)):
repeats = duration_predictor_output[i].float() * alpha
repeats = torch.round(repeats).long()
output.append(torch.repeat_interleave(encoder_output[i], repeats, dim=0))
dec_pos.append(torch.from_numpy(np.indices((output[i].shape[0],))[0] + 1))
output = torch.nn.utils.rnn.pad_sequence(output, batch_first=True)
dec_pos = torch.nn.utils.rnn.pad_sequence(dec_pos, batch_first=True)
dec_pos = dec_pos.to(output.device, non_blocking=True)
if mel_max_length:
output = output[:, :mel_max_length]
dec_pos = dec_pos[:, :mel_max_length]
return output, dec_pos
示例6: test_instance_norm
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def test_instance_norm():
batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
norm = InstanceNorm(16)
assert norm.__repr__() == (
'InstanceNorm(16, eps=1e-05, momentum=0.1, affine=False, '
'track_running_stats=False)')
out = norm(torch.randn(100, 16), batch)
assert out.size() == (100, 16)
norm = InstanceNorm(16, affine=True, track_running_stats=True)
out = norm(torch.randn(100, 16), batch)
assert out.size() == (100, 16)
# Should behave equally to `BatchNorm` for mini-batches of size 1.
x = torch.randn(100, 16)
norm1 = InstanceNorm(16, affine=False, track_running_stats=False)
norm2 = BatchNorm(16, affine=False, track_running_stats=False)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
norm1 = InstanceNorm(16, affine=False, track_running_stats=True)
norm2 = BatchNorm(16, affine=False, track_running_stats=True)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
norm1.eval()
norm2.eval()
assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
示例7: test_graph_size_norm
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def test_graph_size_norm():
batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
norm = GraphSizeNorm()
out = norm(torch.randn(100, 16), batch)
assert out.size() == (100, 16)
示例8: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, seq_value_len_list):
if self.supports_masking:
uiseq_embed_list, mask = seq_value_len_list # [B, T, E], [B, 1]
mask = mask.float()
user_behavior_length = torch.sum(mask, dim=-1, keepdim=True)
mask = mask.unsqueeze(2)
else:
uiseq_embed_list, user_behavior_length = seq_value_len_list # [B, T, E], [B, 1]
mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1],
dtype=torch.float32) # [B, 1, maxlen]
mask = torch.transpose(mask, 1, 2) # [B, maxlen, 1]
embedding_size = uiseq_embed_list.shape[-1]
mask = torch.repeat_interleave(mask, embedding_size, dim=2) # [B, maxlen, E]
if self.mode == 'max':
hist = uiseq_embed_list - (1 - mask) * 1e9
hist = torch.max(hist, dim=1, keepdim=True)[0]
return hist
hist = uiseq_embed_list * mask.float()
hist = torch.sum(hist, dim=1, keepdim=False)
if self.mode == 'mean':
hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps)
hist = torch.unsqueeze(hist, dim=1)
return hist
示例9: repeat
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def repeat(input, repeats, dim):
# return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
if dim < 0:
dim += input.dim()
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
示例10: sample
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def sample(self, data):
support, values = data
mean, logvar = self.condition(support)
distribution = Normal(mean, torch.exp(0.5 * logvar))
latent_sample = distribution.rsample()
latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
local_samples = torch.randn(support.size(0) * self.size, 16)
sample = torch.cat((latent_sample, local_samples), dim=1)
return (support, sample), (mean, logvar)
示例11: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, data):
support, values = data
mean, logvar = self.encoder(support)
distribution = Normal(mean, torch.exp(0.5 * logvar))
latent_sample = distribution.rsample()
latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
combined = torch.cat((values.view(-1, 28 * 28), latent_sample), dim=1)
return self.verdict(combined)
示例12: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, image, condition):
image = image.view(-1, 3, 64, 64)
out = self.input_process(self.input(image))
mean, logvar = self.condition(condition)
#distribution = Normal(mean, torch.exp(0.5 * logvar))
sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
cond = self.postprocess(sample)
cond = torch.repeat_interleave(cond, 5, dim=0)
result = self.combine(torch.cat((out, cond), dim=1))
return result, (mean, logvar)
示例13: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, image, condition):
image = image.view(-1, 28 * 28)
out = self.input_process(self.input(image))
mean, logvar = self.condition(condition)
#distribution = Normal(mean, torch.exp(0.5 * logvar))
sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
cond = self.postprocess(sample)
cond = torch.repeat_interleave(cond, 5, dim=0)
result = self.combine(torch.cat((out, cond), dim=1))
return result, (mean, logvar)
示例14: repack
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def repack(data, indices, target_indices):
out = torch.zeros(
target_indices.size(0), *data.shape[1:],
dtype=data.dtype, device=data.device
)
unique, lengths = indices.unique(return_counts=True)
unique, target_lengths = target_indices.unique(return_counts=True)
offset = target_lengths - lengths
offset = offset.roll(1, 0)
offset[0] = 0
offset = torch.repeat_interleave(offset.cumsum(dim=0), lengths, dim=0)
index = offset + torch.arange(len(indices)).to(data.device)
out[index] = data
return data, target_indices
示例15: pairwise
# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def pairwise(op, data, indices, padding_value=0):
padded, _, _, counts = pad(data, indices, value=padding_value)
padded = padded.transpose(1, 2)
reference = padded.unsqueeze(-1)
padded = padded.unsqueeze(-2)
op_result = op(padded, reference)
# batch indices into pairwise tensor:
batch_indices = torch.arange(counts.size(0))
batch_indices = torch.repeat_interleave(batch_indices, counts ** 2)
# first dimension indices:
first_offset = counts.roll(1)
first_offset[0] = 0
first_offset = torch.cumsum(first_offset, dim=0)
first_offset = torch.repeat_interleave(first_offset, counts)
first_indices = torch.arange(counts.sum()) - first_offset
first_indices = torch.repeat_interleave(
first_indices,
torch.repeat_interleave(counts, counts)
)
# second dimension indices:
second_offset = torch.repeat_interleave(counts, counts).roll(1)
second_offset[0] = 0
second_offset = torch.cumsum(second_offset, dim=0)
second_offset = torch.repeat_interleave(second_offset, torch.repeat_interleave(counts, counts))
second_indices = torch.arange((counts ** 2).sum()) - second_offset
# extract tensor from padded result using indices:
result = op_result[batch_indices, first_indices, second_indices]
# access: cumsum(counts ** 2)[idx] + counts[idx] * idy + idz
access_batch = (counts ** 2).roll(1)
access_batch[0] = 0
access_batch = torch.cumsum(access_batch, dim=0)
access_first = counts
access = (access_batch, access_first)
return result, batch_indices, first_indices, second_indices, access