当前位置: 首页>>代码示例>>Python>>正文


Python torch.int8方法代码示例

本文整理汇总了Python中torch.int8方法的典型用法代码示例。如果您正苦于以下问题:Python torch.int8方法的具体用法?Python torch.int8怎么用?Python torch.int8使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.int8方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: torch_dtype_to_np_dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def torch_dtype_to_np_dtype(dtype):
    dtype_dict = {
            torch.bool    : np.dtype(np.bool),
            torch.uint8   : np.dtype(np.uint8),
            torch.int8    : np.dtype(np.int8),
            torch.int16   : np.dtype(np.int16),
            torch.short   : np.dtype(np.int16),
            torch.int32   : np.dtype(np.int32),
            torch.int     : np.dtype(np.int32),
            torch.int64   : np.dtype(np.int64),
            torch.long    : np.dtype(np.int64),
            torch.float16 : np.dtype(np.float16),
            torch.half    : np.dtype(np.float16),
            torch.float32 : np.dtype(np.float32),
            torch.float   : np.dtype(np.float32),
            torch.float64 : np.dtype(np.float64),
            torch.double  : np.dtype(np.float64),
            }
    return dtype_dict[dtype]


# ---------------------- InferenceEngine internal types ------------------------ 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:24,代码来源:types.py

示例2: update_dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def update_dtype(self, old_dtype):
        updated = {}
        for k, v in old_dtype.items():
            if v == np.float32:
                dt = torch.float32
            elif v == np.float64:
                dt = torch.float64
            elif v == np.float16:
                dt = torch.float16
            elif v == np.uint8:
                dt = torch.uint8
            elif v == np.int8:
                dt = torch.int8
            elif v == np.int16:
                dt = torch.int16
            elif v == np.int32:
                dt = torch.int32
            elif v == np.int16:
                dt = torch.int16
            else:
                raise ValueError("Unsupported dtype {}".format(v))
            updated[k] = dt
        return updated 
开发者ID:heronsystems,项目名称:adeptRL,代码行数:25,代码来源:ops.py

示例3: broadcast_obj

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def broadcast_obj(self, obj, src, group=None):
        """Broadcasts a given object to all parties."""
        if group is None:
            group = self.main_group

        if self.rank == src:
            assert obj is not None, "src party must provide obj for broadcast"
            buf = pickle.dumps(obj)
            size = torch.tensor(len(buf), dtype=torch.int32)
            arr = torch.from_numpy(numpy.frombuffer(buf, dtype=numpy.int8))

            dist.broadcast(size, src, group=group)
            dist.broadcast(arr, src, group=group)
        else:
            size = torch.tensor(1, dtype=torch.int32)
            dist.broadcast(size, src, group=group)

            data = torch.empty(size=(size,), dtype=torch.int8)
            dist.broadcast(data, src, group=group)
            buf = data.numpy().tobytes()
            obj = serial.restricted_loads(buf)
        return obj 
开发者ID:facebookresearch,项目名称:CrypTen,代码行数:24,代码来源:distributed_communicator.py

示例4: testDtypes

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def testDtypes(self):
    # Spot check a few.
    config_str = """
      # Test without torch prefix, but using the
      # prefix is strongly recommended!
      configurable.float32 = %float32
      # Test with torch prefix.
      configurable.int8 = %torch.int8
      configurable.float16 = %torch.float16
    """
    config.parse_config(config_str)

    vals = configurable()
    # pylint: disable=E1101
    self.assertIs(vals['float32'], torch.float32)
    self.assertIs(vals['int8'], torch.int8)
    self.assertIs(vals['float16'], torch.float16)
    # pylint: disable=E1101 
开发者ID:google,项目名称:gin-config,代码行数:20,代码来源:external_configurables_test.py

示例5: sanitize_infinity

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def sanitize_infinity(dtype):
    """
    Returns largest possible value for the specified dtype.

    Parameters:
    -----------
    dtype: torch dtype

    Returns:
    --------
    large_enough: largest possible value for the given dtype
    """
    if dtype is torch.int8:
        large_enough = (1 << 7) - 1
    elif dtype is torch.int16:
        large_enough = (1 << 15) - 1
    elif dtype is torch.int32:
        large_enough = (1 << 31) - 1
    elif dtype is torch.int64:
        large_enough = (1 << 63) - 1
    else:
        large_enough = float("inf")

    return large_enough 
