当前位置: 首页>>代码示例>>Python>>正文


Python cnn.CNNModelHelper方法代码示例

本文整理汇总了Python中caffe2.python.cnn.CNNModelHelper方法的典型用法代码示例。如果您正苦于以下问题:Python cnn.CNNModelHelper方法的具体用法?Python cnn.CNNModelHelper怎么用?Python cnn.CNNModelHelper使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在caffe2.python.cnn的用法示例。


在下文中一共展示了cnn.CNNModelHelper方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def __init__(self, **kwargs):
        # Handle args specific to the DetectionModelHelper, others pass through
        # to CNNModelHelper
        self.train = kwargs.get('train', False)
        self.num_classes = kwargs.get('num_classes', -1)
        assert self.num_classes > 0, 'num_classes must be > 0'
        for k in ('train', 'num_classes'):
            if k in kwargs:
                del kwargs[k]
        kwargs['order'] = 'NCHW'
        # Defensively set cudnn_exhaustive_search to False in case the default
        # changes in CNNModelHelper. The detection code uses variable size
        # inputs that might not play nicely with cudnn_exhaustive_search.
        kwargs['cudnn_exhaustive_search'] = False
        super(DetectionModelHelper, self).__init__(**kwargs)
        self.roi_data_loader = None
        self.losses = []
        self.metrics = []
        self.do_not_update_params = []  # Param on this list are not updated
        self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
        self.net.Proto().num_workers = cfg.NUM_GPUS * 4
        self.prev_use_cudnn = self.use_cudnn
        self.gn_params = []  # Param on this list are GroupNorm parameters 
开发者ID:yihui-he,项目名称:KL-Loss,代码行数:25,代码来源:detector.py

示例2: __init__

# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def __init__(self, **kwargs):
        # Handle args specific to the DetectionModelHelper, others pass through
        # to CNNModelHelper
        self.train = kwargs.get('train', False)
        self.num_classes = kwargs.get('num_classes', -1)
        assert self.num_classes > 0, 'num_classes must be > 0'
        for k in ('train', 'num_classes'):
            if k in kwargs:
                del kwargs[k]
        kwargs['order'] = 'NCHW'
        # Defensively set cudnn_exhaustive_search to False in case the default
        # changes in CNNModelHelper. The detection code uses variable size
        # inputs that might not play nicely with cudnn_exhaustive_search.
        kwargs['cudnn_exhaustive_search'] = False
        super(DetectionModelHelper, self).__init__(**kwargs)
        self.roi_data_loader = None
        self.losses = []
        self.metrics = []
        self.do_not_update_params = []  # Param on this list are not updated
        self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
        self.net.Proto().num_workers = cfg.NUM_GPUS * 4
        self.prev_use_cudnn = self.use_cudnn
        self.gn_params = []  # Param on this list are GroupNorm parameters
        self.stage_params = {}  # Param on this list are updated with scalars 
开发者ID:fyangneil,项目名称:Clustered-Object-Detection-in-Aerial-Image,代码行数:26,代码来源:detector.py

示例3: __init__

# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def __init__(self, **kwargs):
        # Handle args specific to the DetectionModelHelper, others pass through
        # to CNNModelHelper
        self.train = kwargs.get('train', False)
        self.num_classes = kwargs.get('num_classes', -1)
        assert self.num_classes > 0, 'num_classes must be > 0'
        for k in ('train', 'num_classes'):
            if k in kwargs:
                del kwargs[k]
        kwargs['order'] = 'NCHW'
        # Defensively set cudnn_exhaustive_search to False in case the default
        # changes in CNNModelHelper. The detection code uses variable size
        # inputs that might not play nicely with cudnn_exhaustive_search.
        kwargs['cudnn_exhaustive_search'] = False
        super(DetectionModelHelper, self).__init__(**kwargs)
        self.roi_data_loader = None
        self.losses = []
        self.metrics = []
        self.do_not_update_params = []  # Param on this list are not updated
        self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
        self.net.Proto().num_workers = cfg.NUM_GPUS * 4
        self.prev_use_cudnn = self.use_cudnn 
开发者ID:lvpengyuan,项目名称:masktextspotter.caffe2,代码行数:24,代码来源:detector.py

示例4: write_graph

# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def write_graph(self, model_or_nets_or_protos=None, **kwargs):
        '''Write graph to the summary.'''
        if isinstance(model_or_nets_or_protos, cnn.CNNModelHelper):
            current_graph, track_blob_names = model_to_graph(model_or_nets_or_protos, **kwargs)
        elif isinstance(model_or_nets_or_protos, list):
            if isinstance(model_or_nets_or_protos[0], core.Net):
                current_graph, track_blob_names = nets_to_graph(model_or_nets_or_protos, **kwargs)
            elif isinstance(model_or_nets_or_protos[0], caffe2_pb2.NetDef):
                current_graph, track_blob_names = protos_to_graph(model_or_nets_or_protos, **kwargs)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        self._file_writer.add_graph(current_graph)
        self._track_blob_names = track_blob_names
        # Once the graph is built, one can just map the blobs
        self.check_names()
        self.sort_out_names() 
开发者ID:endernewton,项目名称:c2board,代码行数:20,代码来源:writer.py

示例5: test_simple_cnnmodel

# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def test_simple_cnnmodel(self):
        model = cnn.CNNModelHelper("NCHW", name="overfeat")
        workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32))
        workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int))
        with core.NameScope("conv1"):
            conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4)
            relu1 = model.Relu(conv1, conv1)
            pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2)
        with core.NameScope("classifier"):
            fc = model.FC(pool1, "fc", 4096, 1000)
            pred = model.Softmax(fc, "pred")
            xent = model.LabelCrossEntropy([pred, "label"], "xent")
            loss = model.AveragedLoss(xent, "loss")

        blob_name_tracker = {}
        graph = tb.model_to_graph_def(
            model,
            blob_name_tracker=blob_name_tracker,
            shapes={},
            show_simplified=False,
        )

        compare_proto(graph, self)

    # cnn.CNNModelHelper is deprecated, so we also test with
    # model_helper.ModelHelper. The model used in this test is taken from the
    # Caffe2 MNIST tutorial. Also use show_simplified=False here. 
开发者ID:lanpa,项目名称:tensorboardX,代码行数:29,代码来源:test_caffe2.py

示例6: model_build

# 需要导入模块: from caffe2.python import cnn [as 别名]
# 或者: from caffe2.python.cnn import CNNModelHelper [as 别名]
def model_build(networkDefinition, gpu_id=0):
    model = cnn.CNNModelHelper(name="pose", use_cudnn=True, cudnn_exhaustive_search=False)
    model.target_gpu_id = gpu_id
    return build_generic_detection_model(model, networkDefinition) 
开发者ID:eddieyi,项目名称:caffe2-pose-estimation,代码行数:6,代码来源:infer.py


注:本文中的caffe2.python.cnn.CNNModelHelper方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。