当前位置: 首页>>代码示例>>Python>>正文


Python onnx.load_model方法代码示例

本文整理汇总了Python中onnx.load_model方法的典型用法代码示例。如果您正苦于以下问题:Python onnx.load_model方法的具体用法?Python onnx.load_model怎么用?Python onnx.load_model使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在onnx的用法示例。


在下文中一共展示了onnx.load_model方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: import_to_gluon

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def import_to_gluon(model_file, ctx):
    """
    Imports the ONNX model files, passed as a parameter, into Gluon SymbolBlock object.

    Parameters
    ----------
    model_file : str
        ONNX model file name
    ctx : Context or list of Context
        Loads the model into one or many context(s).

    Returns
    -------
    sym_block : :class:`~mxnet.gluon.SymbolBlock`
        A SymbolBlock object representing the given model file.
    """
    graph = GraphProto()
    try:
        import onnx
    except ImportError:
        raise ImportError("Onnx and protobuf need to be installed. Instructions to"
                          + " install - https://github.com/onnx/onnx#installation")
    model_proto = onnx.load_model(model_file)
    net = graph.graph_to_gluon(model_proto.graph, ctx)
    return net 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:27,代码来源:import_to_gluon.py

示例2: generate_onnx_file

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def generate_onnx_file():
    logging.basicConfig(level=logging.INFO)

    # set the proper symbol path, param path and onnx path
    symbol_path = '../symbol_farm/symbol_10_320_20L_5scales_v2_deploy.json'
    param_path = '../saved_model/configuration_10_320_20L_5scales_v2/train_10_320_20L_5scales_v2_iter_1800000.params'
    onnx_path = './onnx_files/v2.onnx'

    net_symbol = mxnet.symbol.load(symbol_path)
    net_params_raw = mxnet.nd.load(param_path)
    net_params = dict()
    for k, v in net_params_raw.items():
        tp, name = k.split(':', 1)
        net_params.update({name: v})

    input_shape = (1, 3, 480, 640)  # CAUTION: in TensorRT, the input size cannot be changed dynamically, so you must set it here.

    onnx_mxnet.export_model(net_symbol, net_params, [input_shape], numpy.float32, onnx_path, verbose=True)

    # Load onnx model
    model_proto = onnx.load_model(onnx_path)

    # Check if converted ONNX protobuf is valid
    checker.check_graph(model_proto.graph) 
开发者ID:becauseofAI,项目名称:lffd-pytorch,代码行数:26,代码来源:to_onnx.py

示例3: torch2tvm_module

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def torch2tvm_module(torch_module: torch.nn.Module, torch_inputs: Tuple[torch.Tensor, ...], target):
    torch_module.eval()
    input_names = []
    input_shapes = {}
    with torch.no_grad():
        for index, torch_input in enumerate(torch_inputs):
            name = "i" + str(index)
            input_names.append(name)
            input_shapes[name] = torch_input.shape
        buffer = io.BytesIO()
        torch.onnx.export(torch_module, torch_inputs, buffer, input_names=input_names, output_names=["o" + str(i) for i in range(len(torch_inputs))])
        outs = torch_module(*torch_inputs)
        buffer.seek(0, 0)
        onnx_model = onnx.load_model(buffer)
        relay_module, params = tvm.relay.frontend.from_onnx(onnx_model, shape=input_shapes)
    with tvm.relay.build_config(opt_level=3):
        graph, tvm_module, params = tvm.relay.build(relay_module, target, params=params)
    return graph, tvm_module, params 
开发者ID:mit-han-lab,项目名称:temporal-shift-module,代码行数:20,代码来源:main.py

示例4: import_model

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def import_model(model_file):
    """Imports the ONNX model file, passed as a parameter, into MXNet symbol and parameters.
    Operator support and coverage -
    https://cwiki.apache.org/confluence/display/MXNET/MXNet-ONNX+Integration

    Parameters
    ----------
    model_file : str
        ONNX model file name

    Returns
    -------
    sym : :class:`~mxnet.symbol.Symbol`
        MXNet symbol object

    arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
        Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format

    aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
        Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
    """
    graph = GraphProto()

    try:
        import onnx
    except ImportError:
        raise ImportError("Onnx and protobuf need to be installed. "
                          + "Instructions to install - https://github.com/onnx/onnx")
    # loads model file and returns ONNX protobuf object
    model_proto = onnx.load_model(model_file)
    sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
    return sym, arg_params, aux_params 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:34,代码来源:import_model.py

