本文整理汇总了Python中torch.jit方法的典型用法代码示例。如果您正苦于以下问题:Python torch.jit方法的具体用法?Python torch.jit怎么用?Python torch.jit使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.jit方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def __init__(self, models, tgt_dict, max_iter=1, quantize=True, check_trace=True):
super().__init__()
src_tokens = torch.tensor([[4, 2]])
src_lengths = torch.tensor([2])
self.models = models
generator = IterativeRefinementGenerator(
self.models, tgt_dict, max_iter=max_iter
)
if quantize:
generator = torch.quantization.quantize_dynamic(
generator, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
)
enc_inputs = (src_tokens, src_lengths)
self.generator = torch.jit.trace(
generator, enc_inputs, _force_outplace=True, check_trace=check_trace
)
示例2: conv_flop_jit
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def conv_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for convolution using torch script.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object before convolution.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object after convolution.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Inputs of Convolution should be a list of length 12. They represent:
# 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding,
# 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn,
# 10) deterministic_cudnn and 11) user_enabled_cudnn.
assert len(inputs) == 12, len(inputs)
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
return conv_flop_count(x_shape, w_shape, out_shape)
示例3: batchnorm_flop_jit
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def batchnorm_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for batch norm.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object before batch norm.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object after batch norm.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Inputs[0] contains the shape of the input.
input_shape = get_shape(inputs[0])
assert 2 <= len(input_shape) <= 5, input_shape
flop = prod(input_shape) * 4
flop_counter = Counter({"batchnorm": flop})
return flop_counter
示例4: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def forward(self, enc, dec):
""" Forward pass
Arguments:
enc: Tensor from the encoder pathway
dec: Tensor from the decoder pathway (to be upconv'd)
"""
updec = self.upconv(dec)
enc, updec = autocrop(enc, updec)
genc, att = self.attention(enc, dec)
if not torch.jit.is_scripting():
self.att = att
updec = self.norm0(updec)
updec = self.act0(updec)
if self.merge_mode == 'concat':
mrg = torch.cat((updec, genc), 1)
else:
mrg = updec + genc
y = self.conv1(mrg)
y = self.norm1(y)
y = self.act1(y)
y = self.conv2(y)
y = self.norm2(y)
y = self.act2(y)
return y
示例5: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def __init__(self,
module,
example_inputs):
super().__init__()
self.module = module
is_class = isinstance(module, torch.nn.Module)
trace = torch.jit.trace(module, example_inputs, True)
if not isinstance(example_inputs, (list, tuple)):
example_inputs = [example_inputs]
graph_py = parse(
trace.graph, len(example_inputs), omit_useless_nodes=False, is_class=is_class)
self.graph = graph_py
self.trace = trace
self.example_inputs = example_inputs
msg = "input mismatch. this may due to some input isn't used in graph"
assert len(example_inputs) + int(is_class) == len(graph_py.get_input_nodes_dict()), msg
示例6: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
i += 1
return output, output_states
# Differs from StackedLSTM in that its forward method takes
# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
# except we don't support overriding script methods.
# https://github.com/pytorch/pytorch/issues/10733
示例7: import_model
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def import_model(path=None):
"""
Imports a model (as ScriptModule) from file.
Parameters
----------
path : str
Path to where the model is saved. Defaults to the return value of the `get_model_path`
function above.
Returns
-------
torch.jit.ScriptModule
The model file.
"""
path = get_model_path() if path is None else path
return torch.jit.load(path)
示例8: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def forward(self, src_tokens, src_lengths):
# (seq_length, batch_size) for compatibility with Caffe2
src_tokens_seq_first = src_tokens.t()
futures = []
for model in self.models:
# evaluation mode
model.eval()
futures.append(
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
)
return self.get_outputs(src_tokens, futures)
示例9: save_to_pytorch
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def save_to_pytorch(self, output_path):
def pack(s):
if hasattr(s, "_pack"):
s._pack()
def unpack(s):
if hasattr(s, "_unpack"):
s._unpack()
self.apply(pack)
torch.jit.save(self, output_path)
self.apply(unpack)
示例10: _get_all_end_states
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def _get_all_end_states(
self,
beam_tokens: Tensor,
beam_scores: Tensor,
beam_prev_indices: Tensor,
num_steps: int,
) -> Tensor:
min_score = float("inf")
min_index = -1
end_states = torch.jit.annotate(List[Tensor], [])
position = 1
while bool(position <= num_steps + 1):
for hyp_index in range(self.beam_size):
if bool(beam_tokens[position][hyp_index] == self.eos_token_id) or bool(
position == num_steps + 1
):
hypo_score = float(beam_scores[position][hyp_index])
if bool(self.length_penalty != 0):
hypo_score = hypo_score / float(position) ** float(
self.length_penalty
)
end_states, min_score, min_index = self._add_to_end_states(
end_states,
min_score,
torch.tensor([hypo_score, float(position), float(hyp_index)]),
min_index,
)
position = position + 1
end_states = torch.stack(end_states)
_, sorted_end_state_indices = end_states[:, 0].sort(dim=0, descending=True)
end_states = end_states[sorted_end_state_indices, :]
return end_states
示例11: generic_activation_jit
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def generic_activation_jit(
op_name: str,
) -> typing.Callable[[typing.List[object], typing.List[object]], typing.Counter[str]]:
"""
This method return a handle that counts the number of activation from the
output shape for the specified operation.
Args:
op_name (str): The name of the operation.
Returns:
typing.Callable: An activation handle for the given operation.
"""
def _generic_activation_jit(outputs: typing.List[object]) -> int:
"""
This is a generic jit handle that counts the number of activations for any
operation given the output shape.
Args:
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object.
Returns:
int: Total number of activations for each operation.
"""
out_shape = get_shape(outputs[0])
ac_count = prod(out_shape)
return ac_count
return lambda inputs, outputs: Counter({op_name: _generic_activation_jit(outputs)})
示例12: get_shape
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def get_shape(val: object) -> typing.List[int]:
"""
Get the shapes from a jit value object.
Args:
val (torch._C.Value): jit value object.
Returns:
list(int): return a list of ints.
"""
if val.isCompleteTensor(): # pyre-ignore
return val.type().sizes() # pyre-ignore
else:
raise ValueError()
示例13: addmm_flop_jit
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def addmm_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for fully connected layers with torch script.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flop = batch_size * input_dim * output_dim
flop_counter = Counter({"addmm": flop})
return flop_counter
示例14: matmul_flop_jit
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def matmul_flop_jit(
inputs: typing.List[object], outputs: typing.List[object]
) -> typing.Counter[str]:
"""
This method counts the flops for matmul.
Args:
inputs (list(torch._C.Value)): The input shape in the form of a list of
jit object before matmul.
outputs (list(torch._C.Value)): The output shape in the form of a list
of jit object after matmul.
Returns:
Counter: A Counter dictionary that records the number of flops for each
operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert len(input_shapes[1]) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][0], input_shapes
batch_dim = input_shapes[0][0]
m1_dim, m2_dim = input_shapes[1]
flop = m1_dim * m2_dim * batch_dim
flop_counter = Counter({"matmul": flop})
return flop_counter
示例15: __call__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import jit [as 别名]
def __call__(self, *args, **kwargs):
method_model = _ForwardOverrideModel(self.model, self.method_name)
example_inputs = {
self.method_name: kwargs if len(kwargs) > 0 else args
}
# noinspection PyTypeChecker
self.tracing_result = torch.jit.trace(
method_model, example_inputs=example_inputs
)