當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.module.BaseModule用法及代碼示例


用法:

class mxnet.module.BaseModule(logger=<module 'logging' from '/work/conda_env/lib/python3.9/logging/__init__.py'>)

基礎:object

模塊的基類。

一個模塊代表一個計算組件。可以將模塊視為一台計算機。模塊可以執行前向和後向傳遞並更新模型中的參數。我們的目標是使 API 易於使用,特別是在我們需要使用命令式 API 來處理多個模塊(例如隨機深度網絡)的情況下。

一個模塊有幾種狀態:

  • 初始狀態:尚未分配內存,因此模塊尚未準備好進行計算。
  • 綁定:輸入、輸出和參數的形狀都是已知的,內存已分配,模塊已準備好進行計算。
  • 參數被初始化:對於帶參數的模塊,在初始化參數之前進行計算可能會導致未定義的輸出。
  • 已安裝優化器:可以將優化器安裝到模塊中。在此之後,可以在計算梯度後根據優化器更新模塊的參數(forward-backward)。

屬性

data_names

此模塊所需數據的名稱列表。

data_shapes

指定此模塊的數據輸入的(名稱,形狀)對列表。

label_shapes

(名稱,形狀)對的列表,指定此模塊的標簽輸入。

output_names

此模塊輸出的名稱列表。

output_shapes

(名稱,形狀)對的列表,指定此模塊的輸出。

symbol

獲取與此模塊關聯的符號。

為了使模塊與其他模塊進行交互,它必須能夠在其初始狀態(綁定之前)報告以下信息:

  • data_names:類型字符串列表,指示所需輸入數據的名稱。
  • output_names :類型字符串列表,指示所需輸出的名稱。

綁定後,一個模塊應該能夠上報以下更豐富的信息:

  • 狀態信息
    • bindedbool ,指示是否已分配計算所需的內存緩衝區。
    • for_training : 模塊是否需要訓練。
    • params_initialized : bool ,表示該模塊的參數是否已經初始化。
    • optimizer_initialized : bool ,指示是否定義和初始化優化器。
    • inputs_need_gradbool,指示是否需要關於輸入數據的梯度。在實現模塊組合時可能很有用。
  • 輸入/輸出信息
    • data_shapes(name, shape) 的列表。理論上,由於分配了內存,我們可以直接提供數據數組。但是在數據並行的情況下,數據數組的形狀可能與從外部世界看到的不同。
    • label_shapes(name, shape) 的列表。這可能是[] 如果模塊不需要標簽(例如,它在頂部不包含損失函數),或者模塊未綁定到訓練。
    • output_shapes :模塊輸出的(name, shape) 列表。
  • 參數(用於帶參數的模塊)
    • get_params() :返回一個元組 (arg_params, aux_params) 。其中每一個都是名稱到NDArray 映射的字典。那些 NDArray 總是存在於 CPU 上。用於計算的實際參數可能存在於其他設備(GPU)上,此函數將檢索(副本)最新參數。
    • set_params(arg_params, aux_params) :將參數分配給進行計算的設備。
    • init_params(...):更靈活的接口來分配或初始化參數。
  • 設置
    • bind():準備計算環境。
    • init_optimizer() :安裝優化器以更新參數。
    • prepare() :根據當前數據批次準備模塊。
  • 計算
    • forward(data_batch):正向操作。
    • backward(out_grads=None):向後操作。
    • update() : 根據安裝的優化器更新參數。
    • get_outputs() : 獲取前一個正向操作的輸出。
    • get_input_grads() :獲取相對於在先前反向操作中計算的輸入的梯度。
    • update_metric(metric, labels, pre_sliced=False) :更新前向計算結果的性能指標。
  • 其他屬性(主要是為了向後兼容)
    • symbol :此模塊的基礎符號圖(如果有) 此屬性不一定是恒定的。例如,對於 BucketingModule ,此屬性隻是使用的 current 符號。對於其他模塊,這個值可能沒有很好地定義。

當這些intermediate-level API 被正確實現時,以下高級 API 將自動可用於模塊:

  • fit :在數據集上訓練模塊參數。
  • predict :對數據集運行預測並收集輸出。
  • score :對數據集運行預測並評估性能。

例子

>>> # An example of creating a mxnet module.
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
>>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
>>> fc2  = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
>>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
>>> fc3  = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
>>> out  = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
>>> mod = mx.mod.Module(out)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。