当前位置: 首页>>代码示例>>Python>>正文


Python module.Module方法代码示例

本文整理汇总了Python中mxnet.module.Module方法的典型用法代码示例。如果您正苦于以下问题:Python module.Module方法的具体用法?Python module.Module怎么用?Python module.Module使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet.module的用法示例。


在下文中一共展示了module.Module方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
                 logger=logging, context=mx.cpu(), work_load_list=None,
                 fixed_param_names=None, state_names=None, group2ctxs=None,
                 compression_params=None, update_freq=None):
        super(SVRGModule, self).__init__(symbol, data_names=data_names, label_names=label_names, logger=logger,
                                         context=context, work_load_list=work_load_list,
                                         fixed_param_names=fixed_param_names, state_names=state_names,
                                         group2ctxs=group2ctxs, compression_params=compression_params)

        # Type check update_frequency
        if isinstance(update_freq, int):
            if update_freq <= 0:
                raise ValueError("update_freq in SVRGModule must be a positive integer to represent the frequency for "
                                 "calculating full gradients")
            self.update_freq = update_freq
        else:
            raise TypeError("update_freq in SVRGModule must be an integer to represent the frequency for "
                            "calculating full gradients")

        self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, context, work_load_list,
                                      fixed_param_names, state_names, group2ctxs, compression_params)

        self._param_dict = None
        self._ctx_len = len(self._context) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:26,代码来源:svrg_module.py

示例2: __init__

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def __init__(self, symbol, data_names, label_names,
                 data_shapes, label_shapes, logger=logging,
                 context=mx.cpu(), work_load_list=None, fixed_param_names=None):
        self.symbol = symbol
        self.data_names = data_names
        self.label_names = label_names
        self.data_shapes = data_shapes
        self.label_shapes = label_shapes
        self.context = context
        self.work_load_list = work_load_list
        self.fixed_param_names = fixed_param_names

        if logger is None:
            logger = logging.getLogger()
            logger.setLevel(logging.INFO)
        self.logger = logger
        self.module = Module(symbol=self.symbol, data_names=self.data_names,
                             label_names=self.label_names, logger=self.logger,
                             context=self.context, work_load_list=self.work_load_list,
                             fixed_param_names=self.fixed_param_names) 
开发者ID:TuSimple,项目名称:sparse-structure-selection,代码行数:22,代码来源:solver.py

示例3: bind

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def bind(self, data_shapes, label_shapes=None, for_training=True,
             inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
        """Binds the symbols to construct executors for both two modules. This is necessary before one
        can perform computation with the SVRGModule.

        Parameters
        ----------
        data_shapes : list of (str, tuple)
            Typically is ``data_iter.provide_data``.
        label_shapes : list of (str, tuple)
            Typically is ``data_iter.provide_label``.
        for_training : bool
            Default is ``True``. Whether the executors should be bound for training.
        inputs_need_grad : bool
            Default is ``False``. Whether the gradients to the input data need to be computed.
            Typically this is not needed. But this might be needed when implementing composition
            of modules.
        force_rebind : bool
            Default is ``False``. This function does nothing if the executors are already
            bound. But with this ``True``, the executors will be forced to rebind.
        shared_module : Module
            Default is ``None``. This is used in bucketing. When not ``None``, the shared module
            essentially corresponds to a different bucket -- a module with different symbol
            but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
        """
        # force rebinding is typically used when one want to switch from
        # training to prediction phase.
        super(SVRGModule, self).bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind,
                                     shared_module, grad_req)

        if for_training:
            self._mod_aux.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module,
                               grad_req) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:35,代码来源:svrg_module.py

示例4: check_qsym_forward

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape):
  mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
  mod.bind(for_training=False,
           data_shapes=[('data', data_shape)],
           label_shapes=[('softmax_label', label_shape)])
  mod.set_params(qarg_params, qaux_params)
  mod.forward(batch, is_train=False)
  for output in mod.get_outputs():
    output.wait_to_read()
  return mod.get_outputs() 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:12,代码来源:test_subgraph.py

示例5: __init__

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def __init__(self, symbol, data_names, label_names,
                 context=mx.cpu(), max_data_shapes=None,
                 provide_data=None, provide_label=None,
                 arg_params=None, aux_params=None):
        #self._mod = MutableModule(symbol, data_names, label_names,
        #                          context=context, max_data_shapes=max_data_shapes)
        self._mod = Module(symbol, data_names, label_names, context=context)
        self._mod.bind(provide_data, provide_label, for_training=False)
        self._mod.init_params(arg_params=arg_params, aux_params=aux_params) 
开发者ID:deepinsight,项目名称:insightface,代码行数:11,代码来源:tester.py

