本文簡要介紹python語言中 torch.jit.freeze
的用法。
用法:
torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)
mod(
ScriptModule
) -要凍結的模塊preserved_attrs(可選的[List[str]]) -除了 forward 方法之外要保留的屬性列表。
modified in preserved methods will also be preserved.(屬性) -
optimize_numerics(bool) -如果
True
,將運行一組不嚴格的優化通道numerics. Full details of optimization can be found at torch.jit.run_frozen_optimizations.(保存) -
冷凍
ScriptModule
。凍結
ScriptModule
將克隆它,並嘗試將克隆模塊的子模塊、參數和屬性作為常量內聯到 TorchScript IR 圖表中。默認情況下,forward
將被保留,以及preserved_attrs
中指定的屬性和方法。此外,在保留方法中修改的任何屬性都將被保留。凍結當前僅接受處於評估模式的ScriptModules。
凍結應用通用優化,無論機器如何,都可以加速您的模型。要使用server-specific 設置進一步優化,請在凍結後運行
optimize_for_inference
。示例(使用參數凍結一個簡單的模塊):
def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code)
示例(凍結具有保留屬性的模塊)
def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13)
注意
如果您不確定為什麽某個屬性沒有作為常量內聯,您可以在 frozen_module.forward.graph 上運行
dump_alias_db
以查看凍結是否檢測到該屬性正在被修改。注意
由於凍結會使權重保持不變並刪除模塊層次結構,因此
to
和其他 nn.Module 操作設備或 dtype 的方法不再起作用。作為解決方法,您可以通過在torch.jit.load
中指定map_location
來重新映射設備,但是特定於設備的邏輯可能已融入模型中。
參數:
返回:
相關用法
- Python PyTorch frexp用法及代碼示例
- Python PyTorch frombuffer用法及代碼示例
- Python PyTorch fractional_max_pool3d用法及代碼示例
- Python PyTorch frac用法及代碼示例
- Python PyTorch fractional_max_pool2d用法及代碼示例
- Python PyTorch from_numpy用法及代碼示例
- Python PyTorch fft2用法及代碼示例
- Python PyTorch fftn用法及代碼示例
- Python PyTorch flip用法及代碼示例
- Python PyTorch float_power用法及代碼示例
- Python PyTorch floor_divide用法及代碼示例
- Python PyTorch fp16_compress_hook用法及代碼示例
- Python PyTorch fftshift用法及代碼示例
- Python PyTorch fake_quantize_per_channel_affine用法及代碼示例
- Python PyTorch flipud用法及代碼示例
- Python PyTorch fliplr用法及代碼示例
- Python PyTorch fp16_compress_wrapper用法及代碼示例
- Python PyTorch fftfreq用法及代碼示例
- Python PyTorch filter_wikipedia_xml用法及代碼示例
- Python PyTorch fuse_modules用法及代碼示例
- Python PyTorch fasterrcnn_mobilenet_v3_large_320_fpn用法及代碼示例
- Python PyTorch fmax用法及代碼示例
- Python PyTorch fork用法及代碼示例
- Python PyTorch fmin用法及代碼示例
- Python PyTorch fasterrcnn_resnet50_fpn用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.freeze。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。