当前位置: 首页>>代码示例>>Python>>正文


Python torch.repeat_interleave方法代码示例

本文整理汇总了Python中torch.repeat_interleave方法的典型用法代码示例。如果您正苦于以下问题:Python torch.repeat_interleave方法的具体用法?Python torch.repeat_interleave怎么用?Python torch.repeat_interleave使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.repeat_interleave方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: hard_k_hot

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def hard_k_hot(logits, k, temperature=0.1):
  r"""Returns a hard k-hot sample given a categorical
  distribution defined by a tensor of unnormalized
  log-likelihoods.

  This is useful for example to sample a set of pixels in an
  image to move from a grid-structured data representation to a
  set- or graph-structured representation within a network.

  Args:
    logits (torch.Tensor): unnormalized log-likelihood tensor.
    k (int): number of items to sample without replacement.
    temperature (float): temparature of the soft distribution.

  Returns:
    Hard k-hot vector from the relaxed k-hot distribution
    defined by logits and temperature.
  """
  soft = soft_k_hot(logits, k, temperature=temperature)
  hard = torch.zeros_like(soft)
  _, top_k = torch.topk(logits, k)
  index = torch.repeat_interleave(torch.arange(0, hard.size(0)), k)
  hard[index, top_k.view(-1)] = 1.0
  return replace_gradient(hard, soft) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:26,代码来源:gradient.py

示例2: pairwise_no_pad

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def pairwise_no_pad(op, data, indices):
  unique, counts = indices.unique(return_counts=True)
  expansion = torch.cumsum(counts, dim=0)
  expansion = torch.repeat_interleave(expansion, counts)
  offset = torch.arange(0, counts.sum(), device=data.device)
  expansion = expansion - offset - 1
  expanded = torch.repeat_interleave(data, expansion.to(data.device), dim=0)

  expansion_offset = counts.roll(1)
  expansion_offset[0] = 0
  expansion_offset = expansion_offset.cumsum(dim=0)
  expansion_offset = torch.repeat_interleave(expansion_offset, counts)
  expansion_offset = torch.repeat_interleave(expansion_offset, expansion)
  off_start = torch.repeat_interleave(torch.repeat_interleave(counts, counts) - expansion, expansion)
  access = torch.arange(expansion.sum(), device=data.device)
  access = access - torch.repeat_interleave(expansion.roll(1).cumsum(dim=0), expansion) + off_start + expansion_offset

  result = op(expanded, data[access.to(data.device)])
  return result, torch.repeat_interleave(indices, expansion, dim=0) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:21,代码来源:scatter.py

示例3: __patched_conv_ops

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def __patched_conv_ops(op, x, y, *args, **kwargs):
        x_encoded = CUDALongTensor.__encode_as_fp64(x).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y).data

        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(3, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)

        bs, c, *img = x.size()
        c_out, c_in, *ks = y.size()

        x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img)
        y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks)

        c_z = c_out if op in ["conv1d", "conv2d"] else c_in

        z_encoded = getattr(torch, op)(
            x_enc_span, y_enc_span, *args, **kwargs, groups=9
        )
        z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_(
            0, 1
        )

        return CUDALongTensor.__decode_as_int64(z_encoded) 
开发者ID:facebookresearch,项目名称:CrypTen,代码行数:26,代码来源:cuda_tensor.py

示例4: get_tiled_batch

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def get_tiled_batch(self, num_tiles: int):
        assert (
            self.has_float_features_only
        ), f"only works for float features now: {self}"
        """
        tiled_feature should be (batch_size * num_tiles, feature_dim)
        forall i in [batch_size],
        tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
        """
        feat = self.float_features
        assert (
            len(feat.shape) == 2
        ), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
        batch_size, _ = feat.shape
        # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
        tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
        return FeatureData(float_features=tiled_feat) 
开发者ID:facebookresearch,项目名称:ReAgent,代码行数:19,代码来源:types.py