示例6: demo_net

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def demo_net(sym, class_names, args):
    # print config
    print('called with args\n{}'.format(pprint.pformat(vars(args))))

    # setup context
    if args.gpu:
        ctx = mx.gpu(int(args.gpu))
    else:
        ctx = mx.cpu(0)

    # load single test
    im_tensor, im_info, im_orig = load_test(args.image, short=args.img_short_side, max_size=args.img_long_side,
                                            mean=args.img_pixel_means, std=args.img_pixel_stds)

    # generate data batch
    data_batch = generate_batch(im_tensor, im_info)

    # load params
    arg_params, aux_params = load_param(args.params, ctx=ctx)

    # produce shape max possible
    data_names = ['data', 'im_info']
    label_names = None
    data_shapes = [('data', (1, 3, args.img_long_side, args.img_long_side)), ('im_info', (1, 3))]
    label_shapes = None

    # check shapes
    check_shape(sym, data_shapes, arg_params, aux_params)

    # create and bind module
    mod = Module(sym, data_names, label_names, context=ctx)
    mod.bind(data_shapes, label_shapes, for_training=False)
    mod.init_params(arg_params=arg_params, aux_params=aux_params)

    # forward
    mod.forward(data_batch)
    rois, scores, bbox_deltas = mod.get_outputs()
    rois = rois[:, 1:]
    scores = scores[0]
    bbox_deltas = bbox_deltas[0]
    im_info = im_info[0]

    # decode detection
    det = im_detect(rois, scores, bbox_deltas, im_info,
                    bbox_stds=args.rcnn_bbox_stds, nms_thresh=args.rcnn_nms_thresh,
                    conf_thresh=args.rcnn_conf_thresh)

    # print out
    for [cls, conf, x1, y1, x2, y2] in det:
        if cls > 0 and conf > args.vis_thresh:
            print(class_names[int(cls)], conf, [x1, y1, x2, y2])

    # if vis
    if args.vis:
        vis_detection(im_orig, det, class_names, thresh=args.vis_thresh) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:57,代码来源:demo.py

示例7: test_net

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def test_net(sym, imdb, args):
    # print config
    logger.info('called with args\n{}'.format(pprint.pformat(vars(args))))

    # setup context
    ctx = mx.gpu(args.gpu)

    # load testing data
    test_data = TestLoader(imdb.roidb, batch_size=1, short=args.img_short_side, max_size=args.img_long_side,
                           mean=args.img_pixel_means, std=args.img_pixel_stds)

    # load params
    arg_params, aux_params = load_param(args.params, ctx=ctx)

    # produce shape max possible
    data_names = ['data', 'im_info']
    label_names = None
    data_shapes = [('data', (1, 3, args.img_long_side, args.img_long_side)), ('im_info', (1, 3))]
    label_shapes = None

    # check shapes
    check_shape(sym, data_shapes, arg_params, aux_params)

    # create and bind module
    mod = Module(sym, data_names, label_names, context=ctx)
    mod.bind(data_shapes, label_shapes, for_training=False)
    mod.init_params(arg_params=arg_params, aux_params=aux_params)

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(imdb.num_images)]
                 for _ in range(imdb.num_classes)]

    # start detection
    with tqdm(total=imdb.num_images) as pbar:
        for i, data_batch in enumerate(test_data):
            # forward
            im_info = data_batch.data[1][0]
            mod.forward(data_batch)
            rois, scores, bbox_deltas = mod.get_outputs()
            rois = rois[:, 1:]
            scores = scores[0]
            bbox_deltas = bbox_deltas[0]

            det = im_detect(rois, scores, bbox_deltas, im_info,
                            bbox_stds=args.rcnn_bbox_stds, nms_thresh=args.rcnn_nms_thresh,
                            conf_thresh=args.rcnn_conf_thresh)
            for j in range(1, imdb.num_classes):
                indexes = np.where(det[:, 0] == j)[0]
                all_boxes[j][i] = np.concatenate((det[:, -4:], det[:, [1]]), axis=-1)[indexes, :]
            pbar.update(data_batch.data[0].shape[0])

    # evaluate model
    imdb.evaluate_detections(all_boxes) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:57,代码来源:test.py

示例8: test_quantize_model

# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def test_quantize_model():
    def check_params(params, qparams, qsym=None):
        if qsym is None:
            assert len(params) == len(qparams)
            for k, v in params.items():
                assert k in qparams
                assert same(v.asnumpy(), qparams[k].asnumpy())
        else:
            qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params)
            assert len(qparams) == len(qparams_ground_truth)
            for k, v in qparams_ground_truth.items():
                assert k in qparams
                assert same(v.asnumpy(), qparams[k].asnumpy())

    def check_qsym_calibrated(qsym):
        attrs = qsym.attr_dict()
        for k, v in attrs.items():
            if k.find('requantize_') != -1:
                assert 'min_calib_range' in v
                assert 'max_calib_range' in v

    sym = get_fp32_sym()
    mod = Module(symbol=sym)
    batch_size = 4
    data_shape = (batch_size, 4, 10, 10)
    label_shape = (batch_size, 10)
    mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)])
    mod.init_params()
    arg_params, aux_params = mod.get_params()
    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
                                                                     arg_params=arg_params,
                                                                     aux_params=aux_params,
                                                                     ctx=mx.current_context(),
                                                                     calib_mode='none')
    check_params(arg_params, qarg_params, qsym)
    check_params(aux_params, qaux_params)

    calib_data = mx.nd.random.uniform(shape=data_shape)
    calib_data = NDArrayIter(data=calib_data)
    calib_data = DummyIter(calib_data)
    qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
                                                                     arg_params=arg_params,
                                                                     aux_params=aux_params,
                                                                     ctx=mx.current_context(),
                                                                     calib_mode='naive',
                                                                     calib_data=calib_data,
                                                                     num_calib_examples=20)
    check_params(arg_params, qarg_params, qsym)
    check_params(aux_params, qaux_params)
    check_qsym_calibrated(qsym) 
开发者ID:mahyarnajibi,项目名称:SNIPER-mxnet,代码行数:52,代码来源:test_quantization.py


注:本文中的mxnet.module.Module方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。