當前位置: 首頁>>代碼示例>>Python>>正文


Python search.BeamSearch類代碼示例

本文整理匯總了Python中blocks.search.BeamSearch的典型用法代碼示例。如果您正苦於以下問題:Python BeamSearch類的具體用法?Python BeamSearch怎麽用?Python BeamSearch使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。


在下文中一共展示了BeamSearch類的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_beam_search

def test_beam_search():
    """Test beam search using the model from the reverse_words demo.

    Ideally this test should be done with a trained model, but so far
    only with a randomly initialized one. So it does not really test
    the ability to find the best output sequence, but only correctness
    of returned costs.

    """
    rng = numpy.random.RandomState(1234)
    alphabet_size = 20
    beam_size = 10
    length = 15

    reverser = WordReverser(10, alphabet_size)
    reverser.weights_init = reverser.biases_init = IsotropicGaussian(0.5)
    reverser.initialize()

    inputs = tensor.lmatrix('inputs')
    samples, = VariableFilter(bricks=[reverser.generator], name="outputs")(
        ComputationGraph(reverser.generate(inputs)))

    input_vals = numpy.tile(rng.randint(alphabet_size, size=(length,)),
                            (beam_size, 1)).T

    search = BeamSearch(10, samples)
    results, mask, costs = search.search({inputs: input_vals},
                                         0, 3 * length)

    true_costs = reverser.cost(
        input_vals, numpy.ones((length, beam_size), dtype=floatX),
        results, mask).eval()
    true_costs = (true_costs * mask).sum(axis=0)
    assert_allclose(costs, true_costs, rtol=1e-5)
開發者ID:kelvinxu,項目名稱:blocks,代碼行數:34,代碼來源:test_search.py

示例2: test_beam_search

def test_beam_search():
    """Test beam search using the model similar to the reverse_words demo.

    Ideally this test should be done with a trained model, but so far
    only with a randomly initialized one. So it does not really test
    the ability to find the best output sequence, but only correctness
    of returned costs.

    """
    rng = numpy.random.RandomState(1234)
    alphabet_size = 20
    beam_size = 10
    length = 15

    simple_generator = SimpleGenerator(10, alphabet_size, seed=1234)
    simple_generator.weights_init = IsotropicGaussian(0.5)
    simple_generator.biases_init = IsotropicGaussian(0.5)
    simple_generator.initialize()

    inputs = tensor.lmatrix('inputs')
    samples, = VariableFilter(
            applications=[simple_generator.generator.generate],
            name="outputs")(
        ComputationGraph(simple_generator.generate(inputs)))

    input_vals = numpy.tile(rng.randint(alphabet_size, size=(length,)),
                            (beam_size, 1)).T

    search = BeamSearch(samples)
    results, mask, costs = search.search(
        {inputs: input_vals}, 0, 3 * length, as_arrays=True)
    # Just check sum
    assert results.sum() == 2816

    true_costs = simple_generator.cost(
        input_vals, numpy.ones((length, beam_size),
                               dtype=theano.config.floatX),
        results, mask).eval()
    true_costs = (true_costs * mask).sum(axis=0)
    assert_allclose(costs.sum(axis=0), true_costs, rtol=1e-5)

    # Test `as_lists=True`
    results2, costs2 = search.search({inputs: input_vals},
                                     0, 3 * length)
    for i in range(len(results2)):
        assert results2[i] == list(results.T[i, :mask.T[i].sum()])
開發者ID:vikkamath,項目名稱:blocks,代碼行數:46,代碼來源:test_search.py