示例5: get_output

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def get_output(encoder_output, duration_predictor_output, alpha, mel_max_length=None):
        output = list()
        dec_pos = list()

        for i in range(encoder_output.size(0)):
            repeats = duration_predictor_output[i].float() * alpha
            repeats = torch.round(repeats).long()
            output.append(torch.repeat_interleave(encoder_output[i], repeats, dim=0))
            dec_pos.append(torch.from_numpy(np.indices((output[i].shape[0],))[0] + 1))

        output = torch.nn.utils.rnn.pad_sequence(output, batch_first=True)
        dec_pos = torch.nn.utils.rnn.pad_sequence(dec_pos, batch_first=True)

        dec_pos = dec_pos.to(output.device, non_blocking=True)

        if mel_max_length:
            output = output[:, :mel_max_length]
            dec_pos = dec_pos[:, :mel_max_length]

        return output, dec_pos 
开发者ID:NVIDIA,项目名称:NeMo,代码行数:22,代码来源:fastspeech.py

示例6: test_instance_norm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def test_instance_norm():
    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))

    norm = InstanceNorm(16)
    assert norm.__repr__() == (
        'InstanceNorm(16, eps=1e-05, momentum=0.1, affine=False, '
        'track_running_stats=False)')
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16)

    norm = InstanceNorm(16, affine=True, track_running_stats=True)
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16)

    # Should behave equally to `BatchNorm` for mini-batches of size 1.
    x = torch.randn(100, 16)
    norm1 = InstanceNorm(16, affine=False, track_running_stats=False)
    norm2 = BatchNorm(16, affine=False, track_running_stats=False)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)

    norm1 = InstanceNorm(16, affine=False, track_running_stats=True)
    norm2 = BatchNorm(16, affine=False, track_running_stats=True)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
    assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
    assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
    assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
    assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
    norm1.eval()
    norm2.eval()
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:33,代码来源:test_instance_norm.py

示例7: test_graph_size_norm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def test_graph_size_norm():
    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
    norm = GraphSizeNorm()
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:7,代码来源:test_graph_size_norm.py

示例8: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, seq_value_len_list):
        if self.supports_masking:
            uiseq_embed_list, mask = seq_value_len_list  # [B, T, E], [B, 1]
            mask = mask.float()
            user_behavior_length = torch.sum(mask, dim=-1, keepdim=True)
            mask = mask.unsqueeze(2)
        else:
            uiseq_embed_list, user_behavior_length = seq_value_len_list  # [B, T, E], [B, 1]
            mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1],
                                       dtype=torch.float32)  # [B, 1, maxlen]
            mask = torch.transpose(mask, 1, 2)  # [B, maxlen, 1]

        embedding_size = uiseq_embed_list.shape[-1]

        mask = torch.repeat_interleave(mask, embedding_size, dim=2)  # [B, maxlen, E]

        if self.mode == 'max':
            hist = uiseq_embed_list - (1 - mask) * 1e9
            hist = torch.max(hist, dim=1, keepdim=True)[0]
            return hist
        hist = uiseq_embed_list * mask.float()
        hist = torch.sum(hist, dim=1, keepdim=False)

        if self.mode == 'mean':
            hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps)

        hist = torch.unsqueeze(hist, dim=1)
        return hist 
开发者ID:shenweichen,项目名称:DeepCTR-Torch,代码行数:30,代码来源:sequence.py

示例9: repeat

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def repeat(input, repeats, dim):
    # return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
    if dim < 0:
        dim += input.dim()
    return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1) 
开发者ID:dmlc,项目名称:dgl,代码行数:7,代码来源:tensor.py

