当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.module.BaseModule用法及代码示例


用法:

class mxnet.module.BaseModule(logger=<module 'logging' from '/work/conda_env/lib/python3.9/logging/__init__.py'>)

基础:object

模块的基类。

一个模块代表一个计算组件。可以将模块视为一台计算机。模块可以执行前向和后向传递并更新模型中的参数。我们的目标是使 API 易于使用,特别是在我们需要使用命令式 API 来处理多个模块(例如随机深度网络)的情况下。

一个模块有几种状态:

  • 初始状态:尚未分配内存,因此模块尚未准备好进行计算。
  • 绑定:输入、输出和参数的形状都是已知的,内存已分配,模块已准备好进行计算。
  • 参数被初始化:对于带参数的模块,在初始化参数之前进行计算可能会导致未定义的输出。
  • 已安装优化器:可以将优化器安装到模块中。在此之后,可以在计算梯度后根据优化器更新模块的参数(forward-backward)。

属性

data_names

此模块所需数据的名称列表。

data_shapes

指定此模块的数据输入的(名称,形状)对列表。

label_shapes

(名称,形状)对的列表,指定此模块的标签输入。

output_names

此模块输出的名称列表。

output_shapes

(名称,形状)对的列表,指定此模块的输出。

symbol

获取与此模块关联的符号。

为了使模块与其他模块进行交互,它必须能够在其初始状态(绑定之前)报告以下信息:

  • data_names:类型字符串列表,指示所需输入数据的名称。
  • output_names :类型字符串列表,指示所需输出的名称。

绑定后,一个模块应该能够上报以下更丰富的信息:

  • 状态信息
    • bindedbool ,指示是否已分配计算所需的内存缓冲区。
    • for_training : 模块是否需要训练。
    • params_initialized : bool ,表示该模块的参数是否已经初始化。
    • optimizer_initialized : bool ,指示是否定义和初始化优化器。
    • inputs_need_gradbool,指示是否需要关于输入数据的梯度。在实现模块组合时可能很有用。
  • 输入/输出信息
    • data_shapes(name, shape) 的列表。理论上,由于分配了内存,我们可以直接提供数据数组。但是在数据并行的情况下,数据数组的形状可能与从外部世界看到的不同。
    • label_shapes(name, shape) 的列表。这可能是[] 如果模块不需要标签(例如,它在顶部不包含损失函数),或者模块未绑定到训练。
    • output_shapes :模块输出的(name, shape) 列表。
  • 参数(用于带参数的模块)
    • get_params() :返回一个元组 (arg_params, aux_params) 。其中每一个都是名称到NDArray 映射的字典。那些 NDArray 总是存在于 CPU 上。用于计算的实际参数可能存在于其他设备(GPU)上,此函数将检索(副本)最新参数。
    • set_params(arg_params, aux_params) :将参数分配给进行计算的设备。
    • init_params(...):更灵活的接口来分配或初始化参数。
  • 设置
    • bind():准备计算环境。
    • init_optimizer() :安装优化器以更新参数。
    • prepare() :根据当前数据批次准备模块。
  • 计算
    • forward(data_batch):正向操作。
    • backward(out_grads=None):向后操作。
    • update() : 根据安装的优化器更新参数。
    • get_outputs() : 获取前一个正向操作的输出。
    • get_input_grads() :获取相对于在先前反向操作中计算的输入的梯度。
    • update_metric(metric, labels, pre_sliced=False) :更新前向计算结果的性能指标。
  • 其他属性(主要是为了向后兼容)
    • symbol :此模块的基础符号图(如果有) 此属性不一定是恒定的。例如,对于 BucketingModule ,此属性只是使用的 current 符号。对于其他模块,这个值可能没有很好地定义。

当这些intermediate-level API 被正确实现时,以下高级 API 将自动可用于模块:

  • fit :在数据集上训练模块参数。
  • predict :对数据集运行预测并收集输出。
  • score :对数据集运行预测并评估性能。

例子

>>> # An example of creating a mxnet module.
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
>>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
>>> fc2  = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
>>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
>>> fc3  = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
>>> out  = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
>>> mod = mx.mod.Module(out)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.module.BaseModule。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。