本文整理汇总了Python中mxnet.module.module.Module方法的典型用法代码示例。如果您正苦于以下问题:Python module.Module方法的具体用法?Python module.Module怎么用?Python module.Module使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet.module.module
的用法示例。
在下文中一共展示了module.Module方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: from mxnet.module import module [as 别名]
# 或者: from mxnet.module.module import Module [as 别名]
def forward(self, data_batch, is_train=None):
assert self.binded and self.params_initialized
# get current_shapes
if self._curr_module.label_shapes is not None:
current_shapes = dict(self._curr_module.data_shapes + self._curr_module.label_shapes)
else:
current_shapes = dict(self._curr_module.data_shapes)
# get input_shapes
if data_batch.provide_label is not None:
input_shapes = dict(data_batch.provide_data + data_batch.provide_label)
else:
input_shapes = dict(data_batch.provide_data)
# decide if shape changed
shape_changed = False
for k, v in current_shapes.items():
if v != input_shapes[k]:
shape_changed = True
if shape_changed:
module = Module(self._symbol, self._data_names, self._label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names)
module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training,
self._curr_module.inputs_need_grad, force_rebind=False,
shared_module=self._curr_module)
self._curr_module = module
self._curr_module.forward(data_batch, is_train=is_train)
示例2: borrow_optimizer
# 需要导入模块: from mxnet.module import module [as 别名]
# 或者: from mxnet.module.module import Module [as 别名]
def borrow_optimizer(self, shared_module):
"""Borrows optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
Parameters
----------
shared_module : Module
"""
assert shared_module.optimizer_initialized
self._optimizer = shared_module._optimizer
self._kvstore = shared_module._kvstore
self._update_on_kvstore = shared_module._update_on_kvstore
self._updater = shared_module._updater
self.optimizer_initialized = True
示例3: __init__
# 需要导入模块: from mxnet.module import module [as 别名]
# 或者: from mxnet.module.module import Module [as 别名]
def __init__(self, symbol, data_names, label_names,
logger=logging, context=ctx.cpu(), work_load_list=None,
asymbol = None,
args = None):
super(ParallModule, self).__init__(logger=logger)
self._symbol = symbol
self._asymbol = asymbol
self._data_names = data_names
self._label_names = label_names
self._context = context
self._work_load_list = work_load_list
self._num_classes = config.num_classes
self._batch_size = args.batch_size
self._verbose = args.verbose
self._emb_size = config.emb_size
self._local_class_start = args.local_class_start
self._iter = 0
self._curr_module = None
self._num_workers = config.num_workers
self._num_ctx = len(self._context)
self._ctx_num_classes = args.ctx_num_classes
self._nd_cache = {}
self._ctx_cpu = mx.cpu()
self._ctx_single_gpu = self._context[-1]
self._fixed_param_names = None
self._curr_module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names)
self._arcface_modules = []
self._ctx_class_start = []
for i in range(len(self._context)):
args._ctxid = i
_module = Module(self._asymbol(args), self._data_names, self._label_names, logger=self.logger,
context=mx.gpu(i), work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names)
self._arcface_modules.append(_module)
_c = args.local_class_start + i*args.ctx_num_classes
self._ctx_class_start.append(_c)
self._usekv = False
if self._usekv:
self._distkv = mx.kvstore.create('dist_sync')
self._kvinit = {}
示例4: bind
# 需要导入模块: from mxnet.module import module [as 别名]
# 或者: from mxnet.module.module import Module [as 别名]
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
# in case we already initialized params, keep it
if self.params_initialized:
arg_params, aux_params = self.get_params()
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:
self._reset_bind()
if self.binded:
self.logger.warning('Already binded, ignoring bind()')
return
assert shared_module is None, 'shared_module for MutableModule is not supported'
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
max_shapes_dict = dict()
if self._max_data_shapes is not None:
max_shapes_dict.update(dict(self._max_data_shapes))
if self._max_label_shapes is not None:
max_shapes_dict.update(dict(self._max_label_shapes))
max_data_shapes = list()
for name, shape in data_shapes:
if name in max_shapes_dict:
max_data_shapes.append((name, max_shapes_dict[name]))
else:
max_data_shapes.append((name, shape))
max_label_shapes = list()
if label_shapes is not None:
for name, shape in label_shapes:
if name in max_shapes_dict:
max_label_shapes.append((name, max_shapes_dict[name]))
else:
max_label_shapes.append((name, shape))
if len(max_label_shapes) == 0:
max_label_shapes = None
module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names)
module.bind(max_data_shapes, max_label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None)
self._curr_module = module
# copy back saved params, if already initialized
if self.params_initialized:
self.set_params(arg_params, aux_params)
示例5: bind
# 需要导入模块: from mxnet.module import module [as 别名]
# 或者: from mxnet.module.module import Module [as 别名]
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None):
# in case we already initialized params, keep it
if self.params_initialized:
arg_params, aux_params = self.get_params()
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:
self._reset_bind()
if self.binded:
self.logger.warning('Already binded, ignoring bind()')
return
assert shared_module is None, 'shared_module for MutableModule is not supported'
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
max_shapes_dict = dict()
if self._max_data_shapes is not None:
max_shapes_dict.update(dict(self._max_data_shapes))
if self._max_label_shapes is not None:
max_shapes_dict.update(dict(self._max_label_shapes))
max_data_shapes = list()
for name, shape in data_shapes:
if name in max_shapes_dict:
max_data_shapes.append((name, max_shapes_dict[name]))
else:
max_data_shapes.append((name, shape))
max_label_shapes = list()
if label_shapes is not None:
for name, shape in label_shapes:
if name in max_shapes_dict:
max_label_shapes.append((name, max_shapes_dict[name]))
else:
max_label_shapes.append((name, shape))
if len(max_label_shapes) == 0:
max_label_shapes = None
module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names)
module.bind(max_data_shapes, max_label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None)
self._curr_module = module
# copy back saved params, if already initialized
if self.params_initialized:
self.set_params(arg_params, aux_params)