当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch EmbeddingBag.from_pretrained用法及代码示例


本文简要介绍python语言中 torch.nn.EmbeddingBag.from_pretrained 的用法。

用法:

classmethod from_pretrained(embeddings, freeze=True, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, mode='mean', sparse=False, include_last_offset=False, padding_idx=None)

参数

  • embeddings(Tensor) -FloatTensor 包含 EmbeddingBag 的权重。第一个维度作为 ‘num_embeddings’ 传递给 EmbeddingBag,第二个维度作为 ‘embedding_dim’ 传递给 EmbeddingBag。

  • freeze(布尔值,可选的) -如果 True ,则张量在学习过程中不会更新。等效于 embeddingbag.weight.requires_grad = False 。默认值:True

  • max_norm(float,可选的) -请参阅模块初始化文档。默认值:None

  • norm_type(float,可选的) -请参阅模块初始化文档。默认 2

  • scale_grad_by_freq(布尔值,可选的) -请参阅模块初始化文档。默认 False

  • mode(string,可选的) -请参阅模块初始化文档。默认值:"mean"

  • sparse(bool,可选的) -请参阅模块初始化文档。默认值:False

  • include_last_offset(bool,可选的) -请参阅模块初始化文档。默认值:False

  • padding_idx(int,可选的) -请参阅模块初始化文档。默认值:None

从给定的二维 FloatTensor 创建 EmbeddingBag 实例。

例子:

>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([[1, 0]])
>>> embeddingbag(input)
tensor([[ 2.5000,  3.7000,  4.6500]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.EmbeddingBag.from_pretrained。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。