示例3: generate

        def generate(input_):
            """Generate output sequences for an input sequence.

            Incapsulates most of the difference between sampling and beam
            search.

            Returns
            -------
            outputs : list of lists
                Trimmed output sequences.
            costs : list
                The negative log-likelihood of generating the respective
                sequences.

            """
            if mode == "beam_search":
                samples, = VariableFilter(
                    bricks=[reverser.generator], name="outputs")(
                        ComputationGraph(generated[1]))
                # NOTE: this will recompile beam search functions
                # every time user presses Enter. Do not create
                # a new `BeamSearch` object every time if
                # speed is important for you.
                beam_search = BeamSearch(input_.shape[1], samples)
                outputs, _, costs = beam_search.search(
                    {chars: input_}, char2code['</S>'],
                    3 * input_.shape[0])
            else:
                _1, outputs, _2, _3, costs = (
                    model.get_theano_function()(input_))
                costs = costs.T

            outputs = list(outputs.T)
            costs = list(costs)
            for i in range(len(outputs)):
                outputs[i] = list(outputs[i])
                try:
                    true_length = outputs[i].index(char2code['</S>']) + 1
                except ValueError:
                    true_length = len(outputs[i])
                outputs[i] = outputs[i][:true_length]
                if mode == "sample":
                    costs[i] = costs[i][:true_length].sum()
            return outputs, costs
開發者ID:kelvinxu,項目名稱:blocks,代碼行數:44,代碼來源:__init__.py

示例4: __init__

 def __init__(self, eol_symbol, beam_size, x, x_mask, samples,
              phoneme_dict=None, black_list=None, language_model=False):
     if black_list is None:
         self.black_list = []
     else:
         self.black_list = black_list
     self.x = x
     self.x_mask = x_mask
     self.eol_symbol = eol_symbol
     self.beam_size = beam_size
     if language_model:
         lm = TrigramLanguageModel()
         ind_to_word = dict(enumerate(lm.unigrams))
         self.beam_search = BeamSearchLM(
             lm, 1., ind_to_word, beam_size, samples)
     else:
         self.beam_search = BeamSearch(beam_size, samples)
     self.beam_search.compile()
     self.phoneme_dict = phoneme_dict
開發者ID:EricDoug,項目名稱:recurrent-batch-normalization,代碼行數:19,代碼來源:ctc_monitoring.py

示例5: __init__

    def __init__(
        self,
        source_sentence,
        samples,
        model,
        data_stream,
        config,
        n_best=1,
        track_n_models=1,
        trg_ivocab=None,
        src_eos_idx=-1,
        trg_eos_idx=-1,
        **kwargs
    ):
        super(BleuValidator, self).__init__(**kwargs)
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.track_n_models = track_n_models
        self.verbose = config.get("val_set_out", None)

        self.src_eos_idx = src_eos_idx
        self.trg_eos_idx = trg_eos_idx

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.src_eos_idx  # self.vocab[self.eos_sym]
        self.best_models = []
        self.val_bleu_curve = []
        self.beam_search = BeamSearch(beam_size=self.config["beam_size"], samples=samples)
        self.multibleu_cmd = ["perl", self.config["bleu_script"], self.config["val_set_grndtruth"], "<"]

        # Create saving directory if it does not exist
        if not os.path.exists(self.config["saveto"]):
            os.makedirs(self.config["saveto"])

        if self.config["reload"]:
            try:
                bleu_score = numpy.load(os.path.join(self.config["saveto"], "val_bleu_scores.npz"))
                self.val_bleu_curve = bleu_score["bleu_scores"].tolist()

                # Track n best previous bleu scores
                for i, bleu in enumerate(sorted(self.val_bleu_curve, reverse=True)):
                    if i < self.track_n_models:
                        self.best_models.append(ModelInfo(bleu))
                logger.info("BleuScores Reloaded")
            except:
                logger.info("BleuScores not Found")
開發者ID:rizar,項目名稱:NMT,代碼行數:55,代碼來源:sampling.py

示例6: __init__

 def __init__(self, eol_symbol, beam_size, x, x_mask, samples,
              phoneme_dict=None, black_list=None):
     if black_list is None:
         self.black_list = []
     else:
         self.black_list = black_list
     self.x = x
     self.x_mask = x_mask
     self.eol_symbol = eol_symbol
     self.beam_size = beam_size
     self.beam_search = BeamSearch(beam_size, samples)
     self.beam_search.compile()
     self.phoneme_dict = phoneme_dict
