当前位置: 首页>>代码示例>>Python>>正文


Python jit.trace方法代码示例

本文整理汇总了Python中torch.jit.trace方法的典型用法代码示例。如果您正苦于以下问题:Python jit.trace方法的具体用法?Python jit.trace怎么用?Python jit.trace使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.jit的用法示例。


在下文中一共展示了jit.trace方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __call__

# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def __call__(self, *args, **kwargs):
        method_model = _ForwardOverrideModel(self.model, self.method_name)

        try:
            assert len(args) == 0, "only KV support implemented"

            fn = getattr(self.model, self.method_name)
            method_argnames = _get_input_argnames(fn=fn, exclude=["self"])
            method_input = tuple(kwargs[name] for name in method_argnames)

            self.tracing_result = jit.trace(method_model, method_input)
        except Exception:
            # for backward compatibility
            self.tracing_result = jit.trace(method_model, *args, **kwargs)
        output = self.model.forward(*args, **kwargs)

        return output 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:19,代码来源:trace.py

示例2: translate

# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def translate(self):
        translation_plan = self.plan.copy()
        translation_plan.forward = None

        args = translation_plan.create_dummy_args()

        # jit.trace clones input args and can change their type, so we have to skip types check
        # TODO see if type check can be made less strict,
        #  e.g. tensor/custom tensor/nn.Parameter could be considered same type
        translation_plan.validate_input_types = False

        # To avoid storing Plan state tensors in torchscript, they will be sent as parameters
        # we trace wrapper func, which accepts state parameters as last arg
        # and sets them into the Plan before executing the Plan
        def wrap_stateful_plan(*args):
            role = translation_plan.role
            state = args[-1]
            if 0 < len(role.state.state_placeholders) == len(state) and isinstance(
                state, (list, tuple)
            ):
                state_placeholders = tuple(
                    role.placeholders[ph.id.value] for ph in role.state.state_placeholders
                )
                PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state)
                PlaceHolder.instantiate_placeholders(state_placeholders, state)

            return translation_plan(*args[:-1])

        plan_params = translation_plan.parameters()
        if len(plan_params) > 0:
            torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params))
        else:
            torchscript_plan = jit.trace(translation_plan, args)

        self.plan.torchscript = torchscript_plan
        return self.plan 
开发者ID:OpenMined,项目名称:PySyft,代码行数:38,代码来源:torchscript.py

示例3: __init__

# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def __init__(self, window_size=3):
        super(SSIM, self).__init__()

        gaussian_img_kernel = {'weight': create_gaussian_window(window_size, 3).float(),
                               'bias': torch.zeros(3)}
        gaussian_blur = nn.Conv2d(3,3,window_size, padding=window_size//2, groups=3).to(device)
        gaussian_blur.load_state_dict(gaussian_img_kernel)
        self.gaussian_blur = trace(gaussian_blur, torch.rand(3, 3, 16, 16, dtype=torch.float32, device=device)) 
开发者ID:ClementPinard,项目名称:unsupervised-depthnet,代码行数:10,代码来源:ssim.py

示例4: export_model

# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def export_model(model, path=None, input_shape=(1, 3, 64, 64)):
    """
    Exports the model. If the model is a `ScriptModule`, it is saved as is. If not,
    it is traced (with the given input_shape) and the resulting ScriptModule is saved
    (this requires the `input_shape`, which defaults to the competition default).

    Parameters
    ----------
    model : torch.nn.Module or torch.jit.ScriptModule
        Pytorch Module or a ScriptModule.
    path : str
        Path to the file where the model is saved. Defaults to the value set by the
        `get_model_path` function above.
    input_shape : tuple or list
        Shape of the input to trace the module with. This is only required if model is not a
        torch.jit.ScriptModule.

    Returns
    -------
    str
        Path to where the model is saved.
    """
    path = get_model_path() if path is None else path
    model = deepcopy(model).cpu().eval()
    if not isinstance(model, torch.jit.ScriptModule):
        assert input_shape is not None, "`input_shape` must be provided since model is not a " \
                                        "`ScriptModule`."
        traced_model = trace(model, torch.zeros(*input_shape))
    else:
        traced_model = model
    torch.jit.save(traced_model, path)
    return path 
开发者ID:amir-abdi,项目名称:disentanglement-pytorch,代码行数:34,代码来源:utils_pytorch.py

示例5: trace_model

# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def trace_model(
    model: Model,
    predict_fn: Callable,
    batch=None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
    predict_params: dict = None,
) -> jit.ScriptModule:
    """Traces model using runner and batch.

    Args:
        model: Model to trace
        predict_fn: Function to run prediction with the model provided,
            takes model, inputs parameters
        batch: Batch to trace the model
        method_name (str): Model's method name that will be
            used as entrypoint during tracing
        mode (str): Mode for model to trace (``train`` or ``eval``)
        requires_grad (bool): Flag to use grads
        opt_level (str): Apex FP16 init level, optional
        device (str): Torch device
        predict_params (dict): additional parameters for model forward

    Returns:
        jit.ScriptModule: Traced model

    Raises:
        ValueError: if both batch and predict_fn must be specified or
          mode is not in 'eval' or 'train'.
    """
    if batch is None or predict_fn is None:
        raise ValueError("Both batch and predict_fn must be specified.")

    if mode not in ["train", "eval"]:
        raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'")

    predict_params = predict_params or {}

    tracer = _TracingModelWrapper(model, method_name)
    if opt_level is not None:
        assert_fp16_available()
        # If traced in AMP we need to initialize the model before calling
        # the jit
        # https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950
        from apex import amp

        model = model.to(device)
        model = amp.initialize(model, optimizers=None, opt_level=opt_level)

    getattr(model, mode)()
    set_requires_grad(model, requires_grad=requires_grad)

    predict_fn(tracer, batch, **predict_params)

    return tracer.tracing_result 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:60,代码来源:trace.py

示例6: save_traced_model

# 需要导入模块: from torch import jit [as 别名]
# 或者: from torch.jit import trace [as 别名]
def save_traced_model(
    model: jit.ScriptModule,
    logdir: Union[str, Path] = None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    out_dir: Union[str, Path] = None,
    out_model: Union[str, Path] = None,
    checkpoint_name: str = None,
) -> None:
    """Saves traced model.

    Args:
        model (ScriptModule): Traced model
        logdir (Union[str, Path]): Path to experiment
        method_name (str): Name of the method was traced
        mode (str): Model's mode - `train` or `eval`
        requires_grad (bool): Whether model was traced with require_grad or not
        opt_level (str): Apex FP16 init level used during tracing
        out_dir (Union[str, Path]): Directory to save model to
            (overrides logdir)
        out_model (Union[str, Path]): Path to save model to
            (overrides logdir & out_dir)
        checkpoint_name (str): Checkpoint name used to restore the model

    Raises:
        ValueError: if nothing out of `logdir`, `out_dir` or `out_model`
          is specified.
    """
    if out_model is None:
        file_name = get_trace_name(
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            additional_string=checkpoint_name,
        )

        output: Path = out_dir
        if output is None:
            if logdir is None:
                raise ValueError(
                    "One of `logdir`, `out_dir` or `out_model` "
                    "should be specified"
                )
            output: Path = Path(logdir) / "trace"

        output.mkdir(exist_ok=True, parents=True)

        out_model = str(output / file_name)
    else:
        out_model = str(out_model)

    jit.save(model, out_model) 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:57,代码来源:trace.py


注:本文中的torch.jit.trace方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。