当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.gather()用法及代码示例


Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.gather() 函数用于根据规定的索引从规定的张量 x 轴收集片段。

用法:

tf.gather(x, indices, axis?, batchDims?)

Parameters: 

  • x:它是要收集其片段的指定输入张量,它可以是 tf.Tensor、TypedArray 或 Array 类型。
  • indices:它是要提取的值的规定索引,它可以是 tf.Tensor、TypedArray 或 Array 类型。
  • axis:要选择的值是指定的轴。默认值为零,并且它是数字类型。但是,此参数是可选的。
  • batchDims:它是批量大小的规定数量,应小于或等于规定的等级,即 index 。它的默认值为零。此外,返回的输出必须具有 x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:] 的形状。它是类型号并且是可选的。

返回值:它返回tf.Tensor对象。



范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input and indices
const y = tf.tensor1d([1, 6, 7, 8]);
const ind = tf.tensor1d([1, 6, 2], 'int32');
  
// Calling tf.gather() method and
// Printing output
y.gather(ind).print();

输出:

Tensor
    [6, NaN, 7]

范例2:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input, indices, axis,
// and batchdims
const y = tf.tensor2d([7, 8, 12, 13], [4, 1]);
const ind = tf.tensor1d([2, 3, 0], 'int32');
const axis = 1;
const batchdims = -1;
  
// Calling tf.gather() method
var res = tf.gather(y, ind, axis, batchdims);
  
// Printing output
res.print();

输出:

Tensor
    [[12 , 13 , 7 ],
     [13 , NaN, 8 ],
     [NaN, NaN, 12],
     [NaN, NaN, 13]]

参考: https://js.tensorflow.org/api/latest/#gather

相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.gather() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。