當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch sparse_csr_tensor用法及代碼示例


本文簡要介紹python語言中 torch.sparse_csr_tensor 的用法。

用法:

torch.sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) → Tensor

參數

  • crow_indices(array_like) -大小為 size[0] + 1 的一維數組。最後一個元素是非零的數量。此張量根據給定行的開始位置對值和col_indices 中的索引進行編碼。張量中的每個連續數字減去它之前的數字表示給定行中的元素數。

  • col_indices(array_like) -values 中每個元素的列坐標。與值具有相同長度的嚴格一維張量。

  • values(array_list) -張量的初始值。可以是列表、元組、NumPy ndarray、標量和其他類型。

  • size(列表,元組,torch.Size, 可選的) -稀疏張量的大小。如果未提供,大小將被推斷為足以容納所有非零元素的最小大小。

關鍵字參數

  • dtype(torch.dtype, 可選的) -返回張量的所需數據類型。默認值:如果無,則從 values 推斷數據類型。

  • device(torch.device, 可選的) -返回張量的所需設備。默認值:如果沒有,則使用當前設備作為默認張量類型(請參閱 torch.set_default_tensor_type() )。對於 CPU 張量類型,device 將是 CPU;對於 CUDA 張量類型,device 將是當前的 CUDA 設備。

  • requires_grad(bool,可選的) -如果 autograd 應該在返回的張量上記錄操作。默認值:False

在給定的 crow_indicescol_indices 處使用指定值構造 CSR(壓縮稀疏行)中的稀疏張量。 CSR 格式的稀疏矩陣乘法運算通常比 COO 格式的稀疏張量更快。讓您看一下關於索引數據類型的注釋。

例子::
>>> crow_indices = [0, 2, 4]
>>> col_indices = [0, 1, 0, 1]
>>> values = [1, 2, 3, 4]
>>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
...                         torch.tensor(col_indices, dtype=torch.int64),
...                         torch.tensor(values), dtype=torch.double)
tensor(crow_indices=tensor([0, 2, 4]),
       col_indices=tensor([0, 1, 0, 1]),
       values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
       dtype=torch.float64, layout=torch.sparse_csr)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.sparse_csr_tensor。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。