本文整理汇总了Python中caffe2.python.cnn.CNNModelHelper方法的典型用法代码示例。如果您正苦于以下问题:Python cnn.CNNModelHelper方法的具体用法?Python cnn.CNNModelHelper怎么用?Python cnn.CNNModelHelper使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类caffe2.python.cnn
的用法示例。
在下文中一共展示了cnn.CNNModelHelper方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def __init__(self, **kwargs):
# Handle args specific to the DetectionModelHelper, others pass through
# to CNNModelHelper
self.train = kwargs.get('train', False)
self.num_classes = kwargs.get('num_classes', -1)
assert self.num_classes > 0, 'num_classes must be > 0'
for k in ('train', 'num_classes'):
if k in kwargs:
del kwargs[k]
kwargs['order'] = 'NCHW'
# Defensively set cudnn_exhaustive_search to False in case the default
# changes in CNNModelHelper. The detection code uses variable size
# inputs that might not play nicely with cudnn_exhaustive_search.
kwargs['cudnn_exhaustive_search'] = False
super(DetectionModelHelper, self).__init__(**kwargs)
self.roi_data_loader = None
self.losses = []
self.metrics = []
self.do_not_update_params = [] # Param on this list are not updated
self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
self.net.Proto().num_workers = cfg.NUM_GPUS * 4
self.prev_use_cudnn = self.use_cudnn
self.gn_params = [] # Param on this list are GroupNorm parameters
示例2: __init__
# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def __init__(self, **kwargs):
# Handle args specific to the DetectionModelHelper, others pass through
# to CNNModelHelper
self.train = kwargs.get('train', False)
self.num_classes = kwargs.get('num_classes', -1)
assert self.num_classes > 0, 'num_classes must be > 0'
for k in ('train', 'num_classes'):
if k in kwargs:
del kwargs[k]
kwargs['order'] = 'NCHW'
# Defensively set cudnn_exhaustive_search to False in case the default
# changes in CNNModelHelper. The detection code uses variable size
# inputs that might not play nicely with cudnn_exhaustive_search.
kwargs['cudnn_exhaustive_search'] = False
super(DetectionModelHelper, self).__init__(**kwargs)
self.roi_data_loader = None
self.losses = []
self.metrics = []
self.do_not_update_params = [] # Param on this list are not updated
self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
self.net.Proto().num_workers = cfg.NUM_GPUS * 4
self.prev_use_cudnn = self.use_cudnn
self.gn_params = [] # Param on this list are GroupNorm parameters
self.stage_params = {} # Param on this list are updated with scalars
示例3: __init__
# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def __init__(self, **kwargs):
# Handle args specific to the DetectionModelHelper, others pass through
# to CNNModelHelper
self.train = kwargs.get('train', False)
self.num_classes = kwargs.get('num_classes', -1)
assert self.num_classes > 0, 'num_classes must be > 0'
for k in ('train', 'num_classes'):
if k in kwargs:
del kwargs[k]
kwargs['order'] = 'NCHW'
# Defensively set cudnn_exhaustive_search to False in case the default
# changes in CNNModelHelper. The detection code uses variable size
# inputs that might not play nicely with cudnn_exhaustive_search.
kwargs['cudnn_exhaustive_search'] = False
super(DetectionModelHelper, self).__init__(**kwargs)
self.roi_data_loader = None
self.losses = []
self.metrics = []
self.do_not_update_params = [] # Param on this list are not updated
self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
self.net.Proto().num_workers = cfg.NUM_GPUS * 4
self.prev_use_cudnn = self.use_cudnn
示例4: write_graph
# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def write_graph(self, model_or_nets_or_protos=None, **kwargs):
'''Write graph to the summary.'''
if isinstance(model_or_nets_or_protos, cnn.CNNModelHelper):
current_graph, track_blob_names = model_to_graph(model_or_nets_or_protos, **kwargs)
elif isinstance(model_or_nets_or_protos, list):
if isinstance(model_or_nets_or_protos[0], core.Net):
current_graph, track_blob_names = nets_to_graph(model_or_nets_or_protos, **kwargs)
elif isinstance(model_or_nets_or_protos[0], caffe2_pb2.NetDef):
current_graph, track_blob_names = protos_to_graph(model_or_nets_or_protos, **kwargs)
else:
raise NotImplementedError
else:
raise NotImplementedError
self._file_writer.add_graph(current_graph)
self._track_blob_names = track_blob_names
# Once the graph is built, one can just map the blobs
self.check_names()
self.sort_out_names()
示例5: test_simple_cnnmodel
# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def test_simple_cnnmodel(self):
model = cnn.CNNModelHelper("NCHW", name="overfeat")
workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32))
workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int))
with core.NameScope("conv1"):
conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4)
relu1 = model.Relu(conv1, conv1)
pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2)
with core.NameScope("classifier"):
fc = model.FC(pool1, "fc", 4096, 1000)
pred = model.Softmax(fc, "pred")
xent = model.LabelCrossEntropy([pred, "label"], "xent")
loss = model.AveragedLoss(xent, "loss")
blob_name_tracker = {}
graph = tb.model_to_graph_def(
model,
blob_name_tracker=blob_name_tracker,
shapes={},
show_simplified=False,
)
compare_proto(graph, self)
# cnn.CNNModelHelper is deprecated, so we also test with
# model_helper.ModelHelper. The model used in this test is taken from the
# Caffe2 MNIST tutorial. Also use show_simplified=False here.
示例6: model_build
# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def model_build(networkDefinition, gpu_id=0):
model = cnn.CNNModelHelper(name="pose", use_cudnn=True, cudnn_exhaustive_search=False)
model.target_gpu_id = gpu_id
return build_generic_detection_model(model, networkDefinition)