開發者ID:EricDoug,項目名稱:recurrent-batch-normalization,代碼行數:13,代碼來源:monitoring.py

示例7: init_beam_search

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        self.beam_size = beam_size
        generated = self.get_generate_graph()
        samples, = VariableFilter(
            applications=[self.generator.generate], name="outputs")(
            ComputationGraph(generated['outputs']))
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()
開發者ID:ZhangAustin,項目名稱:attention-lvcsr,代碼行數:13,代碼來源:recognizer.py

示例8: __init__

    def __init__(self, source_sentence, samples, model, data_stream,
                 config, n_best=1, track_n_models=1, trg_ivocab=None,
                 normalize=True, store_full_main_loop=False, **kwargs):
        # TODO: change config structure
        super(BleuValidator, self).__init__(**kwargs)
        self.store_full_main_loop = store_full_main_loop
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.track_n_models = track_n_models
        self.normalize = normalize
        self.verbose = config.get('val_set_out', None)

        # Helpers
        #self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        #self.unk_sym = data_stream.dataset.unk_token
        #self.eos_sym = data_stream.dataset.eos_token
        #self.unk_idx = self.vocab[self.unk_sym]
        #self.eos_idx = self.vocab[self.eos_sym]
        self.unk_idx = 0 # fs439: TODO hardcoded
        self.eos_idx = 2 # fs439: TODO hardcoded
        self.best_models = []
        self.val_bleu_curve = []
        self.beam_search = BeamSearch(samples=samples)
        self.multibleu_cmd = (self.config['bleu_script'] % self.config['val_set_grndtruth']).split()
        print("BLEU command: %s" % self.multibleu_cmd)

        # Create saving directory if it does not exist
        if not os.path.exists(self.config['saveto']):
            os.makedirs(self.config['saveto'])

        if self.config['reload']:
            try:
                bleu_score = numpy.load(os.path.join(self.config['saveto'],
                                        'val_bleu_scores.npz'))
                self.val_bleu_curve = bleu_score['bleu_scores'].tolist()

                # Track n best previous bleu scores
                for i, bleu in enumerate(
                        sorted(self.val_bleu_curve, reverse=True)):
                    if i < self.track_n_models:
                        self.best_models.append(ModelInfo(bleu))
                logger.info("BleuScores Reloaded")
            except:
                logger.info("BleuScores not Found")
開發者ID:chagge,項目名稱:sgnmt,代碼行數:49,代碼來源:sampling.py

示例9: set_up_decoder

 def set_up_decoder(self, nmt_model_path):
     """This method uses the NMT configuration in ``self.config`` to
     initialize the NMT model. This method basically corresponds to 
     ``blocks.machine_translation.main``.
     
     Args:
         nmt_model_path (string):  Path to the NMT model file (.npz)
     """
     self.nmt_model = NMTModel(self.config)
     self.nmt_model.set_up()
     loader = LoadNMTUtils(nmt_model_path,
                           self.config['saveto'],
                           self.nmt_model.search_model)
     loader.load_weights()
     self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
             if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
     if self.config['trg_sparse_feat_map']:
         self.trg_sparse_feat_map = self.config['trg_sparse_feat_map']
         self.beam_search = SparseBeamSearch(
                              samples=self.nmt_model.samples, 
                              trg_sparse_feat_map=self.trg_sparse_feat_map) 
     else:
         self.trg_sparse_feat_map = FlatSparseFeatMap()
         self.beam_search = BeamSearch(samples=self.nmt_model.samples)
開發者ID:ucam-smt,項目名稱:sgnmt,代碼行數:24,代碼來源:vanilla_decoder.py

