当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch frombuffer用法及代码示例


本文简要介绍python语言中 torch.frombuffer 的用法。

用法:

torch.frombuffer(buffer, *, dtype, count=- 1, offset=0, requires_grad=False) → Tensor

参数

buffer(object) -公开缓冲区接口的 Python 对象。

关键字参数

  • dtype(torch.dtype) -返回张量的所需数据类型。

  • count(int,可选的) -要读取的所需元素的数量。如果为负,则将读取所有元素(直到缓冲区末尾)。默认值:-1。

  • offset(int,可选的) -在缓冲区开始处要跳过的字节数。默认值:0。

  • requires_grad(bool,可选的) -如果 autograd 应该在返回的张量上记录操作。默认值:False

从实现 Python 缓冲区协议的对象创建一维 Tensor

跳过缓冲区中的第一个 offset 字节,并将其余原始字节解释为带有 count 元素的 dtype 类型的一维张量。

请注意,以下任一条件必须为真:

1. count 是一个非零正数,缓冲区中的总字节数小于 offset 加上 count 乘以 dtype 的大小(以字节为单位)。

2、count为负数,缓冲区的长度(字节数)减去offsetdtype的大小(以字节为单位)的倍数。

返回的张量和缓冲区共享相同的内存。对张量的修改将反映在缓冲区中,反之亦然。返回的张量不可调整大小。

注意

此函数增加拥有共享内存的对象的引用计数。因此,在返回的张量超出范围之前,不会释放此类内存。

警告

当传递一个实现缓冲区协议的对象时,该函数的行为是未定义的,该对象的数据不在 CPU 上。这样做很可能导致分段错误。

警告

此函数不会尝试推断 dtype(因此,它不是可选的)。传递与其源不同的dtype 可能会导致意外行为。

例子:

>>> import array
>>> a = array.array('i', [1, 2, 3])
>>> t = torch.frombuffer(a, dtype=torch.int32)
>>> t
tensor([ 1,  2,  3])
>>> t[0] = -1
>>> a
array([-1,  2,  3])

>>> # Interprets the signed char bytes as 32-bit integers.
>>> # Each 4 signed char elements will be interpreted as
>>> # 1 signed 32-bit integer.
>>> import array
>>> a = array.array('b', [-1, 0, 0, 0])
>>> torch.frombuffer(a, dtype=torch.int32)
tensor([255], dtype=torch.int32)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.frombuffer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。