本文简要介绍python语言中 torch.jit.freeze
的用法。
用法:
torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)
mod(
ScriptModule
) -要冻结的模块preserved_attrs(可选的[List[str]]) -除了 forward 方法之外要保留的属性列表。
modified in preserved methods will also be preserved.(属性) -
optimize_numerics(bool) -如果
True
,将运行一组不严格的优化通道numerics. Full details of optimization can be found at torch.jit.run_frozen_optimizations.(保存) -
冷冻
ScriptModule
。冻结
ScriptModule
将克隆它,并尝试将克隆模块的子模块、参数和属性作为常量内联到 TorchScript IR 图表中。默认情况下,forward
将被保留,以及preserved_attrs
中指定的属性和方法。此外,在保留方法中修改的任何属性都将被保留。冻结当前仅接受处于评估模式的ScriptModules。
冻结应用通用优化,无论机器如何,都可以加速您的模型。要使用server-specific 设置进一步优化,请在冻结后运行
optimize_for_inference
。示例(使用参数冻结一个简单的模块):
def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code)
示例(冻结具有保留属性的模块)
def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13)
注意
如果您不确定为什么某个属性没有作为常量内联,您可以在 frozen_module.forward.graph 上运行
dump_alias_db
以查看冻结是否检测到该属性正在被修改。注意
由于冻结会使权重保持不变并删除模块层次结构,因此
to
和其他 nn.Module 操作设备或 dtype 的方法不再起作用。作为解决方法,您可以通过在torch.jit.load
中指定map_location
来重新映射设备,但是特定于设备的逻辑可能已融入模型中。
参数:
返回:
相关用法
- Python PyTorch frexp用法及代码示例
- Python PyTorch frombuffer用法及代码示例
- Python PyTorch fractional_max_pool3d用法及代码示例
- Python PyTorch frac用法及代码示例
- Python PyTorch fractional_max_pool2d用法及代码示例
- Python PyTorch from_numpy用法及代码示例
- Python PyTorch fft2用法及代码示例
- Python PyTorch fftn用法及代码示例
- Python PyTorch flip用法及代码示例
- Python PyTorch float_power用法及代码示例
- Python PyTorch floor_divide用法及代码示例
- Python PyTorch fp16_compress_hook用法及代码示例
- Python PyTorch fftshift用法及代码示例
- Python PyTorch fake_quantize_per_channel_affine用法及代码示例
- Python PyTorch flipud用法及代码示例
- Python PyTorch fliplr用法及代码示例
- Python PyTorch fp16_compress_wrapper用法及代码示例
- Python PyTorch fftfreq用法及代码示例
- Python PyTorch filter_wikipedia_xml用法及代码示例
- Python PyTorch fuse_modules用法及代码示例
- Python PyTorch fasterrcnn_mobilenet_v3_large_320_fpn用法及代码示例
- Python PyTorch fmax用法及代码示例
- Python PyTorch fork用法及代码示例
- Python PyTorch fmin用法及代码示例
- Python PyTorch fasterrcnn_resnet50_fpn用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.freeze。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。