本文整理汇总了C++中shared_ptr::Backward方法的典型用法代码示例。如果您正苦于以下问题:C++ shared_ptr::Backward方法的具体用法?C++ shared_ptr::Backward怎么用?C++ shared_ptr::Backward使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类shared_ptr
的用法示例。
在下文中一共展示了shared_ptr::Backward方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: do_backward
static mxArray* do_backward(const mxArray* const top_diff) {
const vector<Blob<float>*>& output_blobs = net_->output_blobs();
const vector<Blob<float>*>& input_blobs = net_->input_blobs();
if (static_cast<unsigned int>(mxGetDimensions(top_diff)[0]) !=
output_blobs.size()) {
mex_error("Invalid input size");
}
// First, copy the output diff
for (unsigned int i = 0; i < output_blobs.size(); ++i) {
const mxArray* const elem = mxGetCell(top_diff, i);
const float* const data_ptr =
reinterpret_cast<const float* const>(mxGetPr(elem));
switch (Caffe::mode()) {
case Caffe::CPU:
caffe_copy(output_blobs[i]->count(), data_ptr,
output_blobs[i]->mutable_cpu_diff());
break;
case Caffe::GPU:
caffe_copy(output_blobs[i]->count(), data_ptr,
output_blobs[i]->mutable_gpu_diff());
break;
default:
mex_error("Unknown Caffe mode");
} // switch (Caffe::mode())
}
// LOG(INFO) << "Start";
net_->Backward();
// LOG(INFO) << "End";
mxArray* mx_out = mxCreateCellMatrix(input_blobs.size(), 1);
for (unsigned int i = 0; i < input_blobs.size(); ++i) {
// internally data is stored as (width, height, channels, num)
// where width is the fastest dimension
mwSize dims[4] = {input_blobs[i]->width(), input_blobs[i]->height(),
input_blobs[i]->channels(), input_blobs[i]->num()};
mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL);
mxSetCell(mx_out, i, mx_blob);
float* data_ptr = reinterpret_cast<float*>(mxGetPr(mx_blob));
switch (Caffe::mode()) {
case Caffe::CPU:
caffe_copy(input_blobs[i]->count(), input_blobs[i]->cpu_diff(), data_ptr);
break;
case Caffe::GPU:
caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr);
break;
default:
mex_error("Unknown Caffe mode");
} // switch (Caffe::mode())
}
return mx_out;
}
示例2: backward
static void backward(MEX_ARGS) {
float loss = net_->Backward();
}