示例5: get_model_metadata

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def get_model_metadata(model_file):
    """
    Returns the name and shape information of input and output tensors of the given ONNX model file.

    Parameters
    ----------
    model_file : str
        ONNX model file name

    Returns
    -------
    model_metadata : dict
        A dictionary object mapping various metadata to its corresponding value.
        The dictionary will have the following template.
        {
            'input_tensor_data' : <list of tuples representing the shape of the input paramters>,
            'output_tensor_data' : <list of tuples representing the shape of the output
                                    of the model>
        }

    """
    graph = GraphProto()
    try:
        import onnx
    except ImportError:
        raise ImportError("Onnx and protobuf need to be installed. "
                          + "Instructions to install - https://github.com/onnx/onnx")
    model_proto = onnx.load_model(model_file)
    metadata = graph.get_graph_metadata(model_proto.graph)
    return metadata 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:32,代码来源:import_model.py

示例6: run

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def run(args):
    onnx_model = onnx.load_model(run_onnx_util.onnx_model_file(args.test_dir, args.model_file))
    ctx = tvm.gpu()

    input_names, output_names = run_onnx_util.onnx_input_output_names(
        os.path.join(args.test_dir, args.model_file))

    test_data_dir = os.path.join(args.test_dir, 'test_data_set_0')
    inputs, outputs = run_onnx_util.load_test_data(
        test_data_dir, input_names, output_names)

    inputs = dict(inputs)
    graph_module = None
    if args.frontend == 'nnvm':
        graph_module = build_graph_nnvm(args, ctx, onnx_model, inputs, input_names)
    elif args.frontend == 'relay':
        graph_module = build_graph_relay(args, ctx, onnx_model, inputs, input_names)
    else:
        raise RuntimeError('Invalid frontend: {}'.format(args.frontend))

    graph_module.run()

    for i, (name, expected) in enumerate(outputs):
        tvm_output = tvm.nd.empty(expected.shape, expected.dtype, ctx=ctx)
        actual = graph_module.get_output(i, tvm_output).asnumpy()
        np.testing.assert_allclose(expected, actual,
                                   rtol=1e-3, atol=1e-4), name
        print('%s: OK' % name)
    print('ALL OK')

    def compute():
        graph_module.run()
        cupy.cuda.device.Device().synchronize()

    return run_onnx_util.run_benchmark(compute, args.iterations) 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:37,代码来源:run_onnx_tvm.py

示例7: test_save_and_load_model

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def test_save_and_load_model(self):  # type: () -> None
        proto = self._simple_model()
        cls = ModelProto
        proto_string = onnx._serialize(proto)

        # Test if input is string
        loaded_proto = onnx.load_model_from_string(proto_string)
        self.assertTrue(proto == loaded_proto)

        # Test if input has a read function
        f = io.BytesIO()
        onnx.save_model(proto_string, f)
        f = io.BytesIO(f.getvalue())
        loaded_proto = onnx.load_model(f, cls)
        self.assertTrue(proto == loaded_proto)

        # Test if input is a file name
        try:
            fi = tempfile.NamedTemporaryFile(delete=False)
            onnx.save_model(proto, fi)
            fi.close()

            loaded_proto = onnx.load_model(fi.name, cls)
            self.assertTrue(proto == loaded_proto)
        finally:
            os.remove(fi.name) 
开发者ID:mlperf,项目名称:training_results_v0.6,代码行数:28,代码来源:basic_test.py

示例8: check_model_expect

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def check_model_expect(test_path, input_names=None, rtol=1e-5, atol=1e-5):
    if not ONNXRUNTIME_AVAILABLE:
        raise ImportError('ONNX Runtime is not found on checking module.')

    model_path = os.path.join(test_path, 'model.onnx')
    with open(model_path, 'rb') as f:
        onnx_model = onnx.load_model(f)
    sess = rt.InferenceSession(onnx_model.SerializeToString())
    rt_input_names = [value.name for value in sess.get_inputs()]
    rt_output_names = [value.name for value in sess.get_outputs()]

    # To detect unexpected inputs created by exporter, check input names
    if input_names is not None:
        assert list(sorted(input_names)) == list(sorted(rt_input_names))

    test_data_sets = sorted([
        p for p in os.listdir(test_path) if p.startswith('test_data_set_')])
    for test_data in test_data_sets:
        test_data_path = os.path.join(test_path, test_data)
        assert os.path.isdir(test_data_path)
        inputs, outputs = load_test_data(
            test_data_path, rt_input_names, rt_output_names)

        rt_out = sess.run(list(outputs.keys()), inputs)
        for cy, my in zip(outputs.values(), rt_out):
            np.testing.assert_allclose(cy, my, rtol=rtol, atol=atol) 
开发者ID:chainer,项目名称:chainer,代码行数:28,代码来源:test_onnxruntime.py

