当前位置: 首页>>代码示例>>Python>>正文

Python torch.jit方法代码示例

本文整理汇总了Python中torch.jit方法的典型用法代码示例。如果您正苦于以下问题:Python torch.jit方法的具体用法?Python torch.jit怎么用?Python torch.jit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


示例1: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def __init__(self, models, tgt_dict, max_iter=1, quantize=True, check_trace=True):
        src_tokens = torch.tensor([[4, 2]])
        src_lengths = torch.tensor([2])
        self.models = models

        generator = IterativeRefinementGenerator(
            self.models, tgt_dict, max_iter=max_iter
        if quantize:
            generator = torch.quantization.quantize_dynamic(
                generator, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
        enc_inputs = (src_tokens, src_lengths)
        self.generator = torch.jit.trace(
            generator, enc_inputs, _force_outplace=True, check_trace=check_trace

示例2: conv_flop_jit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def conv_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    This method counts the flops for convolution using torch script.

        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object before convolution.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object after convolution.

        Counter: A Counter dictionary that records the number of flops for each
    # Inputs of Convolution should be a list of length 12. They represent:
    # 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding,
    # 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn,
    # 10) deterministic_cudnn and 11) user_enabled_cudnn.
    assert len(inputs) == 12, len(inputs)
    x, w = inputs[:2]
    x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
    return conv_flop_count(x_shape, w_shape, out_shape) 

示例3: batchnorm_flop_jit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def batchnorm_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    This method counts the flops for batch norm.

        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object before batch norm.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object after batch norm.

        Counter: A Counter dictionary that records the number of flops for each
    # Inputs[0] contains the shape of the input.
    input_shape = get_shape(inputs[0])
    assert 2 <= len(input_shape) <= 5, input_shape
    flop = prod(input_shape) * 4
    flop_counter = Counter({"batchnorm": flop})
    return flop_counter 

示例4: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def forward(self, enc, dec):
        """ Forward pass
            enc: Tensor from the encoder pathway
            dec: Tensor from the decoder pathway (to be upconv'd)

        updec = self.upconv(dec)
        enc, updec = autocrop(enc, updec)
        genc, att = self.attention(enc, dec)
        if not torch.jit.is_scripting():
            self.att = att
        updec = self.norm0(updec)
        updec = self.act0(updec)
        if self.merge_mode == 'concat':
            mrg = torch.cat((updec, genc), 1)
            mrg = updec + genc
        y = self.conv1(mrg)
        y = self.norm1(y)
        y = self.act1(y)
        y = self.conv2(y)
        y = self.norm2(y)
        y = self.act2(y)
        return y 

示例5: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def __init__(self,

        self.module = module
        is_class = isinstance(module, torch.nn.Module)

        trace = torch.jit.trace(module, example_inputs, True)
        if not isinstance(example_inputs, (list, tuple)):
            example_inputs = [example_inputs]
        graph_py = parse(
            trace.graph, len(example_inputs), omit_useless_nodes=False, is_class=is_class)
        self.graph = graph_py
        self.trace = trace
        self.example_inputs = example_inputs

        msg = "input mismatch. this may due to some input isn't used in graph"
        assert len(example_inputs) + int(is_class) == len(graph_py.get_input_nodes_dict()), msg 

示例6: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def forward(self, input, states):
        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
        # List[LSTMState]: One state per layer
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
            i += 1
        return output, output_states

# Differs from StackedLSTM in that its forward method takes
# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
# except we don't support overriding script methods.
# https://github.com/pytorch/pytorch/issues/10733 

示例7: import_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def import_model(path=None):
    Imports a model (as ScriptModule) from file.

    path : str
        Path to where the model is saved. Defaults to the return value of the `get_model_path`
        function above.

        The model file.
    path = get_model_path() if path is None else path
    return torch.jit.load(path) 

示例8: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def forward(self, src_tokens, src_lengths):
        # (seq_length, batch_size) for compatibility with Caffe2
        src_tokens_seq_first = src_tokens.t()

        futures = []
        for model in self.models:
            # evaluation mode

                torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)

        return self.get_outputs(src_tokens, futures) 

示例9: save_to_pytorch

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def save_to_pytorch(self, output_path):
        def pack(s):
            if hasattr(s, "_pack"):

        def unpack(s):
            if hasattr(s, "_unpack"):

        torch.jit.save(self, output_path)

示例10: _get_all_end_states

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def _get_all_end_states(
        beam_tokens: Tensor,
        beam_scores: Tensor,
        beam_prev_indices: Tensor,
        num_steps: int,
    ) -> Tensor:
        min_score = float("inf")
        min_index = -1
        end_states = torch.jit.annotate(List[Tensor], [])

        position = 1
        while bool(position <= num_steps + 1):
            for hyp_index in range(self.beam_size):
                if bool(beam_tokens[position][hyp_index] == self.eos_token_id) or bool(
                    position == num_steps + 1
                    hypo_score = float(beam_scores[position][hyp_index])
                    if bool(self.length_penalty != 0):
                        hypo_score = hypo_score / float(position) ** float(
                    end_states, min_score, min_index = self._add_to_end_states(
                        torch.tensor([hypo_score, float(position), float(hyp_index)]),
            position = position + 1

        end_states = torch.stack(end_states)

        _, sorted_end_state_indices = end_states[:, 0].sort(dim=0, descending=True)
        end_states = end_states[sorted_end_state_indices, :]
        return end_states 

示例11: generic_activation_jit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def generic_activation_jit(
    op_name: str,
) -> typing.Callable[[typing.List[object], typing.List[object]], typing.Counter[str]]:
    This method return a handle that counts the number of activation from the
    output shape for the specified operation.

        op_name (str): The name of the operation.

        typing.Callable: An activation handle for the given operation.

    def _generic_activation_jit(outputs: typing.List[object]) -> int:
        This is a generic jit handle that counts the number of activations for any
        operation given the output shape.

            outputs (list(torch._C.Value)): The output shape in the form of a list
                of jit object.

            int: Total number of activations for each operation.
        out_shape = get_shape(outputs[0])
        ac_count = prod(out_shape)
        return ac_count

    return lambda inputs, outputs: Counter({op_name: _generic_activation_jit(outputs)}) 

示例12: get_shape

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def get_shape(val: object) -> typing.List[int]:
    Get the shapes from a jit value object.

        val (torch._C.Value): jit value object.

        list(int): return a list of ints.
    if val.isCompleteTensor():  # pyre-ignore
        return val.type().sizes()  # pyre-ignore
        raise ValueError() 

示例13: addmm_flop_jit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def addmm_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    This method counts the flops for fully connected layers with torch script.

        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object.

        Counter: A Counter dictionary that records the number of flops for each
    # Count flop for nn.Linear
    # inputs is a list of length 3.
    input_shapes = [get_shape(v) for v in inputs[1:3]]
    # input_shapes[0]: [batch size, input feature dimension]
    # input_shapes[1]: [batch size, output feature dimension]
    assert len(input_shapes[0]) == 2, input_shapes[0]
    assert len(input_shapes[1]) == 2, input_shapes[1]
    batch_size, input_dim = input_shapes[0]
    output_dim = input_shapes[1][1]
    flop = batch_size * input_dim * output_dim
    flop_counter = Counter({"addmm": flop})
    return flop_counter 

示例14: matmul_flop_jit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def matmul_flop_jit(
    inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
    This method counts the flops for matmul.

        inputs (list(torch._C.Value)): The input shape in the form of a list of
            jit object before matmul.
        outputs (list(torch._C.Value)): The output shape in the form of a list
            of jit object after matmul.

        Counter: A Counter dictionary that records the number of flops for each
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two matrices.
    input_shapes = [get_shape(v) for v in inputs]
    assert len(input_shapes) == 2, input_shapes
    assert len(input_shapes[1]) == 2, input_shapes
    assert input_shapes[0][-1] == input_shapes[1][0], input_shapes
    batch_dim = input_shapes[0][0]
    m1_dim, m2_dim = input_shapes[1]
    flop = m1_dim * m2_dim * batch_dim
    flop_counter = Counter({"matmul": flop})
    return flop_counter 

示例15: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def __call__(self, *args, **kwargs):
        method_model = _ForwardOverrideModel(self.model, self.method_name)
        example_inputs = {
            self.method_name: kwargs if len(kwargs) > 0 else args

        # noinspection PyTypeChecker
        self.tracing_result = torch.jit.trace(
            method_model, example_inputs=example_inputs
