本文整理汇总了Python中mxnet.module.Module方法的典型用法代码示例。如果您正苦于以下问题:Python module.Module方法的具体用法?Python module.Module怎么用?Python module.Module使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet.module
的用法示例。
在下文中一共展示了module.Module方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=mx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None, group2ctxs=None,
compression_params=None, update_freq=None):
super(SVRGModule, self).__init__(symbol, data_names=data_names, label_names=label_names, logger=logger,
context=context, work_load_list=work_load_list,
fixed_param_names=fixed_param_names, state_names=state_names,
group2ctxs=group2ctxs, compression_params=compression_params)
# Type check update_frequency
if isinstance(update_freq, int):
if update_freq <= 0:
raise ValueError("update_freq in SVRGModule must be a positive integer to represent the frequency for "
"calculating full gradients")
self.update_freq = update_freq
else:
raise TypeError("update_freq in SVRGModule must be an integer to represent the frequency for "
"calculating full gradients")
self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, context, work_load_list,
fixed_param_names, state_names, group2ctxs, compression_params)
self._param_dict = None
self._ctx_len = len(self._context)
示例2: __init__
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def __init__(self, symbol, data_names, label_names,
data_shapes, label_shapes, logger=logging,
context=mx.cpu(), work_load_list=None, fixed_param_names=None):
self.symbol = symbol
self.data_names = data_names
self.label_names = label_names
self.data_shapes = data_shapes
self.label_shapes = label_shapes
self.context = context
self.work_load_list = work_load_list
self.fixed_param_names = fixed_param_names
if logger is None:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
self.logger = logger
self.module = Module(symbol=self.symbol, data_names=self.data_names,
label_names=self.label_names, logger=self.logger,
context=self.context, work_load_list=self.work_load_list,
fixed_param_names=self.fixed_param_names)
示例3: bind
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
"""Binds the symbols to construct executors for both two modules. This is necessary before one
can perform computation with the SVRGModule.
Parameters
----------
data_shapes : list of (str, tuple)
Typically is ``data_iter.provide_data``.
label_shapes : list of (str, tuple)
Typically is ``data_iter.provide_label``.
for_training : bool
Default is ``True``. Whether the executors should be bound for training.
inputs_need_grad : bool
Default is ``False``. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
force_rebind : bool
Default is ``False``. This function does nothing if the executors are already
bound. But with this ``True``, the executors will be forced to rebind.
shared_module : Module
Default is ``None``. This is used in bucketing. When not ``None``, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
"""
# force rebinding is typically used when one want to switch from
# training to prediction phase.
super(SVRGModule, self).bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind,
shared_module, grad_req)
if for_training:
self._mod_aux.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module,
grad_req)
示例4: check_qsym_forward
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape):
mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
mod.set_params(qarg_params, qaux_params)
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()
return mod.get_outputs()
示例5: __init__
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def __init__(self, symbol, data_names, label_names,
context=mx.cpu(), max_data_shapes=None,
provide_data=None, provide_label=None,
arg_params=None, aux_params=None):
#self._mod = MutableModule(symbol, data_names, label_names,
# context=context, max_data_shapes=max_data_shapes)
self._mod = Module(symbol, data_names, label_names, context=context)
self._mod.bind(provide_data, provide_label, for_training=False)
self._mod.init_params(arg_params=arg_params, aux_params=aux_params)
示例6: demo_net
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def demo_net(sym, class_names, args):
# print config
print('called with args\n{}'.format(pprint.pformat(vars(args))))
# setup context
if args.gpu:
ctx = mx.gpu(int(args.gpu))
else:
ctx = mx.cpu(0)
# load single test
im_tensor, im_info, im_orig = load_test(args.image, short=args.img_short_side, max_size=args.img_long_side,
mean=args.img_pixel_means, std=args.img_pixel_stds)
# generate data batch
data_batch = generate_batch(im_tensor, im_info)
# load params
arg_params, aux_params = load_param(args.params, ctx=ctx)
# produce shape max possible
data_names = ['data', 'im_info']
label_names = None
data_shapes = [('data', (1, 3, args.img_long_side, args.img_long_side)), ('im_info', (1, 3))]
label_shapes = None
# check shapes
check_shape(sym, data_shapes, arg_params, aux_params)
# create and bind module
mod = Module(sym, data_names, label_names, context=ctx)
mod.bind(data_shapes, label_shapes, for_training=False)
mod.init_params(arg_params=arg_params, aux_params=aux_params)
# forward
mod.forward(data_batch)
rois, scores, bbox_deltas = mod.get_outputs()
rois = rois[:, 1:]
scores = scores[0]
bbox_deltas = bbox_deltas[0]
im_info = im_info[0]
# decode detection
det = im_detect(rois, scores, bbox_deltas, im_info,
bbox_stds=args.rcnn_bbox_stds, nms_thresh=args.rcnn_nms_thresh,
conf_thresh=args.rcnn_conf_thresh)
# print out
for [cls, conf, x1, y1, x2, y2] in det:
if cls > 0 and conf > args.vis_thresh:
print(class_names[int(cls)], conf, [x1, y1, x2, y2])
# if vis
if args.vis:
vis_detection(im_orig, det, class_names, thresh=args.vis_thresh)
示例7: test_net
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def test_net(sym, imdb, args):
# print config
logger.info('called with args\n{}'.format(pprint.pformat(vars(args))))
# setup context
ctx = mx.gpu(args.gpu)
# load testing data
test_data = TestLoader(imdb.roidb, batch_size=1, short=args.img_short_side, max_size=args.img_long_side,
mean=args.img_pixel_means, std=args.img_pixel_stds)
# load params
arg_params, aux_params = load_param(args.params, ctx=ctx)
# produce shape max possible
data_names = ['data', 'im_info']
label_names = None
data_shapes = [('data', (1, 3, args.img_long_side, args.img_long_side)), ('im_info', (1, 3))]
label_shapes = None
# check shapes
check_shape(sym, data_shapes, arg_params, aux_params)
# create and bind module
mod = Module(sym, data_names, label_names, context=ctx)
mod.bind(data_shapes, label_shapes, for_training=False)
mod.init_params(arg_params=arg_params, aux_params=aux_params)
# all detections are collected into:
# all_boxes[cls][image] = N x 5 array of detections in
# (x1, y1, x2, y2, score)
all_boxes = [[[] for _ in range(imdb.num_images)]
for _ in range(imdb.num_classes)]
# start detection
with tqdm(total=imdb.num_images) as pbar:
for i, data_batch in enumerate(test_data):
# forward
im_info = data_batch.data[1][0]
mod.forward(data_batch)
rois, scores, bbox_deltas = mod.get_outputs()
rois = rois[:, 1:]
scores = scores[0]
bbox_deltas = bbox_deltas[0]
det = im_detect(rois, scores, bbox_deltas, im_info,
bbox_stds=args.rcnn_bbox_stds, nms_thresh=args.rcnn_nms_thresh,
conf_thresh=args.rcnn_conf_thresh)
for j in range(1, imdb.num_classes):
indexes = np.where(det[:, 0] == j)[0]
all_boxes[j][i] = np.concatenate((det[:, -4:], det[:, [1]]), axis=-1)[indexes, :]
pbar.update(data_batch.data[0].shape[0])
# evaluate model
imdb.evaluate_detections(all_boxes)
示例8: test_quantize_model
# 需要导入模块: from mxnet import module [as 别名]
# 或者: from mxnet.module import Module [as 别名]
def test_quantize_model():
def check_params(params, qparams, qsym=None):
if qsym is None:
assert len(params) == len(qparams)
for k, v in params.items():
assert k in qparams
assert same(v.asnumpy(), qparams[k].asnumpy())
else:
qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params)
assert len(qparams) == len(qparams_ground_truth)
for k, v in qparams_ground_truth.items():
assert k in qparams
assert same(v.asnumpy(), qparams[k].asnumpy())
def check_qsym_calibrated(qsym):
attrs = qsym.attr_dict()
for k, v in attrs.items():
if k.find('requantize_') != -1:
assert 'min_calib_range' in v
assert 'max_calib_range' in v
sym = get_fp32_sym()
mod = Module(symbol=sym)
batch_size = 4
data_shape = (batch_size, 4, 10, 10)
label_shape = (batch_size, 10)
mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)])
mod.init_params()
arg_params, aux_params = mod.get_params()
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
arg_params=arg_params,
aux_params=aux_params,
ctx=mx.current_context(),
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
calib_data = mx.nd.random.uniform(shape=data_shape)
calib_data = NDArrayIter(data=calib_data)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
arg_params=arg_params,
aux_params=aux_params,
ctx=mx.current_context(),
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)