当前位置: 首页>>代码示例>>Python>>正文


Python seq2seq.AttentionWrapperState方法代码示例

本文整理汇总了Python中tensorflow.contrib.seq2seq.AttentionWrapperState方法的典型用法代码示例。如果您正苦于以下问题:Python seq2seq.AttentionWrapperState方法的具体用法?Python seq2seq.AttentionWrapperState怎么用?Python seq2seq.AttentionWrapperState使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.seq2seq的用法示例。


在下文中一共展示了seq2seq.AttentionWrapperState方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _build_init_state

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def _build_init_state(self, batch_size, enc_state, rnn_cell, mode, hparams):
    """Builds initial states for the given RNN cells."""
    del mode  # Unused.

    # Build init state.
    init_state = rnn_cell.zero_state(batch_size, tf.float32)

    if hparams.pass_hidden_state:
      # Non-GNMT RNN cell returns AttentionWrappedState.
      if isinstance(init_state, contrib_seq2seq.AttentionWrapperState):
        init_state = init_state.clone(cell_state=enc_state)
      # GNMT RNN cell returns a tuple state.
      elif isinstance(init_state, tuple):
        init_state = tuple(
            zs.clone(cell_state=es) if isinstance(
                zs, contrib_seq2seq.AttentionWrapperState) else es
            for zs, es in zip(init_state, enc_state))
      else:
        ValueError("RNN cell returns zero states of unknown type: %s"
                   % str(type(init_state)))

    return init_state 
开发者ID:google-research,项目名称:language,代码行数:24,代码来源:decoders.py

示例2: wrapper

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def wrapper(self, state):
        """Some RNN states are wrapped in namedtuples. 
        (TensorFlow decision, definitely not mine...). 
        
        This is here for derived classes to specify their wrapper state. 
        Some examples: LSTMStateTuple and AttentionWrapperState.
        
        Args:
            state: tensor state tuple, will be unpacked into the wrapper tuple.
        """
        if self._wrapper is None:
            return state
        else:
            return self._wrapper(*state) 
开发者ID:mckinziebrandon,项目名称:DeepChatModels,代码行数:16,代码来源:_rnn.py

示例3: zero_state

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def zero_state(self, batch_size, dtype):
        with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "zero_state of AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.")
            with tf.control_dependencies(
                [tf.assert_equal(batch_size,
                    self._attention_mechanism.batch_size,
                    message=error_message)]):
                cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="checked_cell_state"),
                    cell_state)
            alignment_history = ()

            _zero_state_tensors = rnn_cell_impl._zero_state_tensors
            return AttentionWrapperState(
                cell_state=cell_state,
                time=tf.zeros([], dtype=tf.int32),
                attention=_zero_state_tensors(self._attention_size, batch_size,
                dtype),
                alignments=self._attention_mechanism.initial_alignments(
                    batch_size, dtype),
                alignment_history=alignment_history) 
开发者ID:mckinziebrandon,项目名称:DeepChatModels,代码行数:30,代码来源:_rnn.py

示例4: state_size

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def state_size(self):
        return AttentionWrapperState(
            cell_state=self._cell.state_size,
            attention=self._attention_size,
            time=tf.TensorShape([]),
            alignments=self._attention_mechanism.alignments_size,
            alignment_history=()) 
开发者ID:mckinziebrandon,项目名称:DeepChatModels,代码行数:9,代码来源:_rnn.py

示例5: shape

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def shape(self):
        return AttentionWrapperState(
            cell_state=self._cell.shape,
            attention=tf.TensorShape([None, self._attention_size]),
            time=tf.TensorShape(None),
            alignments=tf.TensorShape([None, None]),
            alignment_history=()) 
开发者ID:mckinziebrandon,项目名称:DeepChatModels,代码行数:9,代码来源:_rnn.py

