创建由 RNNCell cell
和循环函数 loop_fn
指定的 RNN
。
用法
tf.compat.v1.nn.raw_rnn(
cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None
)
参数
-
cell
RNNCell 的一个实例。 -
loop_fn
接受输入的可调用对象(time, cell_output, cell_state, loop_state)
并返回元组(finished, next_input, next_cell_state, emit_output, next_loop_state)
.这里time
是一个 int32 标量Tensor
,cell_output
是一个Tensor
或(可能是嵌套的)张量元组,由下式确定cell.output_size
, 和cell_state
是一个Tensor
或(可能是嵌套的)张量元组,由loop_fn
在第一次调用时(并且应该匹配cell.state_size
)。输出是:finished
, 一个布尔值Tensor
形状的[batch_size]
,next_input
: 下一个要输入的输入cell
,next_cell_state
: 下一个要喂到的状态cell
, 和emit_output
: 本次迭代存储的输出。注意emit_output
应该是一个Tensor
或(可能是嵌套的)张量元组,聚合在emit_ta
在 - 的里面while_loop
.第一次调用loop_fn
, 这emit_output
对应于emit_structure
然后用于确定zero_tensor
为了emit_ta
(默认为cell.output_size
)。对于后续调用loop_fn
, 这emit_output
对应于要在emit_ta
.参数cell_state
并输出next_cell_state
可以是单个或(可能是嵌套的)张量元组。参数loop_state
并输出next_loop_state
可以是单个或(可能是嵌套的)元组Tensor
和TensorArray
对象。最后一个参数可能会被忽略loop_fn
并且返回值可能是None
.如果不是None
,那么loop_state
将通过 RNN 循环传播,纯粹由loop_fn
跟踪自己的状态。这next_loop_state
返回的参数可能是None
.第一次调用loop_fn
将会time = 0
,cell_output = None
,cell_state = None
, 和loop_state = None
.对于这个电话:next_cell_state
value 应该是用来初始化单元格状态的值。它可能是前一个 RNN 的最终状态,也可能是cell.zero_state()
.它应该是张量的(可能是嵌套的)元组结构。如果cell.state_size
是一个整数,这必须是Tensor
适当的类型和形状[batch_size, cell.state_size]
.如果cell.state_size
是一个TensorShape
, 这一定是Tensor
适当的类型和形状[batch_size] + cell.state_size
.如果cell.state_size
是一个(可能是嵌套的)整数元组或TensorShape
,这将是一个具有相应形状的元组。这emit_output
值可能是None
或张量的(可能是嵌套的)元组结构,例如,(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))
.如果这是第一个emit_output
返回值为None
,那么emit_ta
的结果raw_rnn
将具有相同的结构和数据类型cell.output_size
.否则emit_ta
将具有相同的结构、形状(前面带有batch_size
维度),并且 dtypes 为emit_output
.返回的实际值emit_output
在这个初始化调用被忽略。请注意,此发射结构必须在所有时间步中保持一致。 -
parallel_iterations
(默认值:32)。并行运行的迭代次数。那些没有任何时间依赖性并且可以并行运行的操作将是。此参数以时间换空间。值 >> 1 使用更多内存但花费的时间更少,而较小的值使用更少的内存但计算时间更长。 -
swap_memory
透明地交换前向推理中产生的张量,但需要从 GPU 到 CPU 的反向支撑。这允许训练通常不适合单个 GPU 的 RNN,而性能损失非常小(或没有)。 -
scope
创建的子图的变量范围;默认为"rnn"。
返回
-
一个元组
(emit_ta, final_state, final_loop_state)
其中:emit_ta
:RNN 输出TensorArray
。如果loop_fn
在初始化期间为emit_output
返回一组(可能是嵌套的)张量(输入time = 0
、cell_output = None
和loop_state = None
),则emit_ta
将具有相同的结构、dtypes和形状改为emit_output
。如果loop_fn
在此调用期间返回emit_output = None
,则使用cell.output_size
的结构: 如果cell.output_size
是(可能嵌套的)整数元组或TensorShape
对象,则emit_ta
将是具有与cell.output_size
相同的结构,包含 TensorArrays,其元素的形状对应于cell.output_size
中的形状数据。final_state
:最终的细胞状态。如果cell.state_size
是 int,则其形状为[batch_size, cell.state_size]
。如果它是TensorShape
,它将被塑造成[batch_size] + cell.state_size
。如果它是一个(可能是嵌套的)整数元组或TensorShape
,这将是一个具有相应形状的元组。final_loop_state
:loop_fn
返回的最终循环状态。
抛出
-
TypeError
如果cell
不是 RNNCell 的实例,或者loop_fn
不是callable
。
注意:此方法仍在测试中,API 可能会更改。**
此函数是dynamic_rnn
的更原始版本,每次迭代都提供对输入的更直接访问。它还提供了对何时开始和结束读取序列以及为输出发出什么内容的更多控制。
例如,它可以用于实现 seq2seq 模型的动态解码器。
大多数操作不是使用Tensor
对象,而是直接使用TensorArray
对象。
raw_rnn
的操作,在pseudo-code 中,基本如下:
time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
(output, cell_state) = cell(next_input, state)
(next_finished, next_input, next_state, emit, loop_state) = loop_fn(
time=time + 1, cell_output=output, cell_state=cell_state,
loop_state=loop_state)
# Emit zeros and copy forward state for minibatch entries that are finished.
state = tf.where(finished, state, next_state)
emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
emit_ta = emit_ta.write(time, emit)
# If any new minibatch entries are marked as finished, mark these.
finished = tf.logical_or(finished, next_finished)
time += 1
return (emit_ta, state, loop_state)
具有附加属性,输出和状态可能是(可能是嵌套的)元组,由 cell.output_size
和 cell.state_size
确定,因此最终的 state
和 emit_ta
本身可能是元组。
通过raw_rnn
的dynamic_rnn
的简单实现如下所示:
inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
dtype=tf.float32)
sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
next_cell_state = cell.zero_state(batch_size, tf.float32)
else:
next_cell_state = cell_state
elements_finished = (time >= sequence_length)
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(
finished,
lambda:tf.zeros([batch_size, input_depth], dtype=tf.float32),
lambda:inputs_ta.read(time))
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()
相关用法
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代码示例
- Python tf.compat.v1.nn.static_rnn用法及代码示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代码示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代码示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代码示例
- Python tf.compat.v1.nn.separable_conv2d用法及代码示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代码示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代码示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代码示例
- Python tf.compat.v1.nn.convolution用法及代码示例
- Python tf.compat.v1.nn.conv2d用法及代码示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代码示例
- Python tf.compat.v1.nn.nce_loss用法及代码示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代码示例
- Python tf.compat.v1.nn.pool用法及代码示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代码示例
- Python tf.compat.v1.nn.ctc_loss用法及代码示例
- Python tf.compat.v1.nn.erosion2d用法及代码示例
- Python tf.compat.v1.nn.dilation2d用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.raw_rnn。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。