当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Tensorflow.js tf.truncatedNormal()用法及代码示例


Tensorflow.js是由Google开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。

.truncatedNormal()函数用于查找tf.Tensor以及从截断的正态分布求值的值。此外,此处作为输出产生的值在规定的平均值和标准偏差的支持下遵循正态分布,但那些大小相对于平均值大于2个标准偏差的值将被丢弃并再次选择。

用法:

tf.truncatedNormal(shape, mean?, stdDev?, dtype?, seed?)

参数:

  • shape:它是一个数组,其中包含描述输出张量形状的整数,并且类型为number []。
  • mean:它是正态分布的规定平均值,并且是类型编号。
  • stdDev:它是正态分布的标准偏差,是类型编号。
  • dtype:它是返回的输出张量的声明数据类型,可以是float32或int32类型。
  • seed:是说明的种子,有助于随机数生成器,并且是类型数。

返回值:它返回tf.Tensor对象。



范例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling truncatedNormal() method and
// Printing output
tf.truncatedNormal([3, 4]).print();

输出:

Tensor
    [[-0.0277713, -0.4777073, -0.3911407, 1.85613   ],
     [-0.0667888, -0.0867875, 0.8295102 , -0.5933844],
     [0.5160138 , 0.7871808 , 0.6818511 , 1.2177598 ]]

范例2:

Javascript


// Importing the tensorflow.js library 
import * as tf from "@tensorflow/tfjs"
  
// Defining shape
var sh = [3, 2];
var mean = 4;
var st_dev = 5;
var dtyp = 'int32';
  
// Calling truncatedNormal() method
var res = tf.truncatedNormal(sh, mean, st_dev, dtyp);
  
// Printing output
res.print();

输出:

Tensor
    [[-1, -5],
     [4 , 4 ],
     [11, 2 ]]

参考: https://js.tensorflow.org/api/latest/#truncatedNormal

相关用法


注:本文由纯净天空筛选整理自nidhi1352singh大神的英文原创作品 Tensorflow.js tf.truncatedNormal() Function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。