當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.nn.raw_rnn用法及代碼示例


創建由 RNNCell cell 和循環函數 loop_fn 指定的 RNN

用法

tf.compat.v1.nn.raw_rnn(
    cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None
)

參數

  • cell RNNCell 的一個實例。
  • loop_fn 接受輸入的可調用對象(time, cell_output, cell_state, loop_state)並返回元組(finished, next_input, next_cell_state, emit_output, next_loop_state).這裏time是一個 int32 標量Tensor,cell_output是一個Tensor或(可能是嵌套的)張量元組,由下式確定cell.output_size, 和cell_state是一個Tensor或(可能是嵌套的)張量元組,由loop_fn在第一次調用時(並且應該匹配cell.state_size)。輸出是:finished, 一個布爾值Tensor形狀的[batch_size],next_input: 下一個要輸入的輸入cell,next_cell_state: 下一個要喂到的狀態cell, 和emit_output: 本次迭代存儲的輸出。注意emit_output應該是一個Tensor或(可能是嵌套的)張量元組,聚合在emit_ta在 - 的裏麵while_loop.第一次調用loop_fn, 這emit_output對應於emit_structure然後用於確定zero_tensor為了emit_ta(默認為cell.output_size)。對於後續調用loop_fn, 這emit_output對應於要在emit_ta.參數cell_state並輸出next_cell_state可以是單個或(可能是嵌套的)張量元組。參數loop_state並輸出next_loop_state可以是單個或(可能是嵌套的)元組TensorTensorArray對象。最後一個參數可能會被忽略loop_fn並且返回值可能是None.如果不是None,那麽loop_state將通過 RNN 循環傳播,純粹由loop_fn跟蹤自己的狀態。這next_loop_state返回的參數可能是None.第一次調用loop_fn將會time = 0,cell_output = None,cell_state = None, 和loop_state = None.對於這個電話:next_cell_statevalue 應該是用來初始化單元格狀態的值。它可能是前一個 RNN 的最終狀態,也可能是cell.zero_state().它應該是張量的(可能是嵌套的)元組結構。如果cell.state_size是一個整數,這必須是Tensor適當的類型和形狀[batch_size, cell.state_size].如果cell.state_size是一個TensorShape, 這一定是Tensor適當的類型和形狀[batch_size] + cell.state_size.如果cell.state_size是一個(可能是嵌套的)整數元組或TensorShape,這將是一個具有相應形狀的元組。這emit_output值可能是None或張量的(可能是嵌套的)元組結構,例如,(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1)).如果這是第一個emit_output返回值為None,那麽emit_ta的結果raw_rnn將具有相同的結構和數據類型cell.output_size.否則emit_ta將具有相同的結構、形狀(前麵帶有batch_size維度),並且 dtypes 為emit_output.返回的實際值emit_output在這個初始化調用被忽略。請注意,此發射結構必須在所有時間步中保持一致。
  • parallel_iterations (默認值:32)。並行運行的迭代次數。那些沒有任何時間依賴性並且可以並行運行的操作將是。此參數以時間換空間。值 >> 1 使用更多內存但花費的時間更少,而較小的值使用更少的內存但計算時間更長。
  • swap_memory 透明地交換前向推理中產生的張量,但需要從 GPU 到 CPU 的反向支撐。這允許訓練通常不適合單個 GPU 的 RNN,而性能損失非常小(或沒有)。
  • scope 創建的子圖的變量範圍;默認為"rnn"。

返回

  • 一個元組(emit_ta, final_state, final_loop_state)其中:

    emit_ta :RNN 輸出 TensorArray 。如果 loop_fn 在初始化期間為 emit_output 返回一組(可能是嵌套的)張量(輸入 time = 0cell_output = Noneloop_state = None ),則 emit_ta 將具有相同的結構、dtypes和形狀改為emit_output。如果 loop_fn 在此調用期間返回 emit_output = None,則使用 cell.output_size 的結構: 如果 cell.output_size 是(可能嵌套的)整數元組或 TensorShape 對象,則 emit_ta 將是具有與 cell.output_size 相同的結構,包含 TensorArrays,其元素的形狀對應於 cell.output_size 中的形狀數據。

    final_state :最終的細胞狀態。如果 cell.state_size 是 int,則其形狀為 [batch_size, cell.state_size] 。如果它是 TensorShape ,它將被塑造成 [batch_size] + cell.state_size 。如果它是一個(可能是嵌套的)整數元組或 TensorShape ,這將是一個具有相應形狀的元組。

    final_loop_stateloop_fn 返回的最終循環狀態。

拋出

  • TypeError 如果 cell 不是 RNNCell 的實例,或者 loop_fn 不是 callable

注意:此方法仍在測試中,API 可能會更改。**

此函數是dynamic_rnn 的更原始版本,每次迭代都提供對輸入的更直接訪問。它還提供了對何時開始和結束讀取序列以及為輸出發出什麽內容的更多控製。

例如,它可以用於實現 seq2seq 模型的動態解碼器。

大多數操作不是使用Tensor 對象,而是直接使用TensorArray 對象。

raw_rnn 的操作,在pseudo-code 中,基本如下:

time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
    time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
  (output, cell_state) = cell(next_input, state)
  (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
      time=time + 1, cell_output=output, cell_state=cell_state,
      loop_state=loop_state)
  # Emit zeros and copy forward state for minibatch entries that are finished.
  state = tf.where(finished, state, next_state)
  emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
  emit_ta = emit_ta.write(time, emit)
  # If any new minibatch entries are marked as finished, mark these.
  finished = tf.logical_or(finished, next_finished)
  time += 1
return (emit_ta, state, loop_state)

具有附加屬性,輸出和狀態可能是(可能是嵌套的)元組,由 cell.output_sizecell.state_size 確定,因此最終的 stateemit_ta 本身可能是元組。

通過raw_rnndynamic_rnn 的簡單實現如下所示:

inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
                        dtype=tf.float32)
sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)

cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)

def loop_fn(time, cell_output, cell_state, loop_state):
  emit_output = cell_output  # == None for time == 0
  if cell_output is None: # time == 0
    next_cell_state = cell.zero_state(batch_size, tf.float32)
  else:
    next_cell_state = cell_state
  elements_finished = (time >= sequence_length)
  finished = tf.reduce_all(elements_finished)
  next_input = tf.cond(
      finished,
      lambda:tf.zeros([batch_size, input_depth], dtype=tf.float32),
      lambda:inputs_ta.read(time))
  next_loop_state = None
  return (elements_finished, next_input, next_cell_state,
          emit_output, next_loop_state)

outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.raw_rnn。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。