本文整理汇总了Python中torch.int8方法的典型用法代码示例。如果您正苦于以下问题:Python torch.int8方法的具体用法?Python torch.int8怎么用?Python torch.int8使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.int8方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: torch_dtype_to_np_dtype
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def torch_dtype_to_np_dtype(dtype):
dtype_dict = {
torch.bool : np.dtype(np.bool),
torch.uint8 : np.dtype(np.uint8),
torch.int8 : np.dtype(np.int8),
torch.int16 : np.dtype(np.int16),
torch.short : np.dtype(np.int16),
torch.int32 : np.dtype(np.int32),
torch.int : np.dtype(np.int32),
torch.int64 : np.dtype(np.int64),
torch.long : np.dtype(np.int64),
torch.float16 : np.dtype(np.float16),
torch.half : np.dtype(np.float16),
torch.float32 : np.dtype(np.float32),
torch.float : np.dtype(np.float32),
torch.float64 : np.dtype(np.float64),
torch.double : np.dtype(np.float64),
}
return dtype_dict[dtype]
# ---------------------- InferenceEngine internal types ------------------------
示例2: update_dtype
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def update_dtype(self, old_dtype):
updated = {}
for k, v in old_dtype.items():
if v == np.float32:
dt = torch.float32
elif v == np.float64:
dt = torch.float64
elif v == np.float16:
dt = torch.float16
elif v == np.uint8:
dt = torch.uint8
elif v == np.int8:
dt = torch.int8
elif v == np.int16:
dt = torch.int16
elif v == np.int32:
dt = torch.int32
elif v == np.int16:
dt = torch.int16
else:
raise ValueError("Unsupported dtype {}".format(v))
updated[k] = dt
return updated
示例3: broadcast_obj
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def broadcast_obj(self, obj, src, group=None):
"""Broadcasts a given object to all parties."""
if group is None:
group = self.main_group
if self.rank == src:
assert obj is not None, "src party must provide obj for broadcast"
buf = pickle.dumps(obj)
size = torch.tensor(len(buf), dtype=torch.int32)
arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8))
dist.broadcast(size, src, group=group)
dist.broadcast(arr, src, group=group)
else:
size = torch.tensor(1, dtype=torch.int32)
dist.broadcast(size, src, group=group)
data = torch.empty(size=(size,), dtype=torch.int8)
dist.broadcast(data, src, group=group)
buf = data.numpy().tobytes()
obj = serial.restricted_loads(buf)
return obj
示例4: testDtypes
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def testDtypes(self):
# Spot check a few.
config_str = """
# Test without torch prefix, but using the
# prefix is strongly recommended!
configurable.float32 = %float32
# Test with torch prefix.
configurable.int8 = %torch.int8
configurable.float16 = %torch.float16
"""
config.parse_config(config_str)
vals = configurable()
# pylint: disable=E1101
self.assertIs(vals['float32'], torch.float32)
self.assertIs(vals['int8'], torch.int8)
self.assertIs(vals['float16'], torch.float16)
# pylint: disable=E1101
示例5: sanitize_infinity
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def sanitize_infinity(dtype):
"""
Returns largest possible value for the specified dtype.
Parameters:
-----------
dtype: torch dtype
Returns:
--------
large_enough: largest possible value for the given dtype
"""
if dtype is torch.int8:
large_enough = (1 << 7) - 1
elif dtype is torch.int16:
large_enough = (1 << 15) - 1
elif dtype is torch.int32:
large_enough = (1 << 31) - 1
elif dtype is torch.int64:
large_enough = (1 << 63) - 1
else:
large_enough = float("inf")
return large_enough
示例6: pytorch_dtype_to_type
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def pytorch_dtype_to_type(dtype):
"""Map a pytorch dtype to a myia type."""
import torch
_type_map = {
torch.int8: Int[8],
torch.int16: Int[16],
torch.int32: Int[32],
torch.int64: Int[64],
torch.uint8: UInt[8],
torch.float16: Float[16],
torch.float32: Float[32],
torch.float64: Float[64],
torch.bool: Bool,
}
if dtype not in _type_map:
raise TypeError(f"Unsupported dtype {dtype}")
return _type_map[dtype]
示例7: test_export_to_8bit_with_bias
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_export_to_8bit_with_bias(self):
qlinear = QuantizedLinear(10, 5, mode="EMA")
qlinear.eval()
state_dict = qlinear.state_dict()
self.assertTrue("weight" in state_dict)
self.assertTrue("bias" in state_dict)
self.assertTrue("quantized_weight" not in state_dict)
self.assertTrue("_quantized_bias" not in state_dict)
self.assertTrue("bias_scale" not in state_dict)
qlinear.mode_8bit = True
state_dict = qlinear.state_dict()
self.assertTrue("weight" not in state_dict)
self.assertTrue("bias" not in state_dict)
self.assertTrue("quantized_weight" in state_dict)
self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
self.assertTrue("_quantized_bias" in state_dict)
self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32)
self.assertTrue("bias_scale" in state_dict)
qlinear.mode_8bit = False
state_dict = qlinear.state_dict()
self.assertTrue("weight" in state_dict)
self.assertTrue("bias" in state_dict)
self.assertTrue("quantized_weight" not in state_dict)
self.assertTrue("_quantized_bias" not in state_dict)
self.assertTrue("bias_scale" not in state_dict)
示例8: test_export_to_8bit_without_bias
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_export_to_8bit_without_bias(self):
qlinear = QuantizedLinear(10, 5, bias=False, mode="EMA")
qlinear.eval()
qlinear.mode_8bit = True
state_dict = qlinear.state_dict()
self.assertTrue("weight" not in state_dict)
self.assertTrue("bias" not in state_dict)
self.assertTrue("quantized_weight" in state_dict)
self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
self.assertTrue("_quantized_bias" not in state_dict)
self.assertTrue("bias_scale" not in state_dict)
qlinear.mode_8bit = False
state_dict = qlinear.state_dict()
self.assertTrue("weight" in state_dict)
self.assertTrue("bias" not in state_dict)
self.assertTrue("quantized_weight" not in state_dict)
self.assertTrue("_quantized_bias" not in state_dict)
self.assertTrue("bias_scale" not in state_dict)
示例9: _convert_dtype_value
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def _convert_dtype_value(val):
"""converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
convert_torch_dtype_map = {7:"torch.float64",
6:"torch.float32",
5:"torch.float16",
4:"torch.int64",
3:"torch.int32",
2:"torch.int16",
1:"torch.int8",
0:"torch.unit8",
None:"torch.int64"} # Default is torch.int64
if val in convert_torch_dtype_map:
return _convert_data_type(convert_torch_dtype_map[val])
else:
msg = "Torch data type value %d is not handled yet." % (val)
raise NotImplementedError(msg)
示例10: _create_typed_const
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def _create_typed_const(data, dtype):
"""create a (scalar) constant of given value and dtype.
dtype should be a TVM dtype"""
if dtype == "float64":
typed_data = _expr.const(np.float64(data), dtype=dtype)
elif dtype == "float32":
typed_data = _expr.const(np.float32(data), dtype=dtype)
elif dtype == "float16":
typed_data = _expr.const(np.float16(data), dtype=dtype)
elif dtype == "int64":
typed_data = _expr.const(np.int64(data), dtype=dtype)
elif dtype == "int32":
typed_data = _expr.const(np.int32(data), dtype=dtype)
elif dtype == "int16":
typed_data = _expr.const(np.int16(data), dtype=dtype)
elif dtype == "int8":
typed_data = _expr.const(np.int8(data), dtype=dtype)
elif dtype == "uint8":
typed_data = _expr.const(np.uint8(data), dtype=dtype)
else:
raise NotImplementedError("input_type {} is not handled yet".format(dtype))
return typed_data
示例11: test_forward_logical_not
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_forward_logical_not():
torch.set_grad_enabled(False)
class LogicalNot1(Module):
def forward(self, *args):
return torch.logical_not(args[0])
input_data = torch.tensor([True, False])
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
示例12: test_forward_bitwise_xor
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_forward_bitwise_xor():
torch.set_grad_enabled(False)
class BitwiseXor1(Module):
def forward(self, *args):
return torch.bitwise_xor(args[0], args[1])
class BitwiseXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.bitwise_xor(args[0], rhs)
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
示例13: test_forward_logical_xor
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_forward_logical_xor():
torch.set_grad_enabled(False)
class LogicalXor1(Module):
def forward(self, *args):
return torch.logical_xor(args[0], args[1])
class LogicalXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.logical_xor(args[0], rhs)
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
示例14: data_type_dict
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def data_type_dict():
return {'float16' : th.float16,
'float32' : th.float32,
'float64' : th.float64,
'uint8' : th.uint8,
'int8' : th.int8,
'int16' : th.int16,
'int32' : th.int32,
'int64' : th.int64,
'bool' : th.bool}
示例15: torch_dtype_to_trt
# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def torch_dtype_to_trt(dtype):
if dtype == torch.int8:
return trt.int8
elif dtype == torch.int32:
return trt.int32
elif dtype == torch.float16:
return trt.float16
elif dtype == torch.float32:
return trt.float32
else:
raise TypeError('%s is not supported by tensorrt' % dtype)