当前位置: 首页>>代码示例>>Python>>正文


Python MLP.mean_square_error方法代码示例

本文整理汇总了Python中mlp.MLP.mean_square_error方法的典型用法代码示例。如果您正苦于以下问题:Python MLP.mean_square_error方法的具体用法?Python MLP.mean_square_error怎么用?Python MLP.mean_square_error使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mlp.MLP的用法示例。


在下文中一共展示了MLP.mean_square_error方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_cA

# 需要导入模块: from mlp import MLP [as 别名]
# 或者: from mlp.MLP import mean_square_error [as 别名]
def test_cA(learning_rate=0.002, training_epochs=10,
            dataset='test2.pkl.gz', n_epochs=100,
            batch_size=5, contraction_level=0.001):


    datasets = load_data(dataset)

    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x, test_set_y = datasets[2]

    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
    # validation/testの時はミニバッチを使わない
    n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] #/ batch_size
    n_test_batches = test_set_x.get_value(borrow=True).shape[0] #/ batch_size

    print '... building the model'

    # allocate symbolic variables for the data
    index = T.lscalar()  # index to a [mini]batch
    x = T.matrix('x')  # the data is presented as rasterized images
    y = T.vector('y')  # the labels are presented as 1D vector of
                        # [int] labels

    # 乱数シード
    rng = np.random.RandomState(1234)

    regressor = MLP(rng=rng, input=x, n_in=1631,
                        n_hidden=800)


    ca = cA(numpy_rng=rng, input=x,
            n_visible=1631, n_hidden=800, n_batchsize=batch_size,
            W = regressor.hiddenLayer.W, bhid = regressor.hiddenLayer.b)

    cost, updates = ca.get_cost_updates(contraction_level=contraction_level,
                                        learning_rate=0.001)

    train_ca = theano.function([index], [T.mean(ca.L_rec), ca.L_jacob],
                               updates=updates,
                               givens={x: train_set_x[index * batch_size:
                                                    (index + 1) * batch_size]})

    start_time = time.clock()

    ############
    # TRAINING #
    ############
    print '... training the model'
    epoch = 0

    # go through training epochs
    for epoch in xrange(training_epochs):
        # go through trainng set
        c = []
        for batch_index in xrange(n_train_batches):
            c.append(train_ca(batch_index))

        c_array = np.vstack(c)
        print 'Training epoch %d, reconstruction cost ' % epoch, np.mean(
            c_array[0]), ' jacobian norm ', np.mean(np.sqrt(c_array[1]))
        a = regressor.hiddenLayer.W.eval()
        print np.sum(a)

    end_time = time.clock()

    training_time = (end_time - start_time)

    print >> sys.stderr, ('The code for file ' + os.path.split(__file__)[1] +
                          ' ran for %.2fm' % ((training_time) / 60.))

    cost = regressor.mean_square_error(y)


    gparams = []
    for param in regressor.params:
        gparam = T.grad(cost, param)
        gparams.append(gparam)

    updates = []

    for param, gparam in zip(regressor.params, gparams):
        updates.append((param, param - learning_rate * gparam))

    # 学習
    train_model = theano.function(inputs=[index], outputs=cost,
            updates=updates,
            givens={
                x: train_set_x[index * batch_size:(index + 1) * batch_size],
                y: train_set_y[index * batch_size:(index + 1) * batch_size]})

    # validation
    validate_model = theano.function(inputs=[index],
            outputs=regressor.prediction,
            givens={
                x: valid_set_x[index:(index + 1)]})
            
    # validationデータセット
    valid_datasets = load_valid_data(dataset)
#.........这里部分代码省略.........
开发者ID:mottodora,项目名称:DNN-MultipleReg,代码行数:103,代码来源:supervised_cA.py


注:本文中的mlp.MLP.mean_square_error方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。