开发者ID:helmholtz-analytics,项目名称:heat,代码行数:26,代码来源:constants.py

示例6: pytorch_dtype_to_type

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def pytorch_dtype_to_type(dtype):
    """Map a pytorch dtype to a myia type."""
    import torch

    _type_map = {
        torch.int8: Int[8],
        torch.int16: Int[16],
        torch.int32: Int[32],
        torch.int64: Int[64],
        torch.uint8: UInt[8],
        torch.float16: Float[16],
        torch.float32: Float[32],
        torch.float64: Float[64],
        torch.bool: Bool,
    }
    if dtype not in _type_map:
        raise TypeError(f"Unsupported dtype {dtype}")
    return _type_map[dtype] 
开发者ID:mila-iqia,项目名称:myia,代码行数:20,代码来源:pytorch_abstract_types.py

示例7: test_export_to_8bit_with_bias

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_export_to_8bit_with_bias(self):
        qlinear = QuantizedLinear(10, 5, mode="EMA")
        qlinear.eval()
        state_dict = qlinear.state_dict()
        self.assertTrue("weight" in state_dict)
        self.assertTrue("bias" in state_dict)
        self.assertTrue("quantized_weight" not in state_dict)
        self.assertTrue("_quantized_bias" not in state_dict)
        self.assertTrue("bias_scale" not in state_dict)
        qlinear.mode_8bit = True
        state_dict = qlinear.state_dict()
        self.assertTrue("weight" not in state_dict)
        self.assertTrue("bias" not in state_dict)
        self.assertTrue("quantized_weight" in state_dict)
        self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
        self.assertTrue("_quantized_bias" in state_dict)
        self.assertTrue(state_dict["_quantized_bias"].dtype == torch.int32)
        self.assertTrue("bias_scale" in state_dict)
        qlinear.mode_8bit = False
        state_dict = qlinear.state_dict()
        self.assertTrue("weight" in state_dict)
        self.assertTrue("bias" in state_dict)
        self.assertTrue("quantized_weight" not in state_dict)
        self.assertTrue("_quantized_bias" not in state_dict)
        self.assertTrue("bias_scale" not in state_dict) 
开发者ID:NervanaSystems,项目名称:nlp-architect,代码行数:27,代码来源:test_quantization.py

示例8: test_export_to_8bit_without_bias

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_export_to_8bit_without_bias(self):
        qlinear = QuantizedLinear(10, 5, bias=False, mode="EMA")
        qlinear.eval()
        qlinear.mode_8bit = True
        state_dict = qlinear.state_dict()
        self.assertTrue("weight" not in state_dict)
        self.assertTrue("bias" not in state_dict)
        self.assertTrue("quantized_weight" in state_dict)
        self.assertTrue(state_dict["quantized_weight"].dtype == torch.int8)
        self.assertTrue("_quantized_bias" not in state_dict)
        self.assertTrue("bias_scale" not in state_dict)
        qlinear.mode_8bit = False
        state_dict = qlinear.state_dict()
        self.assertTrue("weight" in state_dict)
        self.assertTrue("bias" not in state_dict)
        self.assertTrue("quantized_weight" not in state_dict)
        self.assertTrue("_quantized_bias" not in state_dict)
        self.assertTrue("bias_scale" not in state_dict) 
开发者ID:NervanaSystems,项目名称:nlp-architect,代码行数:20,代码来源:test_quantization.py

