用法:
class mxnet.module.BaseModule(logger=<module 'logging' from '/work/conda_env/lib/python3.9/logging/__init__.py'>)
基础:
object
模块的基类。
一个模块代表一个计算组件。可以将模块视为一台计算机。模块可以执行前向和后向传递并更新模型中的参数。我们的目标是使 API 易于使用,特别是在我们需要使用命令式 API 来处理多个模块(例如随机深度网络)的情况下。
一个模块有几种状态:
- 初始状态:尚未分配内存,因此模块尚未准备好进行计算。
- 绑定:输入、输出和参数的形状都是已知的,内存已分配,模块已准备好进行计算。
- 参数被初始化:对于带参数的模块,在初始化参数之前进行计算可能会导致未定义的输出。
- 已安装优化器:可以将优化器安装到模块中。在此之后,可以在计算梯度后根据优化器更新模块的参数(forward-backward)。
属性
此模块所需数据的名称列表。
指定此模块的数据输入的(名称,形状)对列表。
(名称,形状)对的列表,指定此模块的标签输入。
此模块输出的名称列表。
(名称,形状)对的列表,指定此模块的输出。
获取与此模块关联的符号。
为了使模块与其他模块进行交互,它必须能够在其初始状态(绑定之前)报告以下信息:
data_names
:类型字符串列表,指示所需输入数据的名称。output_names
:类型字符串列表,指示所需输出的名称。
绑定后,一个模块应该能够上报以下更丰富的信息:
- 状态信息
binded
:bool
,指示是否已分配计算所需的内存缓冲区。for_training
: 模块是否需要训练。params_initialized
:bool
,表示该模块的参数是否已经初始化。optimizer_initialized
:bool
,指示是否定义和初始化优化器。inputs_need_grad
:bool
,指示是否需要关于输入数据的梯度。在实现模块组合时可能很有用。
- 输入/输出信息
data_shapes
:(name, shape)
的列表。理论上,由于分配了内存,我们可以直接提供数据数组。但是在数据并行的情况下,数据数组的形状可能与从外部世界看到的不同。label_shapes
:(name, shape)
的列表。这可能是[]
如果模块不需要标签(例如,它在顶部不包含损失函数),或者模块未绑定到训练。output_shapes
:模块输出的(name, shape)
列表。
- 参数(用于带参数的模块)
get_params()
:返回一个元组(arg_params, aux_params)
。其中每一个都是名称到NDArray
映射的字典。那些NDArray
总是存在于 CPU 上。用于计算的实际参数可能存在于其他设备(GPU)上,此函数将检索(副本)最新参数。set_params(arg_params, aux_params)
:将参数分配给进行计算的设备。init_params(...)
:更灵活的接口来分配或初始化参数。
- 设置
bind()
:准备计算环境。init_optimizer()
:安装优化器以更新参数。prepare()
:根据当前数据批次准备模块。
- 计算
forward(data_batch)
:正向操作。backward(out_grads=None)
:向后操作。update()
: 根据安装的优化器更新参数。get_outputs()
: 获取前一个正向操作的输出。get_input_grads()
:获取相对于在先前反向操作中计算的输入的梯度。update_metric(metric, labels, pre_sliced=False)
:更新前向计算结果的性能指标。
- 其他属性(主要是为了向后兼容)
symbol
:此模块的基础符号图(如果有) 此属性不一定是恒定的。例如,对于BucketingModule
,此属性只是使用的current
符号。对于其他模块,这个值可能没有很好地定义。
当这些intermediate-level API 被正确实现时,以下高级 API 将自动可用于模块:
fit
:在数据集上训练模块参数。predict
:对数据集运行预测并收集输出。score
:对数据集运行预测并评估性能。
例子:
>>> # An example of creating a mxnet module. >>> import mxnet as mx >>> data = mx.symbol.Variable('data') >>> fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) >>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") >>> fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) >>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") >>> fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) >>> out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax') >>> mod = mx.mod.Module(out)
相关用法
- Python mxnet.module.BaseModule.get_outputs用法及代码示例
- Python mxnet.module.BaseModule.forward用法及代码示例
- Python mxnet.module.BaseModule.bind用法及代码示例
- Python mxnet.module.BaseModule.init_params用法及代码示例
- Python mxnet.module.BaseModule.get_params用法及代码示例
- Python mxnet.module.BaseModule.set_params用法及代码示例
- Python mxnet.module.BaseModule.update用法及代码示例
- Python mxnet.module.BaseModule.iter_predict用法及代码示例
- Python mxnet.module.BaseModule.save_params用法及代码示例
- Python mxnet.module.BaseModule.init_optimizer用法及代码示例
- Python mxnet.module.BaseModule.score用法及代码示例
- Python mxnet.module.BaseModule.fit用法及代码示例
- Python mxnet.module.BaseModule.update_metric用法及代码示例
- Python mxnet.module.BaseModule.predict用法及代码示例
- Python mxnet.module.BaseModule.get_input_grads用法及代码示例
- Python mxnet.module.BaseModule.backward用法及代码示例
- Python mxnet.module.BaseModule.load_params用法及代码示例
- Python mxnet.module.BucketingModule.set_params用法及代码示例
- Python mxnet.module.SequentialModule.add用法及代码示例
- Python mxnet.module.Module.set_params用法及代码示例
注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.module.BaseModule。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。