創建由 RNNCell cell
和循環函數 loop_fn
指定的 RNN
。
用法
tf.compat.v1.nn.raw_rnn(
cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None
)
參數
-
cell
RNNCell 的一個實例。 -
loop_fn
接受輸入的可調用對象(time, cell_output, cell_state, loop_state)
並返回元組(finished, next_input, next_cell_state, emit_output, next_loop_state)
.這裏time
是一個 int32 標量Tensor
,cell_output
是一個Tensor
或(可能是嵌套的)張量元組,由下式確定cell.output_size
, 和cell_state
是一個Tensor
或(可能是嵌套的)張量元組,由loop_fn
在第一次調用時(並且應該匹配cell.state_size
)。輸出是:finished
, 一個布爾值Tensor
形狀的[batch_size]
,next_input
: 下一個要輸入的輸入cell
,next_cell_state
: 下一個要喂到的狀態cell
, 和emit_output
: 本次迭代存儲的輸出。注意emit_output
應該是一個Tensor
或(可能是嵌套的)張量元組,聚合在emit_ta
在 - 的裏麵while_loop
.第一次調用loop_fn
, 這emit_output
對應於emit_structure
然後用於確定zero_tensor
為了emit_ta
(默認為cell.output_size
)。對於後續調用loop_fn
, 這emit_output
對應於要在emit_ta
.參數cell_state
並輸出next_cell_state
可以是單個或(可能是嵌套的)張量元組。參數loop_state
並輸出next_loop_state
可以是單個或(可能是嵌套的)元組Tensor
和TensorArray
對象。最後一個參數可能會被忽略loop_fn
並且返回值可能是None
.如果不是None
,那麽loop_state
將通過 RNN 循環傳播,純粹由loop_fn
跟蹤自己的狀態。這next_loop_state
返回的參數可能是None
.第一次調用loop_fn
將會time = 0
,cell_output = None
,cell_state = None
, 和loop_state = None
.對於這個電話:next_cell_state
value 應該是用來初始化單元格狀態的值。它可能是前一個 RNN 的最終狀態,也可能是cell.zero_state()
.它應該是張量的(可能是嵌套的)元組結構。如果cell.state_size
是一個整數,這必須是Tensor
適當的類型和形狀[batch_size, cell.state_size]
.如果cell.state_size
是一個TensorShape
, 這一定是Tensor
適當的類型和形狀[batch_size] + cell.state_size
.如果cell.state_size
是一個(可能是嵌套的)整數元組或TensorShape
,這將是一個具有相應形狀的元組。這emit_output
值可能是None
或張量的(可能是嵌套的)元組結構,例如,(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))
.如果這是第一個emit_output
返回值為None
,那麽emit_ta
的結果raw_rnn
將具有相同的結構和數據類型cell.output_size
.否則emit_ta
將具有相同的結構、形狀(前麵帶有batch_size
維度),並且 dtypes 為emit_output
.返回的實際值emit_output
在這個初始化調用被忽略。請注意,此發射結構必須在所有時間步中保持一致。 -
parallel_iterations
(默認值:32)。並行運行的迭代次數。那些沒有任何時間依賴性並且可以並行運行的操作將是。此參數以時間換空間。值 >> 1 使用更多內存但花費的時間更少,而較小的值使用更少的內存但計算時間更長。 -
swap_memory
透明地交換前向推理中產生的張量,但需要從 GPU 到 CPU 的反向支撐。這允許訓練通常不適合單個 GPU 的 RNN,而性能損失非常小(或沒有)。 -
scope
創建的子圖的變量範圍;默認為"rnn"。
返回
-
一個元組
(emit_ta, final_state, final_loop_state)
其中:emit_ta
:RNN 輸出TensorArray
。如果loop_fn
在初始化期間為emit_output
返回一組(可能是嵌套的)張量(輸入time = 0
、cell_output = None
和loop_state = None
),則emit_ta
將具有相同的結構、dtypes和形狀改為emit_output
。如果loop_fn
在此調用期間返回emit_output = None
,則使用cell.output_size
的結構: 如果cell.output_size
是(可能嵌套的)整數元組或TensorShape
對象,則emit_ta
將是具有與cell.output_size
相同的結構,包含 TensorArrays,其元素的形狀對應於cell.output_size
中的形狀數據。final_state
:最終的細胞狀態。如果cell.state_size
是 int,則其形狀為[batch_size, cell.state_size]
。如果它是TensorShape
,它將被塑造成[batch_size] + cell.state_size
。如果它是一個(可能是嵌套的)整數元組或TensorShape
,這將是一個具有相應形狀的元組。final_loop_state
:loop_fn
返回的最終循環狀態。
拋出
-
TypeError
如果cell
不是 RNNCell 的實例,或者loop_fn
不是callable
。
注意:此方法仍在測試中,API 可能會更改。**
此函數是dynamic_rnn
的更原始版本,每次迭代都提供對輸入的更直接訪問。它還提供了對何時開始和結束讀取序列以及為輸出發出什麽內容的更多控製。
例如,它可以用於實現 seq2seq 模型的動態解碼器。
大多數操作不是使用Tensor
對象,而是直接使用TensorArray
對象。
raw_rnn
的操作,在pseudo-code 中,基本如下:
time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
(output, cell_state) = cell(next_input, state)
(next_finished, next_input, next_state, emit, loop_state) = loop_fn(
time=time + 1, cell_output=output, cell_state=cell_state,
loop_state=loop_state)
# Emit zeros and copy forward state for minibatch entries that are finished.
state = tf.where(finished, state, next_state)
emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
emit_ta = emit_ta.write(time, emit)
# If any new minibatch entries are marked as finished, mark these.
finished = tf.logical_or(finished, next_finished)
time += 1
return (emit_ta, state, loop_state)
具有附加屬性,輸出和狀態可能是(可能是嵌套的)元組,由 cell.output_size
和 cell.state_size
確定,因此最終的 state
和 emit_ta
本身可能是元組。
通過raw_rnn
的dynamic_rnn
的簡單實現如下所示:
inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
dtype=tf.float32)
sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
next_cell_state = cell.zero_state(batch_size, tf.float32)
else:
next_cell_state = cell_state
elements_finished = (time >= sequence_length)
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(
finished,
lambda:tf.zeros([batch_size, input_depth], dtype=tf.float32),
lambda:inputs_ta.read(time))
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()
相關用法
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代碼示例
- Python tf.compat.v1.nn.static_rnn用法及代碼示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代碼示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代碼示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.separable_conv2d用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代碼示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代碼示例
- Python tf.compat.v1.nn.convolution用法及代碼示例
- Python tf.compat.v1.nn.conv2d用法及代碼示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.nce_loss用法及代碼示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代碼示例
- Python tf.compat.v1.nn.pool用法及代碼示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.ctc_loss用法及代碼示例
- Python tf.compat.v1.nn.erosion2d用法及代碼示例
- Python tf.compat.v1.nn.dilation2d用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.raw_rnn。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。