当前位置: 首页>>代码示例>>Python>>正文


Python SimpleRecurrent.initial_state方法代码示例

本文整理汇总了Python中blocks.bricks.recurrent.SimpleRecurrent.initial_state方法的典型用法代码示例。如果您正苦于以下问题:Python SimpleRecurrent.initial_state方法的具体用法?Python SimpleRecurrent.initial_state怎么用?Python SimpleRecurrent.initial_state使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在blocks.bricks.recurrent.SimpleRecurrent的用法示例。


在下文中一共展示了SimpleRecurrent.initial_state方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from blocks.bricks.recurrent import SimpleRecurrent [as 别名]
# 或者: from blocks.bricks.recurrent.SimpleRecurrent import initial_state [as 别名]

#.........这里部分代码省略.........
        weights_init=Uniform(std=0.01),
        biases_init=Constant(0.)
    )
    score_layer.initialize()

    embedding = (linear_embedding.apply(x_int[:-1])
                 * tensor.shape_padright(m.T[1:]))
    rnn_out = rnn.apply(inputs=embedding, mask=m.T[1:])
    probs = softmax(
        sequence_map(score_layer.apply, rnn_out, mask=m.T[1:])[0]
    )
    idx_mask = m.T[1:].nonzero()
    cost = CategoricalCrossEntropy().apply(
        x_int[1:][idx_mask[0], idx_mask[1]],
        probs[idx_mask[0], idx_mask[1]]
    )
    cost.name = 'cost'
    misclassification = MisclassificationRate().apply(
        x_int[1:][idx_mask[0], idx_mask[1]],
        probs[idx_mask[0], idx_mask[1]]
    )
    misclassification.name = 'misclassification'

    cg = ComputationGraph([cost])
    params = cg.parameters

    algorithm = GradientDescent(
        cost=cost,
        params=params,
        step_rule=Adam()
    )

    train_data_stream = Padding(
        data_stream=DataStream(
            dataset=train_dataset,
            iteration_scheme=BatchwiseShuffledScheme(
                examples=train_dataset.num_examples,
                batch_size=10,
            )
        ),
        mask_sources=('features',)
    )

    model = Model(cost)

    extensions = []
    extensions.append(Timing())
    extensions.append(FinishAfter(after_n_epochs=num_epochs))
    extensions.append(TrainingDataMonitoring(
        [cost, misclassification],
        prefix='train',
        after_epoch=True))

    batch_size = 10
    length = 30
    trng = MRG_RandomStreams(18032015)
    u = trng.uniform(size=(length, batch_size, n_voc))
    gumbel_noise = -tensor.log(-tensor.log(u))
    init_samples = (tensor.log(init_probs).dimshuffle(('x', 0))
                    + gumbel_noise[0]).argmax(axis=-1)
    init_states = rnn.initial_state('states', batch_size)

    def sampling_step(g_noise, states, samples_step):
        embedding_step = linear_embedding.apply(samples_step)
        next_states = rnn.apply(inputs=embedding_step,
                                            states=states,
                                            iterate=False)
        probs_step = softmax(score_layer.apply(next_states))
        next_samples = (tensor.log(probs_step)
                        + g_noise).argmax(axis=-1)

        return next_states, next_samples

    [_, samples], _ = theano.scan(
        fn=sampling_step,
        sequences=[gumbel_noise[1:]],
        outputs_info=[init_states, init_samples]
    )

    sampling = theano.function([], samples.owner.inputs[0].T)

    plotters = []
    plotters.append(Plotter(
        channels=[['train_cost', 'train_misclassification']],
        titles=['Costs']))

    extensions.append(PlotManager('Language modelling example',
                                  plotters=plotters,
                                  after_epoch=True,
                                  after_training=True))
    extensions.append(Printing())
    extensions.append(PrintSamples(sampler=sampling,
                                   voc=train_dataset.inv_dict))

    main_loop = MainLoop(model=model,
                         data_stream=train_data_stream,
                         algorithm=algorithm,
                         extensions=extensions)

    main_loop.run()
开发者ID:dmitriy-serdyuk,项目名称:dl_tutorials,代码行数:104,代码来源:rnn_nlp_main.py


注:本文中的blocks.bricks.recurrent.SimpleRecurrent.initial_state方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。