本文整理汇总了Python中torch.jit.trace方法的典型用法代码示例。如果您正苦于以下问题:Python jit.trace方法的具体用法?Python jit.trace怎么用?Python jit.trace使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.jit
的用法示例。
在下文中一共展示了jit.trace方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __call__
# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def __call__(self, *args, **kwargs):
method_model = _ForwardOverrideModel(self.model, self.method_name)
try:
assert len(args) == 0, "only KV support implemented"
fn = getattr(self.model, self.method_name)
method_argnames = _get_input_argnames(fn=fn, exclude=["self"])
method_input = tuple(kwargs[name] for name in method_argnames)
self.tracing_result = jit.trace(method_model, method_input)
except Exception:
# for backward compatibility
self.tracing_result = jit.trace(method_model, *args, **kwargs)
output = self.model.forward(*args, **kwargs)
return output
示例2: translate
# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def translate(self):
translation_plan = self.plan.copy()
translation_plan.forward = None
args = translation_plan.create_dummy_args()
# jit.trace clones input args and can change their type, so we have to skip types check
# TODO see if type check can be made less strict,
# e.g. tensor/custom tensor/nn.Parameter could be considered same type
translation_plan.validate_input_types = False
# To avoid storing Plan state tensors in torchscript, they will be sent as parameters
# we trace wrapper func, which accepts state parameters as last arg
# and sets them into the Plan before executing the Plan
def wrap_stateful_plan(*args):
role = translation_plan.role
state = args[-1]
if 0 < len(role.state.state_placeholders) == len(state) and isinstance(
state, (list, tuple)
):
state_placeholders = tuple(
role.placeholders[ph.id.value] for ph in role.state.state_placeholders
)
PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state)
PlaceHolder.instantiate_placeholders(state_placeholders, state)
return translation_plan(*args[:-1])
plan_params = translation_plan.parameters()
if len(plan_params) > 0:
torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params))
else:
torchscript_plan = jit.trace(translation_plan, args)
self.plan.torchscript = torchscript_plan
return self.plan
示例3: __init__
# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def __init__(self, window_size=3):
super(SSIM, self).__init__()
gaussian_img_kernel = {'weight': create_gaussian_window(window_size, 3).float(),
'bias': torch.zeros(3)}
gaussian_blur = nn.Conv2d(3,3,window_size, padding=window_size//2, groups=3).to(device)
gaussian_blur.load_state_dict(gaussian_img_kernel)
self.gaussian_blur = trace(gaussian_blur, torch.rand(3, 3, 16, 16, dtype=torch.float32, device=device))
示例4: export_model
# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def export_model(model, path=None, input_shape=(1, 3, 64, 64)):
"""
Exports the model. If the model is a `ScriptModule`, it is saved as is. If not,
it is traced (with the given input_shape) and the resulting ScriptModule is saved
(this requires the `input_shape`, which defaults to the competition default).
Parameters
----------
model : torch.nn.Module or torch.jit.ScriptModule
Pytorch Module or a ScriptModule.
path : str
Path to the file where the model is saved. Defaults to the value set by the
`get_model_path` function above.
input_shape : tuple or list
Shape of the input to trace the module with. This is only required if model is not a
torch.jit.ScriptModule.
Returns
-------
str
Path to where the model is saved.
"""
path = get_model_path() if path is None else path
model = deepcopy(model).cpu().eval()
if not isinstance(model, torch.jit.ScriptModule):
assert input_shape is not None, "`input_shape` must be provided since model is not a " \
"`ScriptModule`."
traced_model = trace(model, torch.zeros(*input_shape))
else:
traced_model = model
torch.jit.save(traced_model, path)
return path
示例5: trace_model
# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def trace_model(
model: Model,
predict_fn: Callable,
batch=None,
method_name: str = "forward",
mode: str = "eval",
requires_grad: bool = False,
opt_level: str = None,
device: Device = "cpu",
predict_params: dict = None,
) -> jit.ScriptModule:
"""Traces model using runner and batch.
Args:
model: Model to trace
predict_fn: Function to run prediction with the model provided,
takes model, inputs parameters
batch: Batch to trace the model
method_name (str): Model's method name that will be
used as entrypoint during tracing
mode (str): Mode for model to trace (``train`` or ``eval``)
requires_grad (bool): Flag to use grads
opt_level (str): Apex FP16 init level, optional
device (str): Torch device
predict_params (dict): additional parameters for model forward
Returns:
jit.ScriptModule: Traced model
Raises:
ValueError: if both batch and predict_fn must be specified or
mode is not in 'eval' or 'train'.
"""
if batch is None or predict_fn is None:
raise ValueError("Both batch and predict_fn must be specified.")
if mode not in ["train", "eval"]:
raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'")
predict_params = predict_params or {}
tracer = _TracingModelWrapper(model, method_name)
if opt_level is not None:
assert_fp16_available()
# If traced in AMP we need to initialize the model before calling
# the jit
# https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950
from apex import amp
model = model.to(device)
model = amp.initialize(model, optimizers=None, opt_level=opt_level)
getattr(model, mode)()
set_requires_grad(model, requires_grad=requires_grad)
predict_fn(tracer, batch, **predict_params)
return tracer.tracing_result
示例6: save_traced_model
# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def save_traced_model(
model: jit.ScriptModule,
logdir: Union[str, Path] = None,
method_name: str = "forward",
mode: str = "eval",
requires_grad: bool = False,
opt_level: str = None,
out_dir: Union[str, Path] = None,
out_model: Union[str, Path] = None,
checkpoint_name: str = None,
) -> None:
"""Saves traced model.
Args:
model (ScriptModule): Traced model
logdir (Union[str, Path]): Path to experiment
method_name (str): Name of the method was traced
mode (str): Model's mode - `train` or `eval`
requires_grad (bool): Whether model was traced with require_grad or not
opt_level (str): Apex FP16 init level used during tracing
out_dir (Union[str, Path]): Directory to save model to
(overrides logdir)
out_model (Union[str, Path]): Path to save model to
(overrides logdir & out_dir)
checkpoint_name (str): Checkpoint name used to restore the model
Raises:
ValueError: if nothing out of `logdir`, `out_dir` or `out_model`
is specified.
"""
if out_model is None:
file_name = get_trace_name(
method_name=method_name,
mode=mode,
requires_grad=requires_grad,
opt_level=opt_level,
additional_string=checkpoint_name,
)
output: Path = out_dir
if output is None:
if logdir is None:
raise ValueError(
"One of `logdir`, `out_dir` or `out_model` "
"should be specified"
)
output: Path = Path(logdir) / "trace"
output.mkdir(exist_ok=True, parents=True)
out_model = str(output / file_name)
else:
out_model = str(out_model)
jit.save(model, out_model)