本文整理汇总了Python中mxnet.nd.load方法的典型用法代码示例。如果您正苦于以下问题:Python nd.load方法的具体用法?Python nd.load怎么用?Python nd.load使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet.nd
的用法示例。
在下文中一共展示了nd.load方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_model
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def load_model(symbol_file, param_file, logger=None):
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, symbol_file)
if logger is not None:
logger.info('Loading symbol from file %s' % symbol_file_path)
symbol = mx.sym.load(symbol_file_path)
param_file_path = os.path.join(cur_path, param_file)
if logger is not None:
logger.info('Loading params from file %s' % param_file_path)
save_dict = nd.load(param_file_path)
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return symbol, arg_params, aux_params
示例2: load_model
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def load_model(_symbol_file, _param_file, _logger=None):
"""load existing symbol model"""
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, _symbol_file)
if _logger is not None:
_logger.info('Loading symbol from file %s' % symbol_file_path)
symbol = mx.sym.load(symbol_file_path)
param_file_path = os.path.join(cur_path, _param_file)
if _logger is not None:
_logger.info('Loading params from file %s' % param_file_path)
save_dict = nd.load(param_file_path)
_arg_params = {}
_aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
_arg_params[name] = v
if tp == 'aux':
_aux_params[name] = v
return symbol, _arg_params, _aux_params
示例3: load_model
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def load_model(symbol_file, param_file, mlogger=None):
"""load existing symbol model"""
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, symbol_file)
if mlogger is not None:
mlogger.info('Loading symbol from file %s' % symbol_file_path)
symbol = mx.sym.load(symbol_file_path)
param_file_path = os.path.join(cur_path, param_file)
if mlogger is not None:
mlogger.info('Loading params from file %s' % param_file_path)
save_dict = nd.load(param_file_path)
marg_params = {}
maux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
marg_params[name] = v
if tp == 'aux':
maux_params[name] = v
return symbol, marg_params, maux_params
示例4: initialize_inference
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def initialize_inference(inference, pretrained, start_epoch):
if pretrained:
print('Loading the pretrained model')
vggface_weights = nd.load('ckpt/VGG-FACE/VGG_FACE-0000.params')
# change the name
checkpoint = {}
vgg_face_layers = [2, 2, 3, 3, 3]
for k, v in vggface_weights.items():
if 'conv' in k:
ind1, ind2, sub_name = k.split('_')
ind1 = int(ind1.replace('arg:conv', '')) - 1
ind2 = int(ind2[-1]) - 1
ind = sum(vgg_face_layers[:ind1]) + ind2
key = inference.name + '_conv' + str(ind) + '_' + sub_name
checkpoint[key] = v
# load the weights
for k in inference.collect_params().keys():
if k in checkpoint:
inference.collect_params()[k]._load_init(checkpoint[k], ctx)
print('Loaded %s weights from checkpoints' % k)
else:
inference.collect_params()[k].initialize(ctx=ctx)
print('Initialize %s weights' % k)
print('Done')
elif start_epoch > 0:
print('Loading the weights from [%d] epoch' % start_epoch)
inference.load_params(os.path.join(args.ckpt_dir, args.prefix, '%s-%d.params' % (args.prefix, start_epoch)), ctx)
else:
inference.collect_params().initialize(ctx=ctx)
return inference
示例5: benchmark_score
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
# get mod
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, symbol_file)
if logger is not None:
logger.info('Loading symbol from file %s' % symbol_file_path)
sym = mx.sym.load(symbol_file_path)
mod = mx.mod.Module(symbol=sym, context=ctx)
mod.bind(for_training = False,
inputs_need_grad = False,
data_shapes = [('data', (batch_size,)+data_shape)])
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
# get data
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, []) # empty label
# run
dry_run = 5 # use 5 iterations to warm up
for i in range(dry_run+num_batches):
if i == dry_run:
tic = time.time()
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()
# return num images per second
return num_batches*batch_size/(time.time() - tic)
示例6: load_object
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def load_object(filename):
with open(filename, 'rb') as input:
return pickle.load(input)
示例7: __init__
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def __init__(self, net, params_filename):
"""
A helper for freezing.
:param net: mxnet.gluon.nn.Block
The origin net that you want to load trained parameters and freeze.
:param params_filename: str
The filename of the trained parameters.
# :param input_shape: tuple
# The shape of input. For example, (1, 3, 224, 224) for MobileNet.
"""
self.origin_net = net
self.gluon_params_filename = params_filename
self.sym, self.args, self.auxes = None, None, None
net.load_parameters(params_filename, ignore_extra=True)
net.hybridize()
x = mx.sym.var('data')
y = net(x)
y = mx.sym.SoftmaxOutput(data=y, name='softmax')
self.sym = mx.symbol.load_json(y.tojson()).get_backend_symbol("MKLDNN")
self.args = {}
self.auxes = {}
params = net.collect_params()
# print(params)
for param in params.values():
v = param._reduce()
k = param.name
if 'running' in k:
self.auxes[k] = v
else:
self.args[k] = v
示例8: _act_max_list
# 需要导入模块: from mxnet import nd [as 别名]
# 或者: from mxnet.nd import load [as 别名]
def _act_max_list(self):
gluon_params = nd.load(self.gluon_params_filename)
act_max_list = OrderedDict()
for k in gluon_params.keys():
*others, attr_name = k.split(".")
if attr_name == "act_max":
atom_block = functools.reduce(
lambda b, n: b[int(n)] if self._is_number(n) else getattr(b, n),
others, self.origin_net
)
act_max_list[f'{atom_block.name}'] = gluon_params[k].asscalar()
return act_max_list