当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.nn.raw_rnn用法及代码示例


创建由 RNNCell cell 和循环函数 loop_fn 指定的 RNN

用法

tf.compat.v1.nn.raw_rnn(
    cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None
)

参数

  • cell RNNCell 的一个实例。
  • loop_fn 接受输入的可调用对象(time, cell_output, cell_state, loop_state)并返回元组(finished, next_input, next_cell_state, emit_output, next_loop_state).这里time是一个 int32 标量Tensor,cell_output是一个Tensor或(可能是嵌套的)张量元组,由下式确定cell.output_size, 和cell_state是一个Tensor或(可能是嵌套的)张量元组,由loop_fn在第一次调用时(并且应该匹配cell.state_size)。输出是:finished, 一个布尔值Tensor形状的[batch_size],next_input: 下一个要输入的输入cell,next_cell_state: 下一个要喂到的状态cell, 和emit_output: 本次迭代存储的输出。注意emit_output应该是一个Tensor或(可能是嵌套的)张量元组,聚合在emit_ta在 - 的里面while_loop.第一次调用loop_fn, 这emit_output对应于emit_structure然后用于确定zero_tensor为了emit_ta(默认为cell.output_size)。对于后续调用loop_fn, 这emit_output对应于要在emit_ta.参数cell_state并输出next_cell_state可以是单个或(可能是嵌套的)张量元组。参数loop_state并输出next_loop_state可以是单个或(可能是嵌套的)元组TensorTensorArray对象。最后一个参数可能会被忽略loop_fn并且返回值可能是None.如果不是None,那么loop_state将通过 RNN 循环传播,纯粹由loop_fn跟踪自己的状态。这next_loop_state返回的参数可能是None.第一次调用loop_fn将会time = 0,cell_output = None,cell_state = None, 和loop_state = None.对于这个电话:next_cell_statevalue 应该是用来初始化单元格状态的值。它可能是前一个 RNN 的最终状态,也可能是cell.zero_state().它应该是张量的(可能是嵌套的)元组结构。如果cell.state_size是一个整数,这必须是Tensor适当的类型和形状[batch_size, cell.state_size].如果cell.state_size是一个TensorShape, 这一定是Tensor适当的类型和形状[batch_size] + cell.state_size.如果cell.state_size是一个(可能是嵌套的)整数元组或TensorShape,这将是一个具有相应形状的元组。这emit_output值可能是None或张量的(可能是嵌套的)元组结构,例如,(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1)).如果这是第一个emit_output返回值为None,那么emit_ta的结果raw_rnn将具有相同的结构和数据类型cell.output_size.否则emit_ta将具有相同的结构、形状(前面带有batch_size维度),并且 dtypes 为emit_output.返回的实际值emit_output在这个初始化调用被忽略。请注意,此发射结构必须在所有时间步中保持一致。
  • parallel_iterations (默认值:32)。并行运行的迭代次数。那些没有任何时间依赖性并且可以并行运行的操作将是。此参数以时间换空间。值 >> 1 使用更多内存但花费的时间更少,而较小的值使用更少的内存但计算时间更长。
  • swap_memory 透明地交换前向推理中产生的张量,但需要从 GPU 到 CPU 的反向支撑。这允许训练通常不适合单个 GPU 的 RNN,而性能损失非常小(或没有)。
  • scope 创建的子图的变量范围;默认为"rnn"。

返回

  • 一个元组(emit_ta, final_state, final_loop_state)其中:

    emit_ta :RNN 输出 TensorArray 。如果 loop_fn 在初始化期间为 emit_output 返回一组(可能是嵌套的)张量(输入 time = 0cell_output = Noneloop_state = None ),则 emit_ta 将具有相同的结构、dtypes和形状改为emit_output。如果 loop_fn 在此调用期间返回 emit_output = None,则使用 cell.output_size 的结构: 如果 cell.output_size 是(可能嵌套的)整数元组或 TensorShape 对象,则 emit_ta 将是具有与 cell.output_size 相同的结构,包含 TensorArrays,其元素的形状对应于 cell.output_size 中的形状数据。

    final_state :最终的细胞状态。如果 cell.state_size 是 int,则其形状为 [batch_size, cell.state_size] 。如果它是 TensorShape ,它将被塑造成 [batch_size] + cell.state_size 。如果它是一个(可能是嵌套的)整数元组或 TensorShape ,这将是一个具有相应形状的元组。

    final_loop_stateloop_fn 返回的最终循环状态。

抛出

  • TypeError 如果 cell 不是 RNNCell 的实例,或者 loop_fn 不是 callable

注意:此方法仍在测试中,API 可能会更改。**

此函数是dynamic_rnn 的更原始版本,每次迭代都提供对输入的更直接访问。它还提供了对何时开始和结束读取序列以及为输出发出什么内容的更多控制。

例如,它可以用于实现 seq2seq 模型的动态解码器。

大多数操作不是使用Tensor 对象,而是直接使用TensorArray 对象。

raw_rnn 的操作,在pseudo-code 中,基本如下:

time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
    time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
  (output, cell_state) = cell(next_input, state)
  (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
      time=time + 1, cell_output=output, cell_state=cell_state,
      loop_state=loop_state)
  # Emit zeros and copy forward state for minibatch entries that are finished.
  state = tf.where(finished, state, next_state)
  emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
  emit_ta = emit_ta.write(time, emit)
  # If any new minibatch entries are marked as finished, mark these.
  finished = tf.logical_or(finished, next_finished)
  time += 1
return (emit_ta, state, loop_state)

具有附加属性,输出和状态可能是(可能是嵌套的)元组,由 cell.output_sizecell.state_size 确定,因此最终的 stateemit_ta 本身可能是元组。

通过raw_rnndynamic_rnn 的简单实现如下所示:

inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
                        dtype=tf.float32)
sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)

cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)

def loop_fn(time, cell_output, cell_state, loop_state):
  emit_output = cell_output  # == None for time == 0
  if cell_output is None: # time == 0
    next_cell_state = cell.zero_state(batch_size, tf.float32)
  else:
    next_cell_state = cell_state
  elements_finished = (time >= sequence_length)
  finished = tf.reduce_all(elements_finished)
  next_input = tf.cond(
      finished,
      lambda:tf.zeros([batch_size, input_depth], dtype=tf.float32),
      lambda:inputs_ta.read(time))
  next_loop_state = None
  return (elements_finished, next_input, next_cell_state,
          emit_output, next_loop_state)

outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.raw_rnn。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。