当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch optimize_for_inference用法及代码示例


本文简要介绍python语言中 torch.jit.optimize_for_inference 的用法。

用法:

torch.jit.optimize_for_inference(mod)

执行一组优化传递以优化模型以进行推理。如果模型尚未冻结,optimize_for_inference 将自动调用 torch.jit.freeze

除了应该在任何环境下加速您的模型的通用优化之外,为推理做准备还将烘焙构建特定设置,例如 CUDNN 或 MKLDNN 的存在,并且将来可能会进行转换,从而在一台机器上加快速度但速度很慢事情就另当别论了。因此,在调用 optimize_for_inference 后未实现序列化,因此无法保证。

这仍处于原型中,可能会降低您的模型速度。到目前为止,主要针对的用例是 cpu 和 gpu 上的视觉模型,但程度较小。

示例(使用 Conv->Batchnorm 优化模块):

import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
assert "batch_norm" not in str(frozen_mod.graph)
# if built with MKLDNN, convolution will be run with MKLDNN weights
assert "MKLDNN" in frozen_mod.graph

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.optimize_for_inference。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。