當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch frombuffer用法及代碼示例


本文簡要介紹python語言中 torch.frombuffer 的用法。

用法:

torch.frombuffer(buffer, *, dtype, count=- 1, offset=0, requires_grad=False) → Tensor

參數

buffer(object) -公開緩衝區接口的 Python 對象。

關鍵字參數

  • dtype(torch.dtype) -返回張量的所需數據類型。

  • count(int,可選的) -要讀取的所需元素的數量。如果為負,則將讀取所有元素(直到緩衝區末尾)。默認值:-1。

  • offset(int,可選的) -在緩衝區開始處要跳過的字節數。默認值:0。

  • requires_grad(bool,可選的) -如果 autograd 應該在返回的張量上記錄操作。默認值:False

從實現 Python 緩衝區協議的對象創建一維 Tensor

跳過緩衝區中的第一個 offset 字節,並將其餘原始字節解釋為帶有 count 元素的 dtype 類型的一維張量。

請注意,以下任一條件必須為真:

1. count 是一個非零正數,緩衝區中的總字節數小於 offset 加上 count 乘以 dtype 的大小(以字節為單位)。

2、count為負數,緩衝區的長度(字節數)減去offsetdtype的大小(以字節為單位)的倍數。

返回的張量和緩衝區共享相同的內存。對張量的修改將反映在緩衝區中,反之亦然。返回的張量不可調整大小。

注意

此函數增加擁有共享內存的對象的引用計數。因此,在返回的張量超出範圍之前,不會釋放此類內存。

警告

當傳遞一個實現緩衝區協議的對象時,該函數的行為是未定義的,該對象的數據不在 CPU 上。這樣做很可能導致分段錯誤。

警告

此函數不會嘗試推斷 dtype(因此,它不是可選的)。傳遞與其源不同的dtype 可能會導致意外行為。

例子:

>>> import array
>>> a = array.array('i', [1, 2, 3])
>>> t = torch.frombuffer(a, dtype=torch.int32)
>>> t
tensor([ 1,  2,  3])
>>> t[0] = -1
>>> a
array([-1,  2,  3])

>>> # Interprets the signed char bytes as 32-bit integers.
>>> # Each 4 signed char elements will be interpreted as
>>> # 1 signed 32-bit integer.
>>> import array
>>> a = array.array('b', [-1, 0, 0, 0])
>>> torch.frombuffer(a, dtype=torch.int32)
tensor([255], dtype=torch.int32)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.frombuffer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。