本文整理汇总了Python中tensorflow.contrib.seq2seq.AttentionWrapperState方法的典型用法代码示例。如果您正苦于以下问题:Python seq2seq.AttentionWrapperState方法的具体用法?Python seq2seq.AttentionWrapperState怎么用?Python seq2seq.AttentionWrapperState使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.seq2seq
的用法示例。
在下文中一共展示了seq2seq.AttentionWrapperState方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _build_init_state
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def _build_init_state(self, batch_size, enc_state, rnn_cell, mode, hparams):
"""Builds initial states for the given RNN cells."""
del mode # Unused.
# Build init state.
init_state = rnn_cell.zero_state(batch_size, tf.float32)
if hparams.pass_hidden_state:
# Non-GNMT RNN cell returns AttentionWrappedState.
if isinstance(init_state, contrib_seq2seq.AttentionWrapperState):
init_state = init_state.clone(cell_state=enc_state)
# GNMT RNN cell returns a tuple state.
elif isinstance(init_state, tuple):
init_state = tuple(
zs.clone(cell_state=es) if isinstance(
zs, contrib_seq2seq.AttentionWrapperState) else es
for zs, es in zip(init_state, enc_state))
else:
ValueError("RNN cell returns zero states of unknown type: %s"
% str(type(init_state)))
return init_state
示例2: wrapper
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def wrapper(self, state):
"""Some RNN states are wrapped in namedtuples.
(TensorFlow decision, definitely not mine...).
This is here for derived classes to specify their wrapper state.
Some examples: LSTMStateTuple and AttentionWrapperState.
Args:
state: tensor state tuple, will be unpacked into the wrapper tuple.
"""
if self._wrapper is None:
return state
else:
return self._wrapper(*state)
示例3: zero_state
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def zero_state(self, batch_size, dtype):
with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
if self._initial_cell_state is not None:
cell_state = self._initial_cell_state
else:
cell_state = self._cell.zero_state(batch_size, dtype)
error_message = (
"zero_state of AttentionWrapper %s: " % self._base_name +
"Non-matching batch sizes between the memory "
"(encoder output) and the requested batch size.")
with tf.control_dependencies(
[tf.assert_equal(batch_size,
self._attention_mechanism.batch_size,
message=error_message)]):
cell_state = nest.map_structure(
lambda s: tf.identity(s, name="checked_cell_state"),
cell_state)
alignment_history = ()
_zero_state_tensors = rnn_cell_impl._zero_state_tensors
return AttentionWrapperState(
cell_state=cell_state,
time=tf.zeros([], dtype=tf.int32),
attention=_zero_state_tensors(self._attention_size, batch_size,
dtype),
alignments=self._attention_mechanism.initial_alignments(
batch_size, dtype),
alignment_history=alignment_history)
示例4: state_size
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def state_size(self):
return AttentionWrapperState(
cell_state=self._cell.state_size,
attention=self._attention_size,
time=tf.TensorShape([]),
alignments=self._attention_mechanism.alignments_size,
alignment_history=())
示例5: shape
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def shape(self):
return AttentionWrapperState(
cell_state=self._cell.shape,
attention=tf.TensorShape([None, self._attention_size]),
time=tf.TensorShape(None),
alignments=tf.TensorShape([None, None]),
alignment_history=())
示例6: call
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def call(self, inputs, state):
"""First computes the cell state and output in the usual way,
then works through the attention pipeline:
h --> a --> c --> h_tilde
using the naming/notation from Luong et. al, 2015.
Args:
inputs: `2-D` tensor with shape `[batch_size x input_size]`.
state: An instance of `AttentionWrapperState` containing the
tensors from the prev timestep.
Returns:
A tuple `(attention_or_cell_output, next_state)`, where:
- `attention_or_cell_output` depending on `output_attention`.
- `next_state` is an instance of `DynamicAttentionWrapperState`
containing the state calculated at this time step.
"""
# Concatenate the previous h_tilde with inputs (input-feeding).
cell_inputs = tf.concat([inputs, state.attention], -1)
# 1. (hidden) Compute the hidden state (cell_output).
cell_output, next_cell_state = self._cell(cell_inputs,
state.cell_state)
# 2. (align) Compute the normalized alignment scores. [B, L_enc].
# where L_enc is the max seq len in the encoder outputs for the (B)atch.
score = self._attention_mechanism(
cell_output, previous_alignments=state.alignments)
alignments = tf.nn.softmax(score)
# Reshape from [B, L_enc] to [B, 1, L_enc]
expanded_alignments = tf.expand_dims(alignments, 1)
# (Possibly projected) encoder outputs: [B, L_enc, state_size]
encoder_outputs = self._attention_mechanism.values
# 3 (context) Take inner prod. [B, 1, state size].
context = tf.matmul(expanded_alignments, encoder_outputs)
context = tf.squeeze(context, [1])
# 4 (h_tilde) Compute tanh(W [c, h]).
attention = self._attention_layer(
tf.concat([cell_output, context], -1))
next_state = AttentionWrapperState(
cell_state=next_cell_state,
attention=attention,
time=state.time + 1,
alignments=alignments,
alignment_history=())
return attention, next_state
示例7: call
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def call(self, inputs, state):
if not isinstance(state, seq2seq.AttentionWrapperState):
raise TypeError("Expected state to be instance of AttentionWrapperState. "
"Received type %s instead." % type(state))
if self._is_multi:
previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
else:
previous_alignments = [state.alignments]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
if isinstance(self._cell, rnn.LSTMCell):
rnn_cell_state = state.cell_state.h
else:
rnn_cell_state = state.cell_state
attention, alignments = _compute_attention(
attention_mechanism, rnn_cell_state, previous_alignments[i],
self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(
state.time, alignments) if self._alignment_history else ()
all_alignments.append(alignments)
all_histories.append(alignment_history)
all_attentions.append(attention)
attention = array_ops.concat(all_attentions, 1)
cell_inputs = self._cell_input_fn(inputs, attention)
cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)
next_state = seq2seq.AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(all_histories))
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state
示例8: call
# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def call(self, inputs, state):
if not isinstance(state, seq2seq.AttentionWrapperState):
raise TypeError("Expected state to be instance of AttentionWrapperState. "
"Received type %s instead." % type(state))
if self._is_multi:
previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
else:
previous_alignments = [state.alignments]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_attention_states = []
all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
if isinstance(self._cell, rnn.LSTMCell):
rnn_cell_state = state.cell_state.h
else:
rnn_cell_state = state.cell_state
attention, alignments, next_attention_state = _compute_attention(
attention_mechanism, rnn_cell_state, previous_alignments[i],
self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(
state.time, alignments) if self._alignment_history else ()
all_attention_states.append(next_attention_state)
all_alignments.append(alignments)
all_histories.append(alignment_history)
all_attentions.append(attention)
attention = array_ops.concat(all_attentions, 1)
cell_inputs = self._cell_input_fn(inputs, attention)
cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)
next_state = seq2seq.AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
attention_state=self._item_or_tuple(all_attention_states),
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(all_histories))
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state