当前位置: 首页>>代码示例>>Python>>正文


Python RBM.reconstruct方法代码示例

本文整理汇总了Python中rbm.RBM.reconstruct方法的典型用法代码示例。如果您正苦于以下问题:Python RBM.reconstruct方法的具体用法?Python RBM.reconstruct怎么用?Python RBM.reconstruct使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rbm.RBM的用法示例。


在下文中一共展示了RBM.reconstruct方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: toy_test

# 需要导入模块: from rbm import RBM [as 别名]
# 或者: from rbm.RBM import reconstruct [as 别名]

#.........这里部分代码省略.........
    # each cell represents the number of times a term occurs
    #                          d1 d2 d3 d4 d5
    toy_data = numpy.asarray([[0, 2, 0, 1, 0],
                              [9, 0, 3, 1, 1],
                              [4, 1, 1, 2, 1],
                              [10, 10, 1, 1, 0],
                              [1, 0, 8, 0, 10],
                              [0, 1, 10, 1, 0],
                              [1, 0, 2, 6, 1],
                              [0, 0, 1, 0, 0],
                              [1, 0, 0, 0, 0],
                              [1, 0, 1, 0, 0],
                              [1, 1, 0, 0, 1],
                              [10, 2, 0, 1, 0],
                              [0, 0, 1, 0, 10],
                              [1, 0, 0, 3, 0],
                              [0, 0, 2, 0, 1],
                              [10, 0, 1, 0, 0],
                              [0, 1, 0, 0, 0],
                              [0, 1, 0, 1, 0],
                              [1, 0, 1, 0, 0],
                              [1, 0, 0, 0, 1],
                              [1, 0, 1, 0, 0],
                              [0, 0, 1, 0, 0]])

    # from rbm import RBM
    from rbm_variants import RBM_Orthogonal as RBM
    # from rbm_variants import PoissonRBM as RBM


    train_x = toSharedX(toy_data, name="toy_data")

    n_vis = train_x.get_value(borrow=True).shape[1]

    n_samples = train_x.get_value(borrow=True).shape[0]

    if batch_size >= n_samples:
        batch_size = n_samples

    n_train_batches = n_samples / batch_size


    # construct the RBM class
    rbm = RBM(n_visible=n_vis, n_hidden=n_hidden, isPCD=isPCD)
    train_fn = rbm.get_train_fn(train_x, batch_size)

    print "... projecting"
    print rbm.project(train_x, hidSample=1)

    #################################
    #     Training the RBM          #
    #################################
    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)
    os.chdir(output_folder)

    plotting_time = 0.
    start_time = time.clock()
    import PIL.Image
    from visualizer import tile_raster_images

    # go through training epochs
    for epoch in xrange(training_epochs):

        # go through the training set
        mean_cost = []
        for batch_index in xrange(n_train_batches):
            # for each batch, we extract the gibbs chain
            new_cost = train_fn(index=batch_index, lr=learning_rate)
            mean_cost += [new_cost]

        print 'Training epoch %d, cost is ' % epoch, numpy.mean(mean_cost)

        if numpy.mean(mean_cost) >= 0:
            break

        # W shape is [784 500]
        # Plot filters after each training epoch
        plotting_start = time.clock()
        # Construct image from the weight matrix
        image = PIL.Image.fromarray(tile_raster_images(
            X=rbm.W.get_value(borrow=True).T,
            # weight is [n_vis, n_hidden]
            # so, among 'n_hidden' rows,
            # each row corresponds to propdown one hidden unit
            img_shape=(1, n_vis), tile_shape=(n_hidden, 1),
            tile_spacing=(1, 1)))
        image.save('filters_at_epoch_%i.png' % epoch)
        plotting_stop = time.clock()
        plotting_time += (plotting_stop - plotting_start)

    end_time = time.clock()
    pretraining_time = (end_time - start_time) - plotting_time
    print ('Training took %f minutes' % (pretraining_time / 60.))

    print "... projecting"
    print rbm.project(train_x, hidSample=1)

    print "... reconstructing"
    print rbm.reconstruct(train_x, showSample=1) * train_x.get_value(borrow=True)
开发者ID:bboalimoe,项目名称:DENA,代码行数:104,代码来源:rbm_variants.py


注:本文中的rbm.RBM.reconstruct方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。