當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python mxnet.io.NDArrayIter用法及代碼示例

用法:

class mxnet.io.NDArrayIter(data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', label_name='softmax_label')

參數

  • data(array or list of array or dict of string to array) - 輸入數據。
  • label(array or list of array or dict of string to array, optional) - 輸入標簽。
  • batch_size(int) - 數據的批量大小。
  • shuffle(bool, optional) - 是否打亂數據。隻有在沒有使用 h5py.Dataset 輸入時才支持。
  • last_batch_handle(str, optional) - 如何處理最後一批。此參數可以是‘pad’, ‘discard’ 或‘roll_over’。如果‘pad’,最後一批將從頭開始填充數據如果‘discard’,最後一批將被丟棄如果‘roll_over’,剩餘元素將滾動到下一次迭代並注意它是預期的用於訓練,如果用於預測可能會導致問題。
  • data_name(str, optional) - 數據名稱。
  • label_name(str, optional) - 標簽名稱。

基礎:mxnet.io.io.DataIter

返回 mx.nd.NDArraynumpy.ndarrayh5py.Dataset mx.nd.sparse.CSRNDArrayscipy.sparse.csr_matrix 的迭代器。

例子

屬性

provide_data

此迭代器提供的數據的名稱和形狀。

provide_label

此迭代器提供的標簽的名稱和形狀。

>>> data = np.arange(40).reshape((10,2,2))
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> for batch in dataiter:
...     print batch.data[0].asnumpy()
...     batch.data[0].shape
...
[[[ 36.  37.]
  [ 38.  39.]]
 [[ 16.  17.]
  [ 18.  19.]]
 [[ 12.  13.]
  [ 14.  15.]]]
(3L, 2L, 2L)
[[[ 32.  33.]
  [ 34.  35.]]
 [[  4.   5.]
  [  6.   7.]]
 [[ 24.  25.]
  [ 26.  27.]]]
(3L, 2L, 2L)
[[[  8.   9.]
  [ 10.  11.]]
 [[ 20.  21.]
  [ 22.  23.]]
 [[ 28.  29.]
  [ 30.  31.]]]
(3L, 2L, 2L)
>>> dataiter.provide_data # Returns a list of `DataDesc`
[DataDesc[data,(3, 2L, 2L),<type 'numpy.float32'>,NCHW]]
>>> dataiter.provide_label # Returns a list of `DataDesc`
[DataDesc[softmax_label,(3, 1L),<type 'numpy.float32'>,NCHW]]

在上麵的例子中,數據被洗牌,因為shuffle參數設置為True,其餘的例子被丟棄,因為last_batch_handle參數設置為discard

last_batch_handle 參數的用法:

>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='pad')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx  # Padding added after the examples read are over. So, 10/3+1 batches are created.
4
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
3
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
>>> batchidx = 0
>>> for batch in dataiter:
...     batchidx += 1
...
>>> batchidx # Remaining examples are rolled over to the next iteration.
3
>>> dataiter.reset()
>>> dataiter.next().data[0].asnumpy()
[[[ 36.  37.]
  [ 38.  39.]]
 [[ 0.  1.]
  [ 2.  3.]]
 [[ 4.  5.]
  [ 6.  7.]]]
(3L, 2L, 2L)

NDArrayIter 還支持多個輸入和標簽。

>>> data = {'data1':np.zeros(shape=(10,2,2)), 'data2':np.zeros(shape=(20,2,2))}
>>> label = {'label1':np.zeros(shape=(10,1)), 'label2':np.zeros(shape=(20,1))}
>>> dataiter = mx.io.NDArrayIter(data, label, 3, True, last_batch_handle='discard')

NDArrayIter 還支持 mx.nd.sparse.CSRNDArray 並將 last_batch_handle 設置為 discard

>>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr')
>>> labels = np.ones([10, 1])
>>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, last_batch_handle='discard')
>>> [batch.data[0] for batch in dataiter]
[
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>,
<CSRNDArray 3x4 @cpu(0)>]

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.io.NDArrayIter。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。