示例6: call

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def call(self, inputs, state):
        """First computes the cell state and output in the usual way, 
        then works through the attention pipeline:
            h --> a --> c --> h_tilde
        using the naming/notation from Luong et. al, 2015.

        Args:
            inputs: `2-D` tensor with shape `[batch_size x input_size]`.
            state: An instance of `AttentionWrapperState` containing the 
                tensors from the prev timestep.
     
        Returns:
            A tuple `(attention_or_cell_output, next_state)`, where:
            - `attention_or_cell_output` depending on `output_attention`.
            - `next_state` is an instance of `DynamicAttentionWrapperState`
                containing the state calculated at this time step.
        """

        # Concatenate the previous h_tilde with inputs (input-feeding).
        cell_inputs = tf.concat([inputs, state.attention], -1)

        # 1. (hidden) Compute the hidden state (cell_output).
        cell_output, next_cell_state = self._cell(cell_inputs,
                                                  state.cell_state)

        # 2. (align) Compute the normalized alignment scores. [B, L_enc].
        # where L_enc is the max seq len in the encoder outputs for the (B)atch.
        score = self._attention_mechanism(
            cell_output, previous_alignments=state.alignments)
        alignments = tf.nn.softmax(score)

        # Reshape from [B, L_enc] to [B, 1, L_enc]
        expanded_alignments = tf.expand_dims(alignments, 1)
        # (Possibly projected) encoder outputs: [B, L_enc, state_size]
        encoder_outputs = self._attention_mechanism.values
        # 3 (context) Take inner prod. [B, 1, state size].
        context = tf.matmul(expanded_alignments, encoder_outputs)
        context = tf.squeeze(context, [1])

        # 4 (h_tilde) Compute tanh(W [c, h]).
        attention = self._attention_layer(
            tf.concat([cell_output, context], -1))

        next_state = AttentionWrapperState(
            cell_state=next_cell_state,
            attention=attention,
            time=state.time + 1,
            alignments=alignments,
            alignment_history=())

        return attention, next_state 
开发者ID:mckinziebrandon,项目名称:DeepChatModels,代码行数:53,代码来源:_rnn.py

示例7: call

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def call(self, inputs, state):
    if not isinstance(state, seq2seq.AttentionWrapperState):
      raise TypeError("Expected state to be instance of AttentionWrapperState. "
                      "Received type %s instead."  % type(state))

    if self._is_multi:
      previous_alignments = state.alignments
      previous_alignment_history = state.alignment_history
    else:
      previous_alignments = [state.alignments]
      previous_alignment_history = [state.alignment_history]

    all_alignments = []
    all_attentions = []
    all_histories = []
    for i, attention_mechanism in enumerate(self._attention_mechanisms):
      if isinstance(self._cell, rnn.LSTMCell):
        rnn_cell_state = state.cell_state.h
      else:
        rnn_cell_state = state.cell_state
      attention, alignments = _compute_attention(
          attention_mechanism, rnn_cell_state, previous_alignments[i],
          self._attention_layers[i] if self._attention_layers else None)
      alignment_history = previous_alignment_history[i].write(
          state.time, alignments) if self._alignment_history else ()

      all_alignments.append(alignments)
      all_histories.append(alignment_history)
      all_attentions.append(attention)

    attention = array_ops.concat(all_attentions, 1)

    cell_inputs = self._cell_input_fn(inputs, attention)
    cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)

    next_state = seq2seq.AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        alignments=self._item_or_tuple(all_alignments),
        alignment_history=self._item_or_tuple(all_histories))
    
    if self._output_attention:
      return attention, next_state
    else:
      return cell_output, next_state 
开发者ID:bgshih,项目名称:aster,代码行数:48,代码来源:sync_attention_wrapper.py

示例8: call

# 需要导入模块: from tensorflow.contrib import seq2seq [as 别名]
# 或者: from tensorflow.contrib.seq2seq import AttentionWrapperState [as 别名]
def call(self, inputs, state):
    if not isinstance(state, seq2seq.AttentionWrapperState):
      raise TypeError("Expected state to be instance of AttentionWrapperState. "
                      "Received type %s instead."  % type(state))

    if self._is_multi:
      previous_alignments = state.alignments
      previous_alignment_history = state.alignment_history
    else:
      previous_alignments = [state.alignments]
      previous_alignment_history = [state.alignment_history]

    all_alignments = []
    all_attentions = []
    all_attention_states = []
    all_histories = []
    for i, attention_mechanism in enumerate(self._attention_mechanisms):
      if isinstance(self._cell, rnn.LSTMCell):
        rnn_cell_state = state.cell_state.h
      else:
        rnn_cell_state = state.cell_state
      attention, alignments, next_attention_state = _compute_attention(
          attention_mechanism, rnn_cell_state, previous_alignments[i],
          self._attention_layers[i] if self._attention_layers else None)
      alignment_history = previous_alignment_history[i].write(
          state.time, alignments) if self._alignment_history else ()

      all_attention_states.append(next_attention_state)
      all_alignments.append(alignments)
      all_histories.append(alignment_history)
      all_attentions.append(attention)

    attention = array_ops.concat(all_attentions, 1)

    cell_inputs = self._cell_input_fn(inputs, attention)
    cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)

    next_state = seq2seq.AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        attention_state=self._item_or_tuple(all_attention_states),
        alignments=self._item_or_tuple(all_alignments),
        alignment_history=self._item_or_tuple(all_histories))
    
    if self._output_attention:
      return attention, next_state
    else:
      return cell_output, next_state 
开发者ID:huizhang0110,项目名称:AON,代码行数:51,代码来源:sync_attention_wrapper.py


注:本文中的tensorflow.contrib.seq2seq.AttentionWrapperState方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。