當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch freeze用法及代碼示例


本文簡要介紹python語言中 torch.jit.freeze 的用法。

用法:

torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)

參數

  • mod(ScriptModule) -要凍結的模塊

  • preserved_attrs(可選的[List[str]]) -除了 forward 方法之外要保留的屬性列表。

  • modified in preserved methods will also be preserved.(屬性) -

  • optimize_numerics(bool) -如果 True ,將運行一組不嚴格的優化通道

  • numerics. Full details of optimization can be found at torch.jit.run_frozen_optimizations.(保存) -

返回

冷凍 ScriptModule

凍結 ScriptModule 將克隆它,並嘗試將克隆模塊的子模塊、參數和屬性作為常量內聯到 TorchScript IR 圖表中。默認情況下, forward 將被保留,以及 preserved_attrs 中指定的屬性和方法。此外,在保留方法中修改的任何屬性都將被保留。

凍結當前僅接受處於評估模式的ScriptModules。

凍結應用通用優化,無論機器如何,都可以加速您的模型。要使用server-specific 設置進一步優化,請在凍結後運行optimize_for_inference

示例(使用參數凍結一個簡單的模塊):

def forward(self, input):
        output = self.weight.mm(input)
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# parameters have been removed and inlined into the Graph as constants
assert len(list(frozen_module.named_parameters())) == 0
# See the compiled graph as Python code
print(frozen_module.code)

示例(凍結具有保留屬性的模塊)

def forward(self, input):
        self.modified_tensor += 1
        return input + self.modified_tensor

scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
# it to retain model semantics
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# now that we've run it once, the next result will be incremented by one
assert frozen_module(torch.tensor(1)) == torch.tensor(13)

注意

如果您不確定為什麽某個屬性沒有作為常量內聯,您可以在 frozen_module.forward.graph 上運行 dump_alias_db 以查看凍結是否檢測到該屬性正在被修改。

注意

由於凍結會使權重保持不變並刪除模塊層次結構,因此 to 和其他 nn.Module 操作設備或 dtype 的方法不再起作用。作為解決方法,您可以通過在 torch.jit.load 中指定 map_location 來重新映射設備,但是特定於設備的邏輯可能已融入模型中。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.freeze。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。