本文整理汇总了Python中tensorflow.python.ops.rnn_cell.MultiRNNCell方法的典型用法代码示例。如果您正苦于以下问题:Python rnn_cell.MultiRNNCell方法的具体用法?Python rnn_cell.MultiRNNCell怎么用?Python rnn_cell.MultiRNNCell使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.rnn_cell
的用法示例。
在下文中一共展示了rnn_cell.MultiRNNCell方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_rnn_cell
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def _get_rnn_cell(cell_type, num_units, num_layers):
"""Constructs and return an `RNNCell`.
Args:
cell_type: either a string identifying the `RNNCell` type, or a subclass of
`RNNCell`.
num_units: the number of units in the `RNNCell`.
num_layers: the number of layers in the RNN.
Returns:
An initialized `RNNCell`.
Raises:
ValueError: `cell_type` is an invalid `RNNCell` name.
TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
"""
if isinstance(cell_type, str):
cell_type = _CELL_TYPES.get(cell_type)
if cell_type is None:
raise ValueError('The supported cell types are {}; got {}'.format(
list(_CELL_TYPES.keys()), cell_type))
if not issubclass(cell_type, rnn_cell.RNNCell):
raise TypeError(
'cell_type must be a subclass of RNNCell or one of {}.'.format(
list(_CELL_TYPES.keys())))
cell = cell_type(num_units=num_units)
if num_layers > 1:
cell = rnn_cell.MultiRNNCell(
[cell] * num_layers, state_is_tuple=True)
return cell
示例2: save_variables_list
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def save_variables_list(self):
# Return a list of the trainable variables created within the rnnlm model.
# This consists of the two projection softmax variables (softmax_w and softmax_b),
# embedding, and all of the weights and biases in the MultiRNNCell model.
# Save only the trainable variables and the placeholders needed to resume training;
# discard the rest, including optimizer state.
save_vars = set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='rnnlm'))
save_vars.update({self.lr, self.global_epoch_fraction, self.global_seconds_elapsed})
return list(save_vars)
示例3: _build_pre
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def _build_pre(self):
self.dimA = 20
self.cellA = MultiRNNCell([LSTMCell(self.dimA)] * 2)
self.b1 = 0.95
self.b2 = 0.95
self.lr = 0.1
self.eps = 1e-8
示例4: _build_pre
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def _build_pre(self):
self.dimH = 20
self.cellH = MultiRNNCell([LSTMCell(self.dimH)] * 2)
self.lr = 0.1
示例5: _create_encoder_cell
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def _create_encoder_cell(self):
return MultiRNNCell([self._create_rnn_cell() for _ in range(self.cfg.num_layers)])
示例6: _create_decoder_cell
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def _create_decoder_cell(self):
enc_outputs, enc_states, enc_seq_len = self.enc_outputs, self.enc_states, self.enc_seq_len
batch_size = self.batch_size * self.cfg.beam_size if self.use_beam_search else self.batch_size
with tf.variable_scope("attention"):
if self.cfg.attention == "luong": # Luong attention mechanism
attention_mechanism = LuongAttention(num_units=self.cfg.num_units, memory=enc_outputs,
memory_sequence_length=enc_seq_len)
else: # default using Bahdanau attention mechanism
attention_mechanism = BahdanauAttention(num_units=self.cfg.num_units, memory=enc_outputs,
memory_sequence_length=enc_seq_len)
def cell_input_fn(inputs, attention): # define cell input function to keep input/output dimension same
# reference: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/AttentionWrapper
if not self.cfg.use_attention_input_feeding:
return inputs
input_project = tf.layers.Dense(self.cfg.num_units, dtype=tf.float32, name='attn_input_feeding')
return input_project(tf.concat([inputs, attention], axis=-1))
if self.cfg.top_attention: # apply attention mechanism only on the top decoder layer
cells = [self._create_rnn_cell() for _ in range(self.cfg.num_layers)]
cells[-1] = AttentionWrapper(cells[-1], attention_mechanism=attention_mechanism, name="Attention_Wrapper",
attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states[-1],
cell_input_fn=cell_input_fn)
initial_state = [state for state in enc_states]
initial_state[-1] = cells[-1].zero_state(batch_size=batch_size, dtype=tf.float32)
dec_init_states = tuple(initial_state)
cells = MultiRNNCell(cells)
else:
cells = MultiRNNCell([self._create_rnn_cell() for _ in range(self.cfg.num_layers)])
cells = AttentionWrapper(cells, attention_mechanism=attention_mechanism, name="Attention_Wrapper",
attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states,
cell_input_fn=cell_input_fn)
dec_init_states = cells.zero_state(batch_size=batch_size, dtype=tf.float32).clone(cell_state=enc_states)
return cells, dec_init_states
示例7: _create_rnn_cell
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def _create_rnn_cell(self):
if self.cfg["num_layers"] is None or self.cfg["num_layers"] <= 1:
return self._create_single_rnn_cell(self.cfg["num_units"])
else:
if self.cfg["use_stack_rnn"]:
return [self._create_single_rnn_cell(self.cfg["num_units"]) for _ in range(self.cfg["num_layers"])]
else:
return MultiRNNCell([self._create_single_rnn_cell(self.cfg["num_units"])
for _ in range(self.cfg["num_layers"])])
示例8: build_cell
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def build_cell(hidden_units, depth=1):
cell_list = [build_single_cell(hidden_units) for i in range(depth)]
return MultiRNNCell(cell_list)
示例9: build_cell
# 需要导入模块: from tensorflow.python.ops import rnn_cell [as 别名]
# 或者: from tensorflow.python.ops.rnn_cell import MultiRNNCell [as 别名]
def build_cell(hidden_units, depth=1):
cell_list = [build_single_cell(hidden_units) for i in range(depth)]
return MultiRNNCell(cell_list)
user_count, item_count, cate_count = pickle.load(f)