示例10: init_beam_search

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        if hasattr(self, '_beam_search') and self.beam_size == beam_size:
            # Only recompile if the user wants a different beam size
            return
        self.beam_size = beam_size
        generated = self.get_generate_graph(use_mask=False, n_steps=3)
        cg = ComputationGraph(generated.values())
        samples, = VariableFilter(
            applications=[self.generator.generate], name="outputs")(cg)
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()
開發者ID:DingKe,項目名稱:attention-lvcsr,代碼行數:16,代碼來源:recognizer.py

示例11: __init__

    def __init__(self, source_sentence, samples, model, data_stream, ground_truth, config,
                 val_out=None, val_best_out=None, n_best=1, normalize=True, **kwargs):
        # TODO: change config structure
        super(BleuEvaluator, self).__init__(**kwargs)
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.normalize = normalize
        self.val_out = val_out
        self.val_best_out = val_out and val_best_out
        self.bleu_scores = []

        self.trg_ivocab = None
        self.unk_id = config['unk_id']
        self.eos_id = config['eos_id']
        self.beam_search = BeamSearch(samples=samples)
        self.multibleu_cmd = ['perl', self.config['bleu_script'], ground_truth, '<']
開發者ID:eske,項目名稱:blocks-examples,代碼行數:20,代碼來源:sampling.py

示例12: BleuValidator