示例9: test_mobilenetv2

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def test_mobilenetv2(self):
        try:
            import onnx
            from dlpy.model_conversion.onnx_transforms import (Transformer, OpTypePattern,
                                                               ConstToInitializer,
                                                               InitReshape, InitUnsqueeze,
                                                               FuseMulAddBN)
            from dlpy.model_conversion.onnx_graph import OnnxGraph
            from onnx import helper, numpy_helper
        except:
            unittest.TestCase.skipTest(self, 'onnx package not found')

        from dlpy.model import Model

        if self.data_dir_local is None:
            unittest.TestCase.skipTest(self, "DLPY_DATA_DIR_LOCAL is not set in the environment variables")


        path = os.path.join(self.data_dir_local, 'mobilenetv2-1.0.onnx')

        onnx_model = onnx.load_model(path)
        model1 = Model.from_onnx_model(self.s,
                                       onnx_model,
                                       output_model_table='mobilenetv2',
                                       offsets=255*[0.485, 0.456, 0.406],
                                       norm_stds=255*[0.229, 0.224, 0.225]) 
开发者ID:sassoftware,项目名称:python-dlpy,代码行数:28,代码来源:test_model.py

示例10: verify_onnx_forward_impl

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
    dtype = 'float32'
    x = np.random.uniform(size=data_shape)
    model = onnx.load_model(graph_file)
    c2_out = get_onnxruntime_output(model, x, dtype)
    for target, ctx in ctx_list():
        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
        tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:10,代码来源:test_forward.py

示例11: main

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def main():
    if len(sys.argv) < 4:
        print('decast.py model_in  model_out <op1, ...>')
        return

    input = sys.argv[1]
    output = sys.argv[2]
    op_list = sys.argv[3:]

    oxml = onnx.load_model(input)
    oxml = decast(oxml, op_list)
    onnx.save_model(oxml, output) 
开发者ID:microsoft,项目名称:onnxconverter-common,代码行数:14,代码来源:decast.py

示例12: load

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def load(path_or_model, name=None, inputs=None, outputs=None):
        """
        Construct a Graph object by loading an ONNX model.
        """
        oxml = onnx.load_model(path_or_model) if isinstance(path_or_model, str) else path_or_model
        for opset_import in oxml.opset_import:
            if opset_import.domain == '':
                if Graph.opset != opset_import.version:
                    raise RuntimeError("Graph opset and model opset mismatch: Graph opset = " + str(Graph.opset)
                                       + ", model opset = " + str(opset_import.version))
                break
        g = Graph(name or oxml.graph.name)
        g._bind(oxml, inputs=inputs, outputs=outputs)
        return g 
开发者ID:microsoft,项目名称:onnxconverter-common,代码行数:16,代码来源:onnx_fx.py

示例13: test_onnx_models

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def test_onnx_models(self):
        model_names = ['mobile_segnet_no_opt.onnx', 'srgan_no_opt.onnx', 'test_model_0_no_opt.onnx',
                       'test_model_1_no_opt.onnx']
        num_transpose_list = [2, 3, 11, 5]
        dir_path = os.path.dirname(os.path.realpath(__file__))
        for idx_, model_name_ in enumerate(model_names):
            model_dir = dir_path + '/data/' + model_name_
            origin_model = onnx.load_model(model_dir)
            opt_model = optimize_onnx_model(origin_model)
            self.assertIsNotNone(opt_model)
            num_transpose = sum([1 if n_.op_type == 'Transpose' else 0 for n_ in opt_model.graph.node])
            self.assertEqual(num_transpose, num_transpose_list[idx_]) 
开发者ID:microsoft,项目名称:onnxconverter-common,代码行数:14,代码来源:test_opt.py

示例14: test_set_denotation

# 需要导入模块: import onnx [as 别名]
# 或者: from onnx import load_model [as 别名]
def test_set_denotation(self):
        this = os.path.dirname(__file__)
        onnx_file = os.path.join(this, "coreml_OneHotEncoder_BikeSharing.onnx")
        onnx_model = onnx.load_model(onnx_file)
        set_denotation(onnx_model, "1", "IMAGE", onnx.defs.onnx_opset_version(), dimension_denotation=["DATA_FEATURE"])
        self.assertEqual(onnx_model.graph.input[0].type.denotation, "IMAGE")
        self.assertEqual(onnx_model.graph.input[0].type.tensor_type.shape.dim[0].denotation, "DATA_FEATURE") 
开发者ID:microsoft,项目名称:onnxconverter-common,代码行数:9,代码来源:test_onnx.py


注:本文中的onnx.load_model方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。