当前位置: 首页>>代码示例>>C++>>正文


C++ NDArray::CopyTo方法代码示例

本文整理汇总了C++中NDArray::CopyTo方法的典型用法代码示例。如果您正苦于以下问题:C++ NDArray::CopyTo方法的具体用法?C++ NDArray::CopyTo怎么用?C++ NDArray::CopyTo使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在NDArray的用法示例。


在下文中一共展示了NDArray::CopyTo方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: main


//.........这里部分代码省略.........
        // arg.first is parameter name, and arg.second is the value
        initializer(arg.first, &arg.second);
    }

    // Create sgd optimiz er
    Optimizer* opt = OptimizerRegistry::Find("sgd");
    opt->SetParam("rescale_grad", 1.0/batch_size)
    ->SetParam("lr", learning_rate)
    ->SetParam("wd", weight_decay);

    // Create executor by binding parameters to the model
    auto *exec = net.SimpleBind(ctx, args);
    auto arg_names = net.ListArguments();

    // Start training
    for (int iter = 0; iter < max_epoch; ++iter) {
        int samples = 0;
        train_iter.Reset();

        auto tic = std::chrono::system_clock::now();
        while (train_iter.Next()) {
            samples += batch_size;
            auto data_batch = train_iter.GetDataBatch();

            /*
             * The shape of data_batch.data is (batch_size, (num_mnist_features + 1))
             * Need to reshape this data so that label column can be extracted from this data.
             */
            NDArray reshapedData = data_batch.data.Reshape(Shape((num_mnist_features + 1),
                                                                 batch_size));

            /*
             * Extract the label data by slicing the first column of the data and
             * copy it to "label" arg.
             */
            reshapedData.Slice(0, 1).Reshape(Shape(batch_size)).CopyTo(&args["label"]);

            /*
             * Extract the feature data by slicing the columns 1 to 785 of the data and
             * copy it to "data" arg.
             */
            reshapedData.Slice(1, (num_mnist_features + 1)).Reshape(Shape(batch_size,
                                                                         num_mnist_features))
                                                           .CopyTo(&args["data"]);

            exec->Forward(true);

            // Compute gradients
            exec->Backward();
            // Update parameters
            for (size_t i = 0; i < arg_names.size(); ++i) {
                if (arg_names[i] == "data" || arg_names[i] == "label") continue;
                opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
            }
        }
        auto toc = std::chrono::system_clock::now();

        Accuracy acc;
        val_iter.Reset();
        while (val_iter.Next()) {
            auto data_batch = val_iter.GetDataBatch();

            /*
             * The shape of data_batch.data is (batch_size, (num_mnist_features + 1))
             * Need to reshape this data so that label column can be extracted from this data.
             */
            NDArray reshapedData = data_batch.data.Reshape(Shape((num_mnist_features + 1),
                                                                 batch_size));

            /*
             * Extract the label data by slicing the first column of the data and
             * copy it to "label" arg.
             */
            NDArray labelData = reshapedData.Slice(0, 1).Reshape(Shape(batch_size));
            labelData.CopyTo(&args["label"]);

            /*
             * Extract the feature data by slicing the columns 1 to 785 of the data and
             * copy it to "data" arg.
             */
            reshapedData.Slice(1, (num_mnist_features + 1)).Reshape(Shape(batch_size,
                                                                         num_mnist_features))
                                                                   .CopyTo(&args["data"]);

            // Forward pass is enough as no gradient is needed when evaluating
            exec->Forward(false);
            acc.Update(labelData, exec->outputs[0]);
        }
        float duration = std::chrono::duration_cast<std::chrono::milliseconds>
        (toc - tic).count() / 1000.0;
        LG << "Epoch[" << iter << "]  " << samples/duration << " samples/sec Accuracy: "
        << acc.Get();
    }

    delete exec;
    delete opt;
    MXNotifyShutdown();
    CATCH
    return 0;
}
开发者ID:dmlc,项目名称:mxnet,代码行数:101,代码来源:mlp_csv.cpp


注:本文中的NDArray::CopyTo方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。