class BleuValidator(SimpleExtension, SamplingBase):
    def __init__(
        self,
        source_sentence,
        samples,
        model,
        data_stream,
        config,
        n_best=1,
        track_n_models=1,
        trg_ivocab=None,
        src_eos_idx=-1,
        trg_eos_idx=-1,
        **kwargs
    ):
        super(BleuValidator, self).__init__(**kwargs)
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.track_n_models = track_n_models
        self.verbose = config.get("val_set_out", None)

        self.src_eos_idx = src_eos_idx
        self.trg_eos_idx = trg_eos_idx

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.src_eos_idx  # self.vocab[self.eos_sym]
        self.best_models = []
        self.val_bleu_curve = []
        self.beam_search = BeamSearch(beam_size=self.config["beam_size"], samples=samples)
        self.multibleu_cmd = ["perl", self.config["bleu_script"], self.config["val_set_grndtruth"], "<"]

        # Create saving directory if it does not exist
        if not os.path.exists(self.config["saveto"]):
            os.makedirs(self.config["saveto"])

        if self.config["reload"]:
            try:
                bleu_score = numpy.load(os.path.join(self.config["saveto"], "val_bleu_scores.npz"))
                self.val_bleu_curve = bleu_score["bleu_scores"].tolist()

                # Track n best previous bleu scores
                for i, bleu in enumerate(sorted(self.val_bleu_curve, reverse=True)):
                    if i < self.track_n_models:
                        self.best_models.append(ModelInfo(bleu))
                logger.info("BleuScores Reloaded")
            except:
                logger.info("BleuScores not Found")

    def do(self, which_callback, *args):

        # Track validation burn in
        if self.main_loop.status["iterations_done"] <= self.config["val_burn_in"]:
            return

        # Get current model parameters
        self.model.set_param_values(self.main_loop.model.get_param_values())

        # Evaluate and save if necessary
        self._save_model(self._evaluate_model())

    def _evaluate_model(self):

        logger.info("Started Validation: ")
        val_start_time = time.time()
        mb_subprocess = Popen(self.multibleu_cmd, stdin=PIPE, stdout=PIPE)
        total_cost = 0.0

        # Get target vocabulary
        if not self.trg_ivocab:
            sources = self._get_attr_rec(self.main_loop, "data_stream")
            trg_vocab = sources.data_streams[1].dataset.dictionary
            self.trg_ivocab = {v: k for k, v in trg_vocab.items()}

        if self.verbose:
            ftrans = open(self.config["val_set_out"], "w")

        for i, line in enumerate(self.data_stream.get_epoch_iterator()):
            """
            Load the sentence, retrieve the sample, write to file
            """

            line[0][-1] = self.src_eos_idx
            seq = self._oov_to_unk(line[0])
            input_ = numpy.tile(seq, (self.config["beam_size"], 1))

            # draw sample, checking to ensure we don't get an empty string back
            trans, costs = self.beam_search.search(
                input_values={self.source_sentence: input_},
                max_length=3 * len(seq),
                eol_symbol=self.trg_eos_idx,
                ignore_first_eol=True,
#.........這裏部分代碼省略.........
開發者ID:rizar,項目名稱:NMT,代碼行數:101,代碼來源:sampling.py

示例13: SpeechRecognizer


#.........這裏部分代碼省略.........
            prediction_variable = tensor.lvector('prediction')
            if prediction is not None:
                input_variables.append(prediction_variable)
                cg = self.get_cost_graph(
                    batch=False, prediction=prediction_variable[:, None])
            else:
                cg = self.get_cost_graph(batch=False)
            cost = cg.outputs[0]

            weights, = VariableFilter(
                bricks=[self.generator], name="weights")(cg)

            energies = VariableFilter(
                bricks=[self.generator], name="energies")(cg)
            energies_output = [energies[0][:, 0, :] if energies
                               else tensor.zeros_like(weights)]

            states, = VariableFilter(
                applications=[self.encoder.apply], roles=[OUTPUT],
                name="encoded")(cg)

            ctc_matrix_output = []
            # Temporarily disabled for compatibility with LM code
            # if len(self.generator.readout.source_names) == 1:
            #    ctc_matrix_output = [
            #        self.generator.readout.readout(weighted_averages=states)[:, 0, :]]

            self._analyze = theano.function(
                input_variables,
                [cost[:, 0], weights[:, 0, :]] + energies_output + ctc_matrix_output,
                on_unused_input='warn')
        return self._analyze(**input_values_dict)

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        if hasattr(self, '_beam_search') and self.beam_size == beam_size:
            # Only recompile if the user wants a different beam size
            return
        self.beam_size = beam_size
        generated = self.get_generate_graph(use_mask=False, n_steps=3)
        cg = ComputationGraph(generated.values())
        samples, = VariableFilter(
            applications=[self.generator.generate], name="outputs")(cg)
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()

    def beam_search(self, inputs, **kwargs):
        # When a recognizer is unpickled, self.beam_size is available
        # but beam search has to be recompiled.

        self.init_beam_search(self.beam_size)
        inputs = dict(inputs)
        max_length = int(self.bottom.num_time_steps(**inputs) /
                         self.max_decoded_length_scale)
        search_inputs = {}
        for var in self.inputs.values():
            search_inputs[var] = inputs.pop(var.name)[:, numpy.newaxis, ...]
        if inputs:
            raise Exception(
                'Unknown inputs passed to beam search: {}'.format(
                    inputs.keys()))
        outputs, search_costs = self._beam_search.search(
            search_inputs, self.eos_label,
            max_length,
            ignore_first_eol=self.data_prepend_eos,
            **kwargs)
        return outputs, search_costs

    def init_generate(self):
        generated = self.get_generate_graph(use_mask=False)
        cg = ComputationGraph(generated['outputs'])
        self._do_generate = cg.get_theano_function()

    def sample(self, inputs, n_steps=None):
        if not hasattr(self, '_do_generate'):
            self.init_generate()
        batch, unused_mask = self.bottom.single_to_batch_inputs(inputs)
        batch['n_steps'] = n_steps if n_steps is not None \
            else int(self.bottom.num_time_steps(**batch) /
                     self.max_decoded_length_scale)
        return self._do_generate(**batch)[0]

    def __getstate__(self):
        state = dict(self.__dict__)
        for attr in ['_analyze', '_beam_search']:
            state.pop(attr, None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # To use bricks used on a GPU first on a CPU later
        try:
            emitter = self.generator.readout.emitter
            del emitter._theano_rng
        except:
            pass
開發者ID:DingKe,項目名稱:attention-lvcsr,代碼行數:101,代碼來源:recognizer.py

示例14: test_beam_search_smallest

def test_beam_search_smallest():
    a = numpy.array([[3, 6, 4], [1, 2, 7]])
    ind, mins = BeamSearch._smallest(a, 2)
    assert numpy.all(numpy.array(ind) == numpy.array([[1, 1], [0, 1]]))
    assert numpy.all(mins == [1, 2])
開發者ID:kelvinxu,項目名稱:blocks,代碼行數:5,代碼來源:test_search.py

示例15: BleuValidator

class BleuValidator(SimpleExtension, SamplingBase):
    # TODO: a lot has been changed in NMT, sync respectively
    """Implements early stopping based on BLEU score."""

    def __init__(self, source_sentence, samples, model, data_stream,
                 config, n_best=1, track_n_models=1,
                 normalize=True, **kwargs):
        # TODO: change config structure
        super(BleuValidator, self).__init__(**kwargs)
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.track_n_models = track_n_models
        self.normalize = normalize
        self.verbose = config.get('val_set_out', None)

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.vocab[self.eos_sym]
        self.best_models = []
        self.val_bleu_curve = []
        self.beam_search = BeamSearch(samples=samples)
        self.multibleu_cmd = ['perl', self.config['bleu_script'],
                              self.config['val_set_grndtruth'], '<']

        # Create saving directory if it does not exist
        if not os.path.exists(self.config['saveto']):
            os.makedirs(self.config['saveto'])

        if self.config['reload']:
            try:
                bleu_score = numpy.load(os.path.join(self.config['saveto'],
                                        'val_bleu_scores.npz'))
                self.val_bleu_curve = bleu_score['bleu_scores'].tolist()

                # Track n best previous bleu scores
                for i, bleu in enumerate(
                        sorted(self.val_bleu_curve, reverse=True)):
                    if i < self.track_n_models:
                        self.best_models.append(ModelInfo(bleu))
                logger.info("BleuScores Reloaded")
            except:
                logger.info("BleuScores not Found")

    def do(self, which_callback, *args):

        # Track validation burn in
        if self.main_loop.status['iterations_done'] <= \
                self.config['val_burn_in']:
            return

        # Evaluate and save if necessary
        self._save_model(self._evaluate_model())

    def _evaluate_model(self):

        logger.info("Started Validation: ")
        val_start_time = time.time()
        mb_subprocess = Popen(self.multibleu_cmd, stdin=PIPE, stdout=PIPE)
        total_cost = 0.0

        # Get target vocabulary
        sources = self._get_attr_rec(self.main_loop, 'data_stream')
        trg_vocab = sources.data_streams[1].dataset.dictionary
        self.trg_ivocab = {v: k for k, v in trg_vocab.items()}
        trg_eos_sym = sources.data_streams[1].dataset.eos_token
        self.trg_eos_idx = trg_vocab[trg_eos_sym]

        if self.verbose:
            ftrans = open(self.config['val_set_out'], 'w')

        for i, line in enumerate(self.data_stream.get_epoch_iterator()):
            """
            Load the sentence, retrieve the sample, write to file
            """

            seq = self._oov_to_unk(
                line[0], self.config['src_vocab_size'], self.unk_idx)
            input_ = numpy.tile(seq, (self.config['beam_size'], 1))

            # draw sample, checking to ensure we don't get an empty string back
            trans, costs = \
                self.beam_search.search(
                    input_values={self.source_sentence: input_},
                    max_length=3*len(seq), eol_symbol=self.trg_eos_idx,
                    ignore_first_eol=True)

            # normalize costs according to the sequence lengths
            if self.normalize:
                lengths = numpy.array([len(s) for s in trans])
                costs = costs / lengths

            nbest_idx = numpy.argsort(costs)[:self.n_best]
            for j, best in enumerate(nbest_idx):
#.........這裏部分代碼省略.........
開發者ID:MLDL,項目名稱:blocks-examples,代碼行數:101,代碼來源:sampling.py


注:本文中的blocks.search.BeamSearch類示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。