當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch ModuleDict用法及代碼示例


本文簡要介紹python語言中 torch.nn.ModuleDict 的用法。

用法:

class torch.nn.ModuleDict(modules=None)

參數

modules(可迭代的,可選的) -(字符串:模塊)的映射(字典)或類型(字符串,模塊)的鍵值對的可迭代

在字典中保存子模塊。

ModuleDict 可以像常規 Python 字典一樣進行索引,但它包含的模塊已正確注冊,並且所有 Module 方法都可見。

ModuleDict是一個排序尊重的字典

  • 插入順序,以及

  • update() 中,合並 OrderedDictdict (從 Python 3.6 開始)或另一個 ModuleDict (update() 的參數)的順序。

請注意,update() 與其他無序映射類型(例如,Python 3.6 版本之前的 Python 普通 dict)不會保留合並映射的順序。

例子:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.ModuleDict。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。