當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Tensorflow.js tf.truncatedNormal()用法及代碼示例

Tensorflow.js是由Google開發的開源庫,用於在瀏覽器或節點環境中運行機器學習模型以及深度學習神經網絡。

.truncatedNormal()函數用於查找tf.Tensor以及從截斷的正態分布求值的值。此外,此處作為輸出產生的值在規定的平均值和標準偏差的支持下遵循正態分布,但那些大小相對於平均值大於2個標準偏差的值將被丟棄並再次選擇。

用法:

tf.truncatedNormal(shape, mean?, stdDev?, dtype?, seed?)

參數:

  • shape:它是一個數組,其中包含描述輸出張量形狀的整數,並且類型為number []。
  • mean:它是正態分布的規定平均值,並且是類型編號。
  • stdDev:它是正態分布的標準偏差,是類型編號。
  • dtype:它是返回的輸出張量的聲明數據類型,可以是float32或int32類型。
  • seed:是說明的種子,有助於隨機數生成器,並且是類型數。

返回值:它返回tf.Tensor對象。



範例1:

Javascript


// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling truncatedNormal() method and
// Printing output
tf.truncatedNormal([3, 4]).print();

輸出:

Tensor
    [[-0.0277713, -0.4777073, -0.3911407, 1.85613   ],
     [-0.0667888, -0.0867875, 0.8295102 , -0.5933844],
     [0.5160138 , 0.7871808 , 0.6818511 , 1.2177598 ]]

範例2:

Javascript


// Importing the tensorflow.js library 
import * as tf from "@tensorflow/tfjs"
  
// Defining shape
var sh = [3, 2];
var mean = 4;
var st_dev = 5;
var dtyp = 'int32';
  
// Calling truncatedNormal() method
var res = tf.truncatedNormal(sh, mean, st_dev, dtyp);
  
// Printing output
res.print();

輸出:

Tensor
    [[-1, -5],
     [4 , 4 ],
     [11, 2 ]]

參考: https://js.tensorflow.org/api/latest/#truncatedNormal

相關用法


注:本文由純淨天空篩選整理自nidhi1352singh大神的英文原創作品 Tensorflow.js tf.truncatedNormal() Function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。