示例9: _convert_dtype_value

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def _convert_dtype_value(val):
    """converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
    convert_torch_dtype_map = {7:"torch.float64",
                               6:"torch.float32",
                               5:"torch.float16",
                               4:"torch.int64",
                               3:"torch.int32",
                               2:"torch.int16",
                               1:"torch.int8",
                               0:"torch.unit8",
                               None:"torch.int64"} # Default is torch.int64
    if val in convert_torch_dtype_map:
        return _convert_data_type(convert_torch_dtype_map[val])
    else:
        msg = "Torch data type value %d is not handled yet." % (val)
        raise NotImplementedError(msg) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:18,代码来源:pytorch.py

示例10: _create_typed_const

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def _create_typed_const(data, dtype):
    """create a (scalar) constant of given value and dtype.
       dtype should be a TVM dtype"""

    if dtype == "float64":
        typed_data = _expr.const(np.float64(data), dtype=dtype)
    elif dtype == "float32":
        typed_data = _expr.const(np.float32(data), dtype=dtype)
    elif dtype == "float16":
        typed_data = _expr.const(np.float16(data), dtype=dtype)
    elif dtype == "int64":
        typed_data = _expr.const(np.int64(data), dtype=dtype)
    elif dtype == "int32":
        typed_data = _expr.const(np.int32(data), dtype=dtype)
    elif dtype == "int16":
        typed_data = _expr.const(np.int16(data), dtype=dtype)
    elif dtype == "int8":
        typed_data = _expr.const(np.int8(data), dtype=dtype)
    elif dtype == "uint8":
        typed_data = _expr.const(np.uint8(data), dtype=dtype)
    else:
        raise NotImplementedError("input_type {} is not handled yet".format(dtype))
    return typed_data 
开发者ID:apache,项目名称:incubator-tvm,代码行数:25,代码来源:pytorch.py

示例11: test_forward_logical_not

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_forward_logical_not():
    torch.set_grad_enabled(False)

    class LogicalNot1(Module):
        def forward(self, *args):
            return torch.logical_not(args[0])

    input_data = torch.tensor([True, False])
    verify_model(LogicalNot1().float().eval(), input_data=input_data)

    input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
    verify_model(LogicalNot1().float().eval(), input_data=input_data)

    input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
    verify_model(LogicalNot1().float().eval(), input_data=input_data)

    input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
    verify_model(LogicalNot1().float().eval(), input_data=input_data) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:20,代码来源:test_forward.py

示例12: test_forward_bitwise_xor

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_forward_bitwise_xor():
    torch.set_grad_enabled(False)

    class BitwiseXor1(Module):
        def forward(self, *args):
            return torch.bitwise_xor(args[0], args[1])

    class BitwiseXor2(Module):
        def forward(self, *args):
            rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
            if torch.cuda.is_available():
                rhs = rhs.cuda()
            return torch.bitwise_xor(args[0], rhs)

    lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
    rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
    verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])

    lhs = torch.tensor([True, True, False])
    rhs = torch.tensor([False, True, False])
    verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])

    lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
    verify_model(BitwiseXor2().float().eval(), input_data=[lhs]) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:26,代码来源:test_forward.py

示例13: test_forward_logical_xor

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def test_forward_logical_xor():
    torch.set_grad_enabled(False)

    class LogicalXor1(Module):
        def forward(self, *args):
            return torch.logical_xor(args[0], args[1])

    class LogicalXor2(Module):
        def forward(self, *args):
            rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
            if torch.cuda.is_available():
                rhs = rhs.cuda()
            return torch.logical_xor(args[0], rhs)

    lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
    rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
    verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])

    lhs = torch.tensor([True, True, False])
    rhs = torch.tensor([False, True, False])
    verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])

    lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
    verify_model(LogicalXor2().float().eval(), input_data=[lhs]) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:26,代码来源:test_forward.py

示例14: data_type_dict

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def data_type_dict():
    return {'float16' : th.float16,
            'float32' : th.float32,
            'float64' : th.float64,
            'uint8'   : th.uint8,
            'int8'    : th.int8,
            'int16'   : th.int16,
            'int32'   : th.int32,
            'int64'   : th.int64,
            'bool'    : th.bool} 
开发者ID:dmlc,项目名称:dgl,代码行数:12,代码来源:tensor.py

示例15: torch_dtype_to_trt

# 需要导入模块: import torch [as 别名]
# 或者: from torch import int8 [as 别名]
def torch_dtype_to_trt(dtype):
    if dtype == torch.int8:
        return trt.int8
    elif dtype == torch.int32:
        return trt.int32
    elif dtype == torch.float16:
        return trt.float16
    elif dtype == torch.float32:
        return trt.float32
    else:
        raise TypeError('%s is not supported by tensorrt' % dtype) 
开发者ID:tensorboy,项目名称:centerpose,代码行数:13,代码来源:tensorrt_model.py


注:本文中的torch.int8方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。