当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.io.NDArrayIter用法及代码示例


用法:

class mxnet.io.NDArrayIter(data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', label_name='softmax_label')

参数

  • data(array or list of array or dict of string to array) - 输入数据。
  • label(array or list of array or dict of string to array, optional) - 输入标签。
  • batch_size(int) - 数据的批量大小。
  • shuffle(bool, optional) - 是否打乱数据。只有在没有使用 h5py.Dataset 输入时才支持。
  • last_batch_handle(str, optional) - 如何处理最后一批。此参数可以是‘pad’, ‘discard’ 或‘roll_over’。如果‘pad’,最后一批将从头开始填充数据如果‘discard’,最后一批将被丢弃如果‘roll_over’,剩余元素将滚动到下一次迭代并注意它是预期的用于训练,如果用于预测可能会导致问题。
  • data_name(str, optional) - 数据名称。
  • label_name(str, optional) - 标签名称。

基础:mxnet.io.io.DataIter

返回 mx.nd.NDArraynumpy.ndarrayh5py.Dataset mx.nd.sparse.CSRNDArrayscipy.sparse.csr_matrix 的迭代器。

例子

属性

provide_data

此迭代器提供的数据的名称和形状。

provide_label

此迭代器提供的标签的名称和形状。

>>> data = np.arange(40).reshape((10,2,2))
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> for batch in dataiter:
...     print batch.data[0].asnumpy()
...     batch.data[0].shape
...
[[[ 36.  37.]
  [ 38.  39.]]
 [[ 16.  17.]
  [ 18.  19.]]
 [[ 12.  13.]
  [ 14.  15.]]]
(3L, 2L, 2L)
[[[ 32.  33.]
  [ 34.  35.]]
 [[  4.   5.]
  [  6.   7.]]
 [[ 24.  25.]
  [ 26.  27.]]]
(3L, 2L, 2L)
[[[  8.   9.]
  [ 10.  11.]]
 [[ 20.  21.]
  [ 22.  23.]]
 [[ 28.  29.]
  [ 30.  31.]]]
(3L, 2L, 2L)
>>> dataiter.provide_data # Returns a list of `DataDesc`
[DataDesc[data,(3, 2L, 2L),<type 'numpy.float32'>,NCHW]]
>>> dataiter.provide_label # Returns a list of `DataDesc`
[DataDesc[softmax_label,(3, 1L),<type 'numpy.float32'>,NCHW]]

在上面的例子中,数据被洗牌,因为shuffle参数设置为True,其余的例子被丢弃,因为last_batch_handle参数设置为discard

last_batch_handle 参数的用法:

>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='pad')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx  # Padding added after the examples read are over. So, 10/3+1 batches are created.
4
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
3
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx # Remaining examples are rolled over to the next iteration.
3
>>> dataiter.reset()
>>> dataiter.next().data[0].asnumpy()
[[[ 36.  37.]
  [ 38.  39.]]
 [[ 0.  1.]
  [ 2.  3.]]
 [[ 4.  5.]
  [ 6.  7.]]]
(3L, 2L, 2L)

NDArrayIter 还支持多个输入和标签。

>>> data = {'data1':np.zeros(shape=(10,2,2)), 'data2':np.zeros(shape=(20,2,2))}
>>> label = {'label1':np.zeros(shape=(10,1)), 'label2':np.zeros(shape=(20,1))}
>>> dataiter = mx.io.NDArrayIter(data, label, 3, True, last_batch_handle='discard')

NDArrayIter 还支持 mx.nd.sparse.CSRNDArray 并将 last_batch_handle 设置为 discard

>>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr')
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, last_batch_handle='discard')
>>> [batch.data[0] for batch in dataiter]
[
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>]

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.io.NDArrayIter。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。