示例10: sample

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def sample(self, data):
    support, values = data
    mean, logvar = self.condition(support)
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    latent_sample = distribution.rsample()
    latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
    local_samples = torch.randn(support.size(0) * self.size, 16)
    sample = torch.cat((latent_sample, local_samples), dim=1)
    return (support, sample), (mean, logvar) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:11,代码来源:set_mnist_gan.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, data):
    support, values = data
    mean, logvar = self.encoder(support)
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    latent_sample = distribution.rsample()
    latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
    combined = torch.cat((values.view(-1, 28 * 28), latent_sample), dim=1)
    return self.verdict(combined) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:10,代码来源:set_mnist_gan.py

示例12: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, image, condition):
    image = image.view(-1, 3, 64, 64)
    out = self.input_process(self.input(image))
    mean, logvar = self.condition(condition)
    #distribution = Normal(mean, torch.exp(0.5 * logvar))
    sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
    cond = self.postprocess(sample)
    cond = torch.repeat_interleave(cond, 5, dim=0)
    result = self.combine(torch.cat((out, cond), dim=1))
    return result, (mean, logvar) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:12,代码来源:set_yeast_ebm.py

示例13: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def forward(self, image, condition):
    image = image.view(-1, 28 * 28)
    out = self.input_process(self.input(image))
    mean, logvar = self.condition(condition)
    #distribution = Normal(mean, torch.exp(0.5 * logvar))
    sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
    cond = self.postprocess(sample)
    cond = torch.repeat_interleave(cond, 5, dim=0)
    result = self.combine(torch.cat((out, cond), dim=1))
    return result, (mean, logvar) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:12,代码来源:set_mnist_ebm.py

示例14: repack

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def repack(data, indices, target_indices):
  out = torch.zeros(
    target_indices.size(0), *data.shape[1:],
    dtype=data.dtype, device=data.device
  )
  unique, lengths = indices.unique(return_counts=True)
  unique, target_lengths = target_indices.unique(return_counts=True)
  offset = target_lengths - lengths
  offset = offset.roll(1, 0)
  offset[0] = 0
  offset = torch.repeat_interleave(offset.cumsum(dim=0), lengths, dim=0)
  index = offset + torch.arange(len(indices)).to(data.device)

  out[index] = data
  return data, target_indices 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:17,代码来源:scatter.py

示例15: pairwise

# 需要导入模块: import torch [as 别名]
# 或者: from torch import repeat_interleave [as 别名]
def pairwise(op, data, indices, padding_value=0):
  padded, _, _, counts = pad(data, indices, value=padding_value)
  padded = padded.transpose(1, 2)
  reference = padded.unsqueeze(-1)
  padded = padded.unsqueeze(-2)
  op_result = op(padded, reference)

  # batch indices into pairwise tensor:
  batch_indices = torch.arange(counts.size(0))
  batch_indices = torch.repeat_interleave(batch_indices, counts ** 2)

  # first dimension indices:
  first_offset = counts.roll(1)
  first_offset[0] = 0
  first_offset = torch.cumsum(first_offset, dim=0)
  first_offset = torch.repeat_interleave(first_offset, counts)
  first_indices = torch.arange(counts.sum()) - first_offset
  first_indices = torch.repeat_interleave(
    first_indices,
    torch.repeat_interleave(counts, counts)
  )

  # second dimension indices:
  second_offset = torch.repeat_interleave(counts, counts).roll(1)
  second_offset[0] = 0
  second_offset = torch.cumsum(second_offset, dim=0)
  second_offset = torch.repeat_interleave(second_offset, torch.repeat_interleave(counts, counts))
  second_indices = torch.arange((counts ** 2).sum()) - second_offset

  # extract tensor from padded result using indices:
  result = op_result[batch_indices, first_indices, second_indices]

  # access: cumsum(counts ** 2)[idx] + counts[idx] * idy + idz
  access_batch = (counts ** 2).roll(1)
  access_batch[0] = 0
  access_batch = torch.cumsum(access_batch, dim=0)
  access_first = counts

  access = (access_batch, access_first)

  return result, batch_indices, first_indices, second_indices, access 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:43,代码来源:scatter.py


注:本文中的torch.repeat_interleave方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。