本文整理汇总了Python中mxnet.context方法的典型用法代码示例。如果您正苦于以下问题:Python mxnet.context方法的具体用法?Python mxnet.context怎么用?Python mxnet.context使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet
的用法示例。
在下文中一共展示了mxnet.context方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_context
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def get_context() -> mx.context:
"""
Returns the a list of all available gpu contexts for a given machine.
If no gpus are available, returns [mx.cpu()].
Use it to automatically return MxNet contexts (uses max number of gpus or cpu)
:return: List of mxnet contexts of a gpu or [mx.cpu()] if gpu not available
"""
context_list = []
for gpu_number in range(16):
try:
_ = mx.nd.array([1, 2, 3], ctx=mx.gpu(gpu_number))
context_list.append(mx.gpu(gpu_number))
except mx.MXNetError:
pass
if len(context_list) == 0:
context_list.append(mx.cpu())
return context_list
示例2: __init__
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def __init__(
self,
ctx: mx.context,
label_encoders: List[ColumnEncoder],
data_featurizers: List[Featurizer],
final_fc_hidden_units: List[int]
):
"""
Wrapper of internal DataWig MXNet module
:param ctx: MXNet execution context
:param label_encoders: list of label column encoders
:param data_featurizers: list of data featurizers
:param final_fc_hidden_units: list of number of hidden parameters
"""
self.ctx = ctx
self.data_featurizers = data_featurizers
self.label_encoders = label_encoders
self.final_fc_hidden_units = final_fc_hidden_units
示例3: test_context
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def test_context():
ctx_list = []
ctx_list.append(Context.default_ctx)
def f():
set_default_context(mx.gpu(11))
ctx_list.append(Context.default_ctx)
thread = threading.Thread(target=f)
thread.start()
thread.join()
assert Context.devtype2str[ctx_list[0].device_typeid] == "cpu"
assert ctx_list[0].device_id == 0
assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu"
assert ctx_list[1].device_id == 11
event = threading.Event()
status = [False]
def g():
with mx.cpu(10):
event.wait()
if Context.default_ctx.device_id == 10:
status[0] = True
thread = threading.Thread(target=g)
thread.start()
Context.default_ctx = Context("cpu", 11)
event.set()
thread.join()
event.clear()
assert status[0], "Spawned thread didn't set the correct context"
示例4: _get_ctx
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def _get_ctx(self):
"""return correct context , priority: gpu > cpu
Returns
-------
ctx: mx.context
"""
if has_gpu():
return mx.gpu()
else:
return mx.cpu()
示例5: __call__
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def __call__(self,
iter_train: ImputerIterDf) -> mx.mod.Module:
"""
Given a training iterator, build MXNet module and return it
:param iter_train: Training data iterator
:return: mx.mod.Module
"""
predictions, loss = self.__make_loss()
logger.debug("Building output symbols")
output_symbols = []
for col_enc, output in zip(self.label_encoders, predictions):
output_symbols.append(
mx.sym.BlockGrad(output, name="pred-{}".format(col_enc.output_column)))
mod = mx.mod.Module(
mx.sym.Group([loss] + output_symbols),
context=self.ctx,
# [name for name, dim in iter_train.provide_data],
data_names=[name for name, dim in iter_train.provide_data if name in loss.list_arguments()],
label_names=[name for name, dim in iter_train.provide_label]
)
if mod.binded is False:
mod.bind(data_shapes=[d for d in iter_train.provide_data if d.name in loss.list_arguments()], # iter_train.provide_data,
label_shapes=iter_train.provide_label)
return mod
示例6: load
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def load(output_path: str) -> Any:
"""
Loads model from output path
:param output_path: output_path field of trained Imputer model
:return: imputer model
"""
logger.debug("Output path for loading Imputer {}".format(output_path))
params = pickle.load(open(os.path.join(output_path, "imputer.pickle"), "rb"))
imputer_signature = inspect.getfullargspec(Imputer.__init__)[0]
# get constructor args
constructor_args = {p: params[p] for p in imputer_signature if p != 'self'}
non_constructor_args = {p: params[p] for p in params.keys() if
p not in ['self'] + list(constructor_args.keys())}
# use all relevant fields to instantiate Imputer
imputer = Imputer(**constructor_args)
# then set all other args
for arg, value in non_constructor_args.items():
setattr(imputer, arg, value)
# the module path must be updated when loading the Imputer, too
imputer.module_path = os.path.join(output_path, 'model')
imputer.output_path = output_path
# make sure that the context for this deserialized model is available
ctx = get_context()
logger.debug("Loading mxnet model from {}".format(imputer.module_path))
# for categorical outputs, instance weight is added
if isinstance(imputer.label_encoders[0], NumericalEncoder):
data_names = [s.field_name for s in imputer.data_featurizers]
else:
data_names = [s.field_name for s in imputer.data_featurizers] + [INSTANCE_WEIGHT_COLUMN]
# deserialize mxnet module
imputer.module = mx.module.Module.load(
imputer.module_path,
imputer.__get_best_epoch(),
context=ctx,
data_names=data_names,
label_names=[s.output_column for s in imputer.label_encoders]
)
return imputer