本文简要介绍python语言中 torchrec.modules.crossnet.LowRankMixtureCrossNet
的用法。
用法:
class torchrec.modules.crossnet.LowRankMixtureCrossNet(in_features: int, num_layers: int, num_experts: int = 1, low_rank: int = 1, activation: typing.Union[torch.nn.modules.module.Module, typing.Callable[[torch.Tensor], torch.Tensor]] = <built-in method relu of type object>)
in_features(int) -输入的维度。
num_layers(int) -模块中的层数。
low_rank(int) -交叉矩阵的秩设置(默认 = 0)。值必须始终 >= 0
activation(联盟[火炬.nn.模块,可调用[[torch.Tensor],torch.Tensor]]) - 非线性激活函数,用于定义专家。默认为 relu。
基础:
torch.nn.modules.module.Module
Low Rank Mixture Cross Net 是来自 paper 的 DCN V2 实现:
LowRankMixtureCrossNet
将每层的可学习交叉参数定义为低秩矩阵 以及专家的混合。与LowRankCrossNet
相比,该模块利用了 专家,而不是依赖一位专家来学习特征组合;每个学习特征在不同子空间中相互作用,并使用取决于输入 的门控机制自适应地组合学习到的交叉。在每一层 l 上,张量转换为:
每个 定义为:
其中 、 和 是低秩矩阵, 表示逐元素乘法, 表示矩阵乘法, 是非线性激活函数。
num_expert 为 1 时,将跳过门评估和 MOE 以节省计算。
例子:
batch_size = 3 num_layers = 2 in_features = 10 input = torch.randn(batch_size, in_features) dcn = LowRankCrossNet(num_layers=num_layers, num_experts=5, low_rank=3) output = dcn(input)
参数:
相关用法
- Python PyTorch LowRankMultivariateNormal用法及代码示例
- Python torchrec.modules.crossnet.LowRankCrossNet用法及代码示例
- Python PyTorch LogSigmoid用法及代码示例
- Python PyTorch LogNormal用法及代码示例
- Python PyTorch LocalResponseNorm用法及代码示例
- Python PyTorch LogSoftmax用法及代码示例
- Python PyTorch LocalElasticAgent用法及代码示例
- Python PyTorch LazyModuleMixin用法及代码示例
- Python PyTorch LinearLR用法及代码示例
- Python PyTorch LKJCholesky用法及代码示例
- Python PyTorch L1Loss用法及代码示例
- Python PyTorch LPPool2d用法及代码示例
- Python PyTorch LeakyReLU用法及代码示例
- Python PyTorch LayerNorm用法及代码示例
- Python PyTorch LineReader用法及代码示例
- Python PyTorch LambdaLR用法及代码示例
- Python PyTorch LSTM用法及代码示例
- Python PyTorch Linear用法及代码示例
- Python PyTorch LSTMCell用法及代码示例
- Python PyTorch LPPool1d用法及代码示例
- Python PyTorch Laplace用法及代码示例
- Python PyTorch LazyModuleExtensionMixin.apply用法及代码示例
- Python PyTorch LinearReLU用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.modules.crossnet.LowRankMixtureCrossNet。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。