創建由 RNNCell cell 和循環函數 loop_fn 指定的 RNN。
用法
tf.compat.v1.nn.raw_rnn(
cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None
)參數
-
cellRNNCell 的一個實例。 -
loop_fn接受輸入的可調用對象(time, cell_output, cell_state, loop_state)並返回元組(finished, next_input, next_cell_state, emit_output, next_loop_state).這裏time是一個 int32 標量Tensor,cell_output是一個Tensor或(可能是嵌套的)張量元組,由下式確定cell.output_size, 和cell_state是一個Tensor或(可能是嵌套的)張量元組,由loop_fn在第一次調用時(並且應該匹配cell.state_size)。輸出是:finished, 一個布爾值Tensor形狀的[batch_size],next_input: 下一個要輸入的輸入cell,next_cell_state: 下一個要喂到的狀態cell, 和emit_output: 本次迭代存儲的輸出。注意emit_output應該是一個Tensor或(可能是嵌套的)張量元組,聚合在emit_ta在 - 的裏麵while_loop.第一次調用loop_fn, 這emit_output對應於emit_structure然後用於確定zero_tensor為了emit_ta(默認為cell.output_size)。對於後續調用loop_fn, 這emit_output對應於要在emit_ta.參數cell_state並輸出next_cell_state可以是單個或(可能是嵌套的)張量元組。參數loop_state並輸出next_loop_state可以是單個或(可能是嵌套的)元組Tensor和TensorArray對象。最後一個參數可能會被忽略loop_fn並且返回值可能是None.如果不是None,那麽loop_state將通過 RNN 循環傳播,純粹由loop_fn跟蹤自己的狀態。這next_loop_state返回的參數可能是None.第一次調用loop_fn將會time = 0,cell_output = None,cell_state = None, 和loop_state = None.對於這個電話:next_cell_statevalue 應該是用來初始化單元格狀態的值。它可能是前一個 RNN 的最終狀態,也可能是cell.zero_state().它應該是張量的(可能是嵌套的)元組結構。如果cell.state_size是一個整數,這必須是Tensor適當的類型和形狀[batch_size, cell.state_size].如果cell.state_size是一個TensorShape, 這一定是Tensor適當的類型和形狀[batch_size] + cell.state_size.如果cell.state_size是一個(可能是嵌套的)整數元組或TensorShape,這將是一個具有相應形狀的元組。這emit_output值可能是None或張量的(可能是嵌套的)元組結構,例如,(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1)).如果這是第一個emit_output返回值為None,那麽emit_ta的結果raw_rnn將具有相同的結構和數據類型cell.output_size.否則emit_ta將具有相同的結構、形狀(前麵帶有batch_size維度),並且 dtypes 為emit_output.返回的實際值emit_output在這個初始化調用被忽略。請注意,此發射結構必須在所有時間步中保持一致。 -
parallel_iterations(默認值:32)。並行運行的迭代次數。那些沒有任何時間依賴性並且可以並行運行的操作將是。此參數以時間換空間。值 >> 1 使用更多內存但花費的時間更少,而較小的值使用更少的內存但計算時間更長。 -
swap_memory透明地交換前向推理中產生的張量,但需要從 GPU 到 CPU 的反向支撐。這允許訓練通常不適合單個 GPU 的 RNN,而性能損失非常小(或沒有)。 -
scope創建的子圖的變量範圍;默認為"rnn"。
返回
-
一個元組
(emit_ta, final_state, final_loop_state)其中:emit_ta:RNN 輸出TensorArray。如果loop_fn在初始化期間為emit_output返回一組(可能是嵌套的)張量(輸入time = 0、cell_output = None和loop_state = None),則emit_ta將具有相同的結構、dtypes和形狀改為emit_output。如果loop_fn在此調用期間返回emit_output = None,則使用cell.output_size的結構: 如果cell.output_size是(可能嵌套的)整數元組或TensorShape對象,則emit_ta將是具有與cell.output_size相同的結構,包含 TensorArrays,其元素的形狀對應於cell.output_size中的形狀數據。final_state:最終的細胞狀態。如果cell.state_size是 int,則其形狀為[batch_size, cell.state_size]。如果它是TensorShape,它將被塑造成[batch_size] + cell.state_size。如果它是一個(可能是嵌套的)整數元組或TensorShape,這將是一個具有相應形狀的元組。final_loop_state:loop_fn返回的最終循環狀態。
拋出
-
TypeError如果cell不是 RNNCell 的實例,或者loop_fn不是callable。
注意:此方法仍在測試中,API 可能會更改。**
此函數是dynamic_rnn 的更原始版本,每次迭代都提供對輸入的更直接訪問。它還提供了對何時開始和結束讀取序列以及為輸出發出什麽內容的更多控製。
例如,它可以用於實現 seq2seq 模型的動態解碼器。
大多數操作不是使用Tensor 對象,而是直接使用TensorArray 對象。
raw_rnn 的操作,在pseudo-code 中,基本如下:
time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
(output, cell_state) = cell(next_input, state)
(next_finished, next_input, next_state, emit, loop_state) = loop_fn(
time=time + 1, cell_output=output, cell_state=cell_state,
loop_state=loop_state)
# Emit zeros and copy forward state for minibatch entries that are finished.
state = tf.where(finished, state, next_state)
emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
emit_ta = emit_ta.write(time, emit)
# If any new minibatch entries are marked as finished, mark these.
finished = tf.logical_or(finished, next_finished)
time += 1
return (emit_ta, state, loop_state)
具有附加屬性,輸出和狀態可能是(可能是嵌套的)元組,由 cell.output_size 和 cell.state_size 確定,因此最終的 state 和 emit_ta 本身可能是元組。
通過raw_rnn 的dynamic_rnn 的簡單實現如下所示:
inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
dtype=tf.float32)
sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
next_cell_state = cell.zero_state(batch_size, tf.float32)
else:
next_cell_state = cell_state
elements_finished = (time >= sequence_length)
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(
finished,
lambda:tf.zeros([batch_size, input_depth], dtype=tf.float32),
lambda:inputs_ta.read(time))
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()
相關用法
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代碼示例
- Python tf.compat.v1.nn.static_rnn用法及代碼示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代碼示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代碼示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.separable_conv2d用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代碼示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代碼示例
- Python tf.compat.v1.nn.convolution用法及代碼示例
- Python tf.compat.v1.nn.conv2d用法及代碼示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.nce_loss用法及代碼示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代碼示例
- Python tf.compat.v1.nn.pool用法及代碼示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.ctc_loss用法及代碼示例
- Python tf.compat.v1.nn.erosion2d用法及代碼示例
- Python tf.compat.v1.nn.dilation2d用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.raw_rnn。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
