当前位置: 首页>>代码示例>>Python>>正文


Python mxnet.context方法代码示例

本文整理汇总了Python中mxnet.context方法的典型用法代码示例。如果您正苦于以下问题:Python mxnet.context方法的具体用法?Python mxnet.context怎么用?Python mxnet.context使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet的用法示例。


在下文中一共展示了mxnet.context方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_context

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def get_context() -> mx.context:
    """

    Returns the a list of all available gpu contexts for a given machine.
    If no gpus are available, returns [mx.cpu()].
    Use it to automatically return MxNet contexts (uses max number of gpus or cpu)

    :return: List of mxnet contexts of a gpu or [mx.cpu()] if gpu not available

    """
    context_list = []
    for gpu_number in range(16):
        try:
            _ = mx.nd.array([1, 2, 3], ctx=mx.gpu(gpu_number))
            context_list.append(mx.gpu(gpu_number))
        except mx.MXNetError:
            pass

    if len(context_list) == 0:
        context_list.append(mx.cpu())

    return context_list 
开发者ID:awslabs,项目名称:datawig,代码行数:24,代码来源:utils.py

示例2: __init__

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def __init__(
            self,
            ctx: mx.context,
            label_encoders: List[ColumnEncoder],
            data_featurizers: List[Featurizer],
            final_fc_hidden_units: List[int]
    ):
        """
        Wrapper of internal DataWig MXNet module

        :param ctx: MXNet execution context
        :param label_encoders: list of label column encoders
        :param data_featurizers: list of data featurizers
        :param final_fc_hidden_units: list of number of hidden parameters
        """
        self.ctx = ctx
        self.data_featurizers = data_featurizers
        self.label_encoders = label_encoders
        self.final_fc_hidden_units = final_fc_hidden_units 
开发者ID:awslabs,项目名称:datawig,代码行数:21,代码来源:imputer.py

示例3: test_context

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def test_context():
    ctx_list = []
    ctx_list.append(Context.default_ctx)
    def f():
        set_default_context(mx.gpu(11))
        ctx_list.append(Context.default_ctx)
    thread = threading.Thread(target=f)
    thread.start()
    thread.join()
    assert Context.devtype2str[ctx_list[0].device_typeid] == "cpu"
    assert ctx_list[0].device_id == 0
    assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu"
    assert ctx_list[1].device_id == 11

    event = threading.Event()
    status = [False]
    def g():
        with mx.cpu(10):
            event.wait()
            if Context.default_ctx.device_id == 10:
                status[0] = True
    thread = threading.Thread(target=g)
    thread.start()
    Context.default_ctx = Context("cpu", 11)
    event.set()
    thread.join()
    event.clear()
    assert status[0], "Spawned thread didn't set the correct context" 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:30,代码来源:test_thread_local.py

示例4: _get_ctx

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def _get_ctx(self):
        """return correct context , priority: gpu > cpu

        Returns
        -------
        ctx: mx.context
        """
        if has_gpu():
            return mx.gpu()
        else:
            return mx.cpu() 
开发者ID:geek-ai,项目名称:MAgent,代码行数:13,代码来源:base.py

示例5: __call__

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def __call__(self,
                 iter_train: ImputerIterDf) -> mx.mod.Module:
        """
        Given a training iterator, build MXNet module and return it

        :param iter_train: Training data iterator
        :return: mx.mod.Module
        """

        predictions, loss = self.__make_loss()

        logger.debug("Building output symbols")
        output_symbols = []
        for col_enc, output in zip(self.label_encoders, predictions):
            output_symbols.append(
                mx.sym.BlockGrad(output, name="pred-{}".format(col_enc.output_column)))

        mod = mx.mod.Module(
            mx.sym.Group([loss] + output_symbols),
            context=self.ctx,
            # [name for name, dim in iter_train.provide_data],
            data_names=[name for name, dim in iter_train.provide_data if name in loss.list_arguments()],
            label_names=[name for name, dim in iter_train.provide_label]
        )

        if mod.binded is False:
            mod.bind(data_shapes=[d for d in iter_train.provide_data if d.name in loss.list_arguments()],  # iter_train.provide_data,
                     label_shapes=iter_train.provide_label)

        return mod 
开发者ID:awslabs,项目名称:datawig,代码行数:32,代码来源:imputer.py

示例6: load

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import context [as 别名]
def load(output_path: str) -> Any:
        """

        Loads model from output path

        :param output_path: output_path field of trained Imputer model
        :return: imputer model

        """

        logger.debug("Output path for loading Imputer {}".format(output_path))
        params = pickle.load(open(os.path.join(output_path, "imputer.pickle"), "rb"))
        imputer_signature = inspect.getfullargspec(Imputer.__init__)[0]
        # get constructor args
        constructor_args = {p: params[p] for p in imputer_signature if p != 'self'}
        non_constructor_args = {p: params[p] for p in params.keys() if
                                p not in ['self'] + list(constructor_args.keys())}

        # use all relevant fields to instantiate Imputer
        imputer = Imputer(**constructor_args)
        # then set all other args
        for arg, value in non_constructor_args.items():
            setattr(imputer, arg, value)

        # the module path must be updated when loading the Imputer, too
        imputer.module_path = os.path.join(output_path, 'model')
        imputer.output_path = output_path
        # make sure that the context for this deserialized model is available
        ctx = get_context()

        logger.debug("Loading mxnet model from {}".format(imputer.module_path))

        # for categorical outputs, instance weight is added
        if isinstance(imputer.label_encoders[0], NumericalEncoder):
            data_names = [s.field_name for s in imputer.data_featurizers]
        else:
            data_names = [s.field_name for s in imputer.data_featurizers] + [INSTANCE_WEIGHT_COLUMN]

        # deserialize mxnet module
        imputer.module = mx.module.Module.load(
            imputer.module_path,
            imputer.__get_best_epoch(),
            context=ctx,
            data_names=data_names,
            label_names=[s.output_column for s in imputer.label_encoders]
        )
        return imputer 
开发者ID:awslabs,项目名称:datawig,代码行数:49,代码来源:imputer.py


注:本文中的mxnet.context方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。