用法:
class mxnet.module.BaseModule(logger=<module 'logging' from '/work/conda_env/lib/python3.9/logging/__init__.py'>)
基礎:
object
模塊的基類。
一個模塊代表一個計算組件。可以將模塊視為一台計算機。模塊可以執行前向和後向傳遞並更新模型中的參數。我們的目標是使 API 易於使用,特別是在我們需要使用命令式 API 來處理多個模塊(例如隨機深度網絡)的情況下。
一個模塊有幾種狀態:
- 初始狀態:尚未分配內存,因此模塊尚未準備好進行計算。
- 綁定:輸入、輸出和參數的形狀都是已知的,內存已分配,模塊已準備好進行計算。
- 參數被初始化:對於帶參數的模塊,在初始化參數之前進行計算可能會導致未定義的輸出。
- 已安裝優化器:可以將優化器安裝到模塊中。在此之後,可以在計算梯度後根據優化器更新模塊的參數(forward-backward)。
屬性
此模塊所需數據的名稱列表。
指定此模塊的數據輸入的(名稱,形狀)對列表。
(名稱,形狀)對的列表,指定此模塊的標簽輸入。
此模塊輸出的名稱列表。
(名稱,形狀)對的列表,指定此模塊的輸出。
獲取與此模塊關聯的符號。
為了使模塊與其他模塊進行交互,它必須能夠在其初始狀態(綁定之前)報告以下信息:
data_names
:類型字符串列表,指示所需輸入數據的名稱。output_names
:類型字符串列表,指示所需輸出的名稱。
綁定後,一個模塊應該能夠上報以下更豐富的信息:
- 狀態信息
binded
:bool
,指示是否已分配計算所需的內存緩衝區。for_training
: 模塊是否需要訓練。params_initialized
:bool
,表示該模塊的參數是否已經初始化。optimizer_initialized
:bool
,指示是否定義和初始化優化器。inputs_need_grad
:bool
,指示是否需要關於輸入數據的梯度。在實現模塊組合時可能很有用。
- 輸入/輸出信息
data_shapes
:(name, shape)
的列表。理論上,由於分配了內存,我們可以直接提供數據數組。但是在數據並行的情況下,數據數組的形狀可能與從外部世界看到的不同。label_shapes
:(name, shape)
的列表。這可能是[]
如果模塊不需要標簽(例如,它在頂部不包含損失函數),或者模塊未綁定到訓練。output_shapes
:模塊輸出的(name, shape)
列表。
- 參數(用於帶參數的模塊)
get_params()
:返回一個元組(arg_params, aux_params)
。其中每一個都是名稱到NDArray
映射的字典。那些NDArray
總是存在於 CPU 上。用於計算的實際參數可能存在於其他設備(GPU)上,此函數將檢索(副本)最新參數。set_params(arg_params, aux_params)
:將參數分配給進行計算的設備。init_params(...)
:更靈活的接口來分配或初始化參數。
- 設置
bind()
:準備計算環境。init_optimizer()
:安裝優化器以更新參數。prepare()
:根據當前數據批次準備模塊。
- 計算
forward(data_batch)
:正向操作。backward(out_grads=None)
:向後操作。update()
: 根據安裝的優化器更新參數。get_outputs()
: 獲取前一個正向操作的輸出。get_input_grads()
:獲取相對於在先前反向操作中計算的輸入的梯度。update_metric(metric, labels, pre_sliced=False)
:更新前向計算結果的性能指標。
- 其他屬性(主要是為了向後兼容)
symbol
:此模塊的基礎符號圖(如果有) 此屬性不一定是恒定的。例如,對於BucketingModule
,此屬性隻是使用的current
符號。對於其他模塊,這個值可能沒有很好地定義。
當這些intermediate-level API 被正確實現時,以下高級 API 將自動可用於模塊:
fit
:在數據集上訓練模塊參數。predict
:對數據集運行預測並收集輸出。score
:對數據集運行預測並評估性能。
例子:
>>> # An example of creating a mxnet module. >>> import mxnet as mx >>> data = mx.symbol.Variable('data') >>> fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) >>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") >>> fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) >>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") >>> fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) >>> out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax') >>> mod = mx.mod.Module(out)
相關用法
- Python mxnet.module.BaseModule.get_outputs用法及代碼示例
- Python mxnet.module.BaseModule.forward用法及代碼示例
- Python mxnet.module.BaseModule.bind用法及代碼示例
- Python mxnet.module.BaseModule.init_params用法及代碼示例
- Python mxnet.module.BaseModule.get_params用法及代碼示例
- Python mxnet.module.BaseModule.set_params用法及代碼示例
- Python mxnet.module.BaseModule.update用法及代碼示例
- Python mxnet.module.BaseModule.iter_predict用法及代碼示例
- Python mxnet.module.BaseModule.save_params用法及代碼示例
- Python mxnet.module.BaseModule.init_optimizer用法及代碼示例
- Python mxnet.module.BaseModule.score用法及代碼示例
- Python mxnet.module.BaseModule.fit用法及代碼示例
- Python mxnet.module.BaseModule.update_metric用法及代碼示例
- Python mxnet.module.BaseModule.predict用法及代碼示例
- Python mxnet.module.BaseModule.get_input_grads用法及代碼示例
- Python mxnet.module.BaseModule.backward用法及代碼示例
- Python mxnet.module.BaseModule.load_params用法及代碼示例
- Python mxnet.module.BucketingModule.set_params用法及代碼示例
- Python mxnet.module.SequentialModule.add用法及代碼示例
- Python mxnet.module.Module.set_params用法及代碼示例
注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.module.BaseModule。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。