本文整理汇总了Python中tensorflow.contrib.seq2seq.python.ops.decoder.dynamic_decode函数的典型用法代码示例。如果您正苦于以下问题:Python dynamic_decode函数的具体用法?Python dynamic_decode怎么用?Python dynamic_decode使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了dynamic_decode函数的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testLuongScaledDType
def testLuongScaledDType(self):
# Test case for GitHub issue 18099
for dtype in [np.float16, np.float32, np.float64]:
num_units = 128
encoder_outputs = array_ops.placeholder(dtype, shape=[64, None, 256])
encoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
decoder_inputs = array_ops.placeholder(dtype, shape=[64, None, 128])
decoder_sequence_length = array_ops.placeholder(dtypes.int32, shape=[64])
batch_size = 64
attention_mechanism = wrapper.LuongAttention(
num_units=num_units,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length,
scale=True,
dtype=dtype,
)
cell = rnn_cell.LSTMCell(num_units)
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtype, batch_size=batch_size))
final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
self.assertTrue(
isinstance(final_state, wrapper.AttentionWrapperState))
self.assertTrue(
isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
示例2: _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN
def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
self, use_sequence_length):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
input_depth = 7
cell_depth = 10
max_out = max(sequence_length)
with self.session(use_gpu=True) as sess:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
helper = helper_py.TrainingHelper(inputs, sequence_length)
my_decoder = basic_decoder.BasicDecoder(
cell=cell, helper=helper, initial_state=zero_state)
# Match the variable scope of dynamic_rnn below so we end up
# using the same variables
with vs.variable_scope("root") as scope:
final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
my_decoder,
# impute_finished=True ensures outputs and final state
# match those of dynamic_rnn called with sequence_length not None
impute_finished=use_sequence_length,
scope=scope)
with vs.variable_scope(scope, reuse=True) as scope:
final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn(
cell,
inputs,
sequence_length=sequence_length if use_sequence_length else None,
initial_state=zero_state,
scope=scope)
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_decoder_outputs": final_decoder_outputs,
"final_decoder_state": final_decoder_state,
"final_rnn_outputs": final_rnn_outputs,
"final_rnn_state": final_rnn_state
})
# Decoder only runs out to max_out; ensure values are identical
# to dynamic_rnn, which also zeros out outputs and passes along state.
self.assertAllClose(sess_results["final_decoder_outputs"].rnn_output,
sess_results["final_rnn_outputs"][:, 0:max_out, :])
if use_sequence_length:
self.assertAllClose(sess_results["final_decoder_state"],
sess_results["final_rnn_state"])
示例3: inference
def inference(self, inputs, train=True):
config = self.config
# extract character representations from embedding
with tf.variable_scope('embedding', initializer=tf.contrib.layers.xavier_initializer()):
embedding = tf.get_variable('embedding',
shape=(config.vocab_size, config.embed_dim), dtype=tf.float32)
embedded_inputs = tf.nn.embedding_lookup(embedding, inputs['text'])
# extract speaker embedding if multi-speaker
with tf.variable_scope('speaker'):
if config.num_speakers > 1:
speaker_embed = tf.get_variable('speaker_embed',
shape=(config.num_speakers, config.speaker_embed_dim), dtype=tf.float32)
speaker_embed = \
tf.nn.embedding_lookup(speaker_embed, inputs['speaker'])
else:
speaker_embed = None
# process text input with CBHG module
with tf.variable_scope('encoder'):
pre_out = self.pre_net(embedded_inputs, dropout=config.char_dropout_prob, train=train)
tf.summary.histogram('pre_net_out', pre_out)
encoded = ops.CBHG(pre_out, speaker_embed, K=16, c=[128,128,128], gru_units=128)
# pass through attention based decoder
with tf.variable_scope('decoder'):
dec = self.create_decoder(encoded, inputs, speaker_embed, train)
(seq2seq_output, _), attention_state, _ = \
decoder.dynamic_decode(dec, maximum_iterations=config.max_decode_iter)
self.alignments = tf.transpose(attention_state.alignment_history.stack(), [1,0,2])
tf.summary.histogram('seq2seq_output', seq2seq_output)
# use second CBHG module to process mel features into linear spectogram
with tf.variable_scope('post-process'):
# reshape to account for r value
post_input = tf.reshape(seq2seq_output,
(tf.shape(seq2seq_output)[0], -1, config.mel_features))
output = ops.CBHG(post_input, K=8, c=[128,256,80], gru_units=128)
output = tf.layers.dense(output, units=config.fft_size)
# reshape back to r frame representation
output = tf.reshape(output, (tf.shape(output)[0], -1, config.fft_size*config.r))
tf.summary.histogram('output', output)
return seq2seq_output, output
示例4: _init_decoder
def _init_decoder(self):
with tf.variable_scope('Decoder') as scope:
self.fc_layer = Dense(self.vocab_size)
if self.is_inference:
self.start_tokens = tf.placeholder(tf.int32,shape=[None],name='start_tokens')
self.end_token = tf.placeholder(tf.int32,name='end_token')
self.helper = seq2seq.GreedyEmbeddingHelper(
embedding=self.embedding_matrix,
start_tokens=self.start_tokens,
end_token=self.end_token
)
else:
self.helper = seq2seq.TrainingHelper(
inputs=self.decoder_train_inputs_embedded,
sequence_length=self.decoder_train_length,
time_major=True
)
self.decoder = seq2seq.BasicDecoder(
cell=self.decoder_cell,
helper=self.helper,
initial_state=self.encoder_state,
output_layer=self.fc_layer
)
(self.decoder_outputs_train,
self.decoder_state_train,
self.decoder_context_state_train
) = (
decoder.dynamic_decode(
self.decoder,
output_time_major=True)
)
self.logits = self.decoder_outputs_train.rnn_output
if not self.is_inference:
self.decoder_prediction_inference = tf.argmax(self.logits, axis=-1, name='decoder_prediction_inference')
self.decoder_prediction_train = tf.argmax(self.decoder_outputs_train.rnn_output, axis=-1, name='decoder_prediction_train')
self._init_optimizer()
else:
self.prob = tf.nn.softmax(self.logits)
示例5: _testWithAttention
def _testWithAttention(self,
create_attention_mechanism,
expected_final_output,
expected_final_state,
attention_mechanism_depth=3,
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_size=6,
name=""):
encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
encoder_max_time = 8
decoder_max_time = 4
input_depth = 7
encoder_output_depth = 10
cell_depth = 9
if attention_layer_size is not None:
attention_depth = attention_layer_size
else:
attention_depth = encoder_output_depth
decoder_inputs = np.random.randn(batch_size, decoder_max_time,
input_depth).astype(np.float32)
encoder_outputs = np.random.randn(batch_size, encoder_max_time,
encoder_output_depth).astype(np.float32)
attention_mechanism = create_attention_mechanism(
num_units=attention_mechanism_depth,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
with self.test_session(use_gpu=True) as sess:
with vs.variable_scope(
"root",
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
cell = core_rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
attention_mechanism,
attention_layer_size=attention_layer_size,
alignment_history=alignment_history)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(
isinstance(final_state, wrapper.AttentionWrapperState))
self.assertTrue(
isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))
self.assertEqual((batch_size, None, attention_depth),
tuple(final_outputs.rnn_output.get_shape().as_list()))
self.assertEqual((batch_size, None),
tuple(final_outputs.sample_id.get_shape().as_list()))
self.assertEqual((batch_size, attention_depth),
tuple(final_state.attention.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.c.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.h.get_shape().as_list()))
if alignment_history:
state_alignment_history = final_state.alignment_history.stack()
# Remove the history from final_state for purposes of the
# remainder of the tests.
final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access
self.assertEqual((None, batch_size, encoder_max_time),
tuple(state_alignment_history.get_shape().as_list()))
else:
state_alignment_history = ()
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_outputs": final_outputs,
"final_state": final_state,
"state_alignment_history": state_alignment_history,
})
print("Copy/paste (%s)\nexpected_final_output = " % name,
sess_results["final_outputs"])
sys.stdout.flush()
print("Copy/paste (%s)\nexpected_final_state = " % name,
sess_results["final_state"])
sys.stdout.flush()
print("Copy/paste (%s)\nexpected_final_alignment_history = " % name,
np.asarray(sess_results["state_alignment_history"]))
sys.stdout.flush()
nest.map_structure(self.assertAllClose, expected_final_output,
sess_results["final_outputs"])
#.........这里部分代码省略.........
示例6: _testWithAttention
def _testWithAttention(self,
create_attention_mechanism,
expected_final_outputs,
expected_final_state,
attention_mechanism_depth=3):
encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
encoder_max_time = 8
decoder_max_time = 4
input_depth = 7
encoder_output_depth = 10
cell_depth = 9
attention_depth = 6
decoder_inputs = np.random.randn(batch_size, decoder_max_time,
input_depth).astype(np.float32)
encoder_outputs = np.random.randn(batch_size, encoder_max_time,
encoder_output_depth).astype(np.float32)
attention_mechanism = create_attention_mechanism(
num_units=attention_mechanism_depth,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
with self.test_session() as sess:
with vs.variable_scope(
"root",
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
cell = core_rnn_cell.LSTMCell(cell_depth)
cell = wrapper.DynamicAttentionWrapper(
cell, attention_mechanism, attention_size=attention_depth)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
final_outputs, final_state = decoder.dynamic_decode(my_decoder)
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(
isinstance(final_state, wrapper.DynamicAttentionWrapperState))
self.assertTrue(
isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))
self.assertEqual((batch_size, None, attention_depth),
tuple(final_outputs.rnn_output.get_shape().as_list()))
self.assertEqual((batch_size, None),
tuple(final_outputs.sample_id.get_shape().as_list()))
self.assertEqual((batch_size, attention_depth),
tuple(final_state.attention.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.c.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.h.get_shape().as_list()))
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_outputs": final_outputs,
"final_state": final_state
})
nest.map_structure(self.assertAllClose, expected_final_outputs,
sess_results["final_outputs"])
nest.map_structure(self.assertAllClose, expected_final_state,
sess_results["final_state"])
示例7: _testDynamicDecodeRNN
def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
input_depth = 7
cell_depth = 10
max_out = max(sequence_length)
with self.session(use_gpu=True) as sess:
if time_major:
inputs = np.random.randn(max_time, batch_size,
input_depth).astype(np.float32)
else:
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
helper = helper_py.TrainingHelper(
inputs, sequence_length, time_major=time_major)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
final_outputs, final_state, final_sequence_length = (
decoder.dynamic_decode(my_decoder, output_time_major=time_major,
maximum_iterations=maximum_iterations))
def _t(shape):
if time_major:
return (shape[1], shape[0]) + shape[2:]
return shape
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(isinstance(final_state, rnn_cell.LSTMStateTuple))
self.assertEqual(
(batch_size,),
tuple(final_sequence_length.get_shape().as_list()))
self.assertEqual(
_t((batch_size, None, cell_depth)),
tuple(final_outputs.rnn_output.get_shape().as_list()))
self.assertEqual(
_t((batch_size, None)),
tuple(final_outputs.sample_id.get_shape().as_list()))
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_outputs": final_outputs,
"final_state": final_state,
"final_sequence_length": final_sequence_length,
})
# Mostly a smoke test
time_steps = max_out
expected_length = sequence_length
if maximum_iterations is not None:
time_steps = min(max_out, maximum_iterations)
expected_length = [min(x, maximum_iterations) for x in expected_length]
self.assertEqual(
_t((batch_size, time_steps, cell_depth)),
sess_results["final_outputs"].rnn_output.shape)
self.assertEqual(
_t((batch_size, time_steps)),
sess_results["final_outputs"].sample_id.shape)
self.assertItemsEqual(expected_length,
sess_results["final_sequence_length"])
示例8: _testWithMaybeMultiAttention
def _testWithMaybeMultiAttention(self,
is_multi,
create_attention_mechanisms,
expected_final_output,
expected_final_state,
attention_mechanism_depths,
alignment_history=False,
expected_final_alignment_history=None,
attention_layer_sizes=None,
attention_layers=None,
name=''):
# Allow is_multi to be True with a single mechanism to enable test for
# passing in a single mechanism in a list.
assert len(create_attention_mechanisms) == 1 or is_multi
encoder_sequence_length = [3, 2, 3, 1, 1]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
encoder_max_time = 8
decoder_max_time = 4
input_depth = 7
encoder_output_depth = 10
cell_depth = 9
if attention_layer_sizes is not None:
# Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
attention_depth = sum([attention_layer_size or encoder_output_depth
for attention_layer_size in attention_layer_sizes])
elif attention_layers is not None:
# Compute sum of attention_layers output depth.
attention_depth = sum(
attention_layer.compute_output_shape(
[batch_size, cell_depth + encoder_output_depth])[-1].value
for attention_layer in attention_layers)
else:
attention_depth = encoder_output_depth * len(create_attention_mechanisms)
decoder_inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time,
input_depth).astype(np.float32),
shape=(None, None, input_depth))
encoder_outputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, encoder_max_time,
encoder_output_depth).astype(np.float32),
shape=(None, None, encoder_output_depth))
attention_mechanisms = [
creator(num_units=depth,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
for creator, depth in zip(create_attention_mechanisms,
attention_mechanism_depths)]
with self.test_session(use_gpu=True) as sess:
with vs.variable_scope(
'root',
initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
attention_layer_size = attention_layer_sizes
attention_layer = attention_layers
if not is_multi:
if attention_layer_size is not None:
attention_layer_size = attention_layer_size[0]
if attention_layer is not None:
attention_layer = attention_layer[0]
cell = rnn_cell.LSTMCell(cell_depth)
cell = wrapper.AttentionWrapper(
cell,
attention_mechanisms if is_multi else attention_mechanisms[0],
attention_layer_size=attention_layer_size,
alignment_history=alignment_history,
attention_layer=attention_layer)
helper = helper_py.TrainingHelper(decoder_inputs,
decoder_sequence_length)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(
isinstance(final_state, wrapper.AttentionWrapperState))
self.assertTrue(
isinstance(final_state.cell_state, rnn_cell.LSTMStateTuple))
self.assertEqual((batch_size, None, attention_depth),
tuple(final_outputs.rnn_output.get_shape().as_list()))
self.assertEqual((batch_size, None),
tuple(final_outputs.sample_id.get_shape().as_list()))
self.assertEqual((batch_size, attention_depth),
tuple(final_state.attention.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.c.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.h.get_shape().as_list()))
if alignment_history:
#.........这里部分代码省略.........
示例9: _testDynamicDecodeRNN
def _testDynamicDecodeRNN(self, time_major, has_attention,
with_alignment_history=False):
encoder_sequence_length = np.array([3, 2, 3, 1, 1])
decoder_sequence_length = np.array([2, 0, 1, 2, 3])
batch_size = 5
decoder_max_time = 4
input_depth = 7
cell_depth = 9
attention_depth = 6
vocab_size = 20
end_token = vocab_size - 1
start_token = 0
embedding_dim = 50
max_out = max(decoder_sequence_length)
output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
beam_width = 3
with self.cached_session() as sess:
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
coverage_penalty_weight = 0.0
if has_attention:
coverage_penalty_weight = 0.2
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
np.float32),
shape=(None, None, input_depth))
tiled_inputs = beam_search_decoder.tile_batch(
inputs, multiplier=beam_width)
tiled_sequence_length = beam_search_decoder.tile_batch(
encoder_sequence_length, multiplier=beam_width)
attention_mechanism = attention_wrapper.BahdanauAttention(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
initial_state = beam_search_decoder.tile_batch(
initial_state, multiplier=beam_width)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
attention_layer_size=attention_depth,
alignment_history=with_alignment_history)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
cell_state = cell_state.clone(cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,
start_tokens=array_ops.fill([batch_size_tensor], start_token),
end_token=end_token,
initial_state=cell_state,
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0,
coverage_penalty_weight=coverage_penalty_weight)
final_outputs, final_state, final_sequence_lengths = (
decoder.dynamic_decode(
bsd, output_time_major=time_major, maximum_iterations=max_out))
def _t(shape):
if time_major:
return (shape[1], shape[0]) + shape[2:]
return shape
self.assertTrue(
isinstance(final_outputs,
beam_search_decoder.FinalBeamSearchDecoderOutput))
self.assertTrue(
isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
beam_search_decoder_output = final_outputs.beam_search_decoder_output
self.assertEqual(
_t((batch_size, None, beam_width)),
tuple(beam_search_decoder_output.scores.get_shape().as_list()))
self.assertEqual(
_t((batch_size, None, beam_width)),
tuple(final_outputs.predicted_ids.get_shape().as_list()))
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
'final_outputs': final_outputs,
'final_state': final_state,
'final_sequence_lengths': final_sequence_lengths
})
max_sequence_length = np.max(sess_results['final_sequence_lengths'])
# A smoke test
self.assertEqual(
_t((batch_size, max_sequence_length, beam_width)),
sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
self.assertEqual(
_t((batch_size, max_sequence_length, beam_width)), sess_results[
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
示例10: _testDynamicDecodeRNN
def _testDynamicDecodeRNN(self, time_major, has_attention):
encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
decoder_max_time = 4
input_depth = 7
cell_depth = 9
attention_depth = 6
vocab_size = 20
end_token = vocab_size - 1
start_token = 0
embedding_dim = 50
max_out = max(decoder_sequence_length)
output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
beam_width = 3
with self.test_session() as sess:
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = core_rnn_cell.LSTMCell(cell_depth)
if has_attention:
inputs = np.random.randn(batch_size, decoder_max_time,
input_depth).astype(np.float32)
attention_mechanism = attention_wrapper.BahdanauAttention(
num_units=attention_depth,
memory=inputs,
memory_sequence_length=encoder_sequence_length)
cell = attention_wrapper.AttentionWrapper(
cell=cell,
attention_mechanism=attention_mechanism,
attention_size=attention_depth,
alignment_history=False)
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size * beam_width)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,
start_tokens=batch_size * [start_token],
end_token=end_token,
initial_state=cell_state,
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0)
final_outputs, final_state = decoder.dynamic_decode(
bsd, output_time_major=time_major, maximum_iterations=max_out)
def _t(shape):
if time_major:
return (shape[1], shape[0]) + shape[2:]
return shape
self.assertTrue(
isinstance(final_outputs,
beam_search_decoder.FinalBeamSearchDecoderOutput))
self.assertTrue(
isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
beam_search_decoder_output = final_outputs.beam_search_decoder_output
self.assertEqual(
_t((batch_size, None, beam_width)),
tuple(beam_search_decoder_output.scores.get_shape().as_list()))
self.assertEqual(
_t((batch_size, None, beam_width)),
tuple(final_outputs.predicted_ids.get_shape().as_list()))
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
'final_outputs': final_outputs,
'final_state': final_state
})
# Mostly a smoke test
time_steps = max_out
self.assertEqual(
_t((batch_size, time_steps, beam_width)),
sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
self.assertEqual(
_t((batch_size, time_steps, beam_width)), sess_results[
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)