當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.gather()用法及代碼示例

Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.gather() 函數用於根據規定的索引從規定的張量 x 軸收集片段。

用法:

tf.gather(x, indices, axis?, batchDims?)

Parameters: 

  • x:它是要收集其片段的指定輸入張量,它可以是 tf.Tensor、TypedArray 或 Array 類型。
  • indices:它是要提取的值的規定索引,它可以是 tf.Tensor、TypedArray 或 Array 類型。
  • axis:要選擇的值是指定的軸。默認值為零,並且它是數字類型。但是,此參數是可選的。
  • batchDims:它是批量大小的規定數量,應小於或等於規定的等級,即 index 。它的默認值為零。此外,返回的輸出必須具有 x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:] 的形狀。它是類型號並且是可選的。

返回值:它返回tf.Tensor對象。



範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input and indices
const y = tf.tensor1d([1, 6, 7, 8]);
const ind = tf.tensor1d([1, 6, 2], 'int32');
  
// Calling tf.gather() method and
// Printing output
y.gather(ind).print();

輸出:

Tensor
    [6, NaN, 7]

範例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input, indices, axis,
// and batchdims
const y = tf.tensor2d([7, 8, 12, 13], [4, 1]);
const ind = tf.tensor1d([2, 3, 0], 'int32');
const axis = 1;
const batchdims = -1;
  
// Calling tf.gather() method
var res = tf.gather(y, ind, axis, batchdims);
  
// Printing output
res.print();

輸出:

Tensor
    [[12 , 13 , 7 ],
     [13 , NaN, 8 ],
     [NaN, NaN, 12],
     [NaN, NaN, 13]]

參考: https://js.tensorflow.org/api/latest/#gather

相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.gather() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。