本文整理汇总了C++中NetParameter::state方法的典型用法代码示例。如果您正苦于以下问题:C++ NetParameter::state方法的具体用法?C++ NetParameter::state怎么用?C++ NetParameter::state使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类NetParameter
的用法示例。
在下文中一共展示了NetParameter::state方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: state
void Net<Dtype>::filterNet(const NetParameter& param, NetParameter* filtered_param){
NetState state(param.state());
filtered_param->CopyFrom(param);
// remove all layer params and then filter
filtered_param->clear_layer();
for (int i = 0; i < param.layer_size(); i++){
const LayerParameter& layer_param = param.layer(i);
const string& layer_name = layer_param.name();
// usually a layer has not any include/exclude rules
CHECK(layer_param.include_size() == 0 || layer_param.exclude_size() == 0)
<< "Specify either include or exclude rules.";
bool layer_included = (layer_param.include_size() == 0);
// assume 'included' and check if meet any excluded rules
for (int j = 0; layer_included&&j < layer_param.exclude_size(); j++){
if (stateMeetRule(state, layer_param.exclude(j), layer_name)){
// cancel 'included'
layer_included = false;
}
}
// assume 'excluded' and check if meet any included rules
for (int j = 0; !layer_included&&j < layer_param.include_size(); j++){
if (stateMeetRule(state, layer_param.include(j), layer_name)){
// cancel 'excluded'
layer_included = true;
}
}
// copy the included layer to filtered_param
if (layer_included) filtered_param->add_layer()->CopyFrom(layer_param);
}
}
示例2: LOG
void Net<Dtype>::Init(const NetParameter& in_param){
CHECK(Dragon::get_root_solver() || root_net)
<< "Root net need to be set for all non-root solvers.";
phase = in_param.state().phase();
NetParameter filtered_param, param;
// filter for unqualified LayerParameters(e.g Test DataLayer)
filterNet(in_param, &filtered_param);
insertSplits(filtered_param, ¶m);
name = param.name();
LOG_IF(INFO, Dragon::get_root_solver())
<< "Initialize net from parameters: ";/*<< endl << param.DebugString();*/
map<string, int> blob_name_to_idx;
set<string> available_blobs;
CHECK_EQ(param.input_size(), param.input_shape_size())<< "input blob_shape must specify a blob.";
memory_used = 0;
// check and stuff virtual input blobs firstly [Viewing Mode Only]
for (int input_id=0; input_id < param.input_size(); input_id++){
const int layer_id = -1;
// net_input.push_back(.....virtual blob.....)
appendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx);
}
// stuff real blobs for each layer then [Traning/Testing/Viewing Mode]
bottom_vecs.resize(param.layer_size());
bottom_id_vecs.resize(param.layer_size());
bottoms_need_backward.resize(param.layer_size());
top_vecs.resize(param.layer_size());
top_id_vecs.resize(param.layer_size());
param_id_vecs.resize(param.layer_size());
for (int layer_id = 0; layer_id < param.layer_size(); layer_id++){
bool share_from_root = !Dragon::get_root_solver()
&& root_net->layers[layer_id]->shareInParallel();
// copy net phase to layer if not set
if (!param.layer(layer_id).has_phase())
param.mutable_layer(layer_id)->set_phase(phase);
const LayerParameter& layer_param = param.layer(layer_id);
if (share_from_root){
LOG(INFO) << "Share Layer: " << layer_param.name() << " from the root net.";
// share layer by pointer
layers.push_back(root_net->layers[layer_id]);
layers[layer_id]->setShared(true);
}
else{
// use layer factory to create a pointer
// layer type is referred by layer_param->type()
// see more in layer_factory.hpp
layers.push_back(LayerFactory<Dtype>::createLayer(layer_param));
}
layer_names.push_back(layer_param.name());
LOG_IF(INFO, Dragon::get_root_solver()) << "Create Layer: " << layer_param.name();
bool need_bp = false;
// stuff bottom blobs
for (int bottom_id = 0; bottom_id < layer_param.bottom_size(); bottom_id++){
const int blob_id = appendBottom(param, layer_id, bottom_id, &available_blobs, &blob_name_to_idx);
// check whether a bottom need back propogation
need_bp |= blobs_need_backward[blob_id];
}
// stuff top blobs
for (int top_id = 0; top_id < layer_param.top_size(); top_id++)
appendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx);
// auto top blobs
// NOT_IMPLEMENTED;
Layer<Dtype>* layer = layers[layer_id].get();
// setup for layer
if (share_from_root){
const vector<Blob<Dtype>*> base_top = root_net->top_vecs[layer_id];
const vector<Blob<Dtype>*> this_top = this->top_vecs[layer_id];
// reshape solely after root_net finishing
for (int top_id = 0; top_id < base_top.size(); top_id++){
this_top[top_id]->reshapeLike(*base_top[top_id]);
}
}
else layer->setup(bottom_vecs[layer_id], top_vecs[layer_id]);
LOG_IF(INFO, Dragon::get_root_solver()) << "Setup Layer: " << layer_param.name();
for (int top_id = 0; top_id < top_vecs[layer_id].size(); top_id++){
// extend size to max number of blobs if necessary
if (blobs_loss_weight.size() <= top_id_vecs[layer_id][top_id])
blobs_loss_weight.resize(top_id_vecs[layer_id][top_id] + 1, Dtype(0));
// store global loss weights from each layer each blob
blobs_loss_weight[top_id_vecs[layer_id][top_id]] = layer->getLoss(top_id);
LOG_IF(INFO, Dragon::get_root_solver())
<< "Top shape: " << top_vecs[layer_id][top_id]->shape_string();
if (layer->getLoss(top_id)) LOG_IF(INFO, Dragon::get_root_solver())
<< " with loss weight " << layer->getLoss(top_id);
// sum up for training parameter statistic
memory_used += top_vecs[layer_id][top_id]->count();
}
LOG_IF(INFO, Dragon::get_root_solver())
<< "Memory required for Data: " << memory_used*sizeof(Dtype);
const int param_size = layer_param.param_size();
// blobs_size will be set after layer->setup()
const int param_blobs_size = layer->getBlobs().size();
CHECK_LE(param_size, param_blobs_size)<< "Too many params specify for layer.";
// use if do not specify hyperparameter
// lr_mult=decay_mult=1.0
ParamSpec default_hyperparameter;
for (int param_id = 0; param_id < param_blobs_size; param_id++){
const ParamSpec* hyperparameter = param_id < param_size ?
&layer_param.param(param_id) : &default_hyperparameter;
const bool param_need_bp = hyperparameter->lr_mult() != 0;
// check whether a param blob need back propogation [default=true]
//.........这里部分代码省略.........