当前位置: 首页>>代码示例>>Python>>正文


Python torch.onnx方法代码示例

本文整理汇总了Python中torch.onnx方法的典型用法代码示例。如果您正苦于以下问题:Python torch.onnx方法的具体用法?Python torch.onnx怎么用?Python torch.onnx使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.onnx方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: pytorch2onnx

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def pytorch2onnx(args):
    # PyTorch exports to ONNX without the need for an external converter
    import torch
    from torch.autograd import Variable
    import torch.onnx
    import torchvision
    # Create input with the correct dimensions of the input of your model
    if args.model_input_shapes == None:
        raise ValueError("Please provide --model_input_shapes to convert Pytorch models.")
    dummy_model_input = []
    if len(args.model_input_shapes) == 1:
        dummy_model_input = Variable(torch.randn(*args.model_input_shapes))
    else:
        for shape in args.model_input_shapes:
            dummy_model_input.append(Variable(torch.randn(*shape)))

    # load the PyTorch model
    model = torch.load(args.model, map_location="cpu")

    # export the PyTorch model as an ONNX protobuf
    torch.onnx.export(model, dummy_model_input, args.output_onnx_path) 
开发者ID:microsoft,项目名称:OLive,代码行数:23,代码来源:onnx_converter.py

示例2: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def main(args):
    dataset = load_config(args.dataset)

    num_classes = len(dataset["common"]["classes"])
    net = UNet(num_classes)

    def map_location(storage, _):
        return storage.cpu()

    chkpt = torch.load(args.checkpoint, map_location=map_location)
    net = torch.nn.DataParallel(net)
    net.load_state_dict(chkpt["state_dict"])

    # Todo: make input channels configurable, not hard-coded to three channels for RGB
    batch = torch.autograd.Variable(torch.randn(1, 3, args.image_size, args.image_size))

    torch.onnx.export(net, batch, args.model) 
开发者ID:mapbox,项目名称:robosat,代码行数:19,代码来源:export.py

示例3: onnx_inference

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def onnx_inference(args):
    # Load the ONNX model
    model = onnx.load("models/deepspeech_{}.onnx".format(args.continue_from))

    # Check that the IR is well formed
    onnx.checker.check_model(model)

    onnx.helper.printable_graph(model.graph)

    print("model checked, preparing backend!")
    rep = backend.prepare(model, device="CPU")  # or "CPU"

    print("running inference!")

    # Hard coded input dim
    inputs = np.random.randn(16, 1, 161, 129).astype(np.float32)

    start = time.time()
    outputs = rep.run(inputs)
    print("time used: {}".format(time.time() - start))
    # To run networks with more than one input, pass a tuple
    # rather than a single numpy ndarray.
    print(outputs[0]) 
开发者ID:mlperf,项目名称:inference,代码行数:25,代码来源:convert_onnx.py

示例4: stylize_onnx_caffe2

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def stylize_onnx_caffe2(content_image, args):
    """
    Read ONNX model and run it using Caffe2
    """

    assert not args.export_onnx

    import onnx
    import onnx_caffe2.backend

    model = onnx.load(args.model)

    prepared_backend = onnx_caffe2.backend.prepare(model, device='CUDA' if args.cuda else 'CPU')
    inp = {model.graph.input[0].name: content_image.numpy()}
    c2_out = prepared_backend.run(inp)[0]

    return torch.from_numpy(c2_out) 
开发者ID:pytorch,项目名称:examples,代码行数:19,代码来源:neural_style.py

示例5: convert_models

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def convert_models(args):
    # Quick format check
    model_extension = get_extension(args.model)
    if (args.model_type == "onnx" or model_extension == "onnx"):
        print("Input model is already ONNX model. Skipping conversion.")
        if args.model != args.output_onnx_path:
            copyfile(args.model, args.output_onnx_path)
        return

    if converters.get(args.model_type) == None:
        raise ValueError('Model type {} is not currently supported. \n\
            Please select one of the following model types -\n\
                cntk, coreml, keras, pytorch, scikit-learn, tensorflow'.format(args.model_type))

    suffix = suffix_format_map.get(model_extension)

    if suffix != None and suffix != args.model_type:
        raise ValueError('model with extension {} do not come from {}'.format(model_extension, args.model_type))

    # Find the corresponding converter for current model
    converter = converters.get(args.model_type)
    # Run converter
    converter(args) 
开发者ID:microsoft,项目名称:OLive,代码行数:25,代码来源:onnx_converter.py

示例6: test_dcgan

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def test_dcgan(self):
        # dcgan is flaky on some seeds, see:
        # https://github.com/ProjectToffee/onnx/pull/70
        torch.manual_seed(1)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1)

        netD = dcgan._netD(1)
        netD.apply(dcgan.weights_init)
        input = Variable(torch.randn(BATCH_SIZE, 3, dcgan.imgsz, dcgan.imgsz))
        self.run_model_test(netD, train=False, batch_size=BATCH_SIZE,
                            input=input)

        netG = dcgan._netG(1)
        netG.apply(dcgan.weights_init)
        state_dict = model_zoo.load_url(model_urls['dcgan_b'], progress=False)
        # state_dict = model_zoo.load_url(model_urls['dcgan_f'], progress=False)
        noise = Variable(
            torch.randn(BATCH_SIZE, dcgan.nz, 1, 1).normal_(0, 1))
        self.run_model_test(netG, train=False, batch_size=BATCH_SIZE,
                            input=noise, state_dict=state_dict, rtol=1e-2, atol=1e-6) 
开发者ID:onnxbot,项目名称:onnx-fb-universe,代码行数:23,代码来源:test_caffe2.py

示例7: test_symbolic_override_nested

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def test_symbolic_override_nested(self):
        def symb(g, x, y):
            assert isinstance(x, torch._C.Value)
            assert isinstance(y[0], torch._C.Value)
            assert isinstance(y[1], torch._C.Value)
            return g.op('Sum', x, y[0], y[1]), (
                g.op('Neg', x), g.op('Neg', y[0]))

        @torch.onnx.symbolic_override_first_arg_based(symb)
        def foo(x, y):
            return x + y[0] + y[1], (-x, -y[0])

        class BigModule(torch.nn.Module):
            def forward(self, x, y):
                return foo(x, y)

        inp = (Variable(torch.FloatTensor([1])),
               (Variable(torch.FloatTensor([2])),
                Variable(torch.FloatTensor([3]))))
        BigModule()(*inp)
        self.assertONNX(BigModule(), inp) 
开发者ID:onnxbot,项目名称:onnx-fb-universe,代码行数:23,代码来源:test_operators.py

示例8: predict

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def predict():
    list_of_files = glob.glob('output/models/*')  # * means all if need specific format then *.csv
    model_path = max(list_of_files, key=os.path.getctime)

    print("Generating ONNX from model:", model_path)
    model = torch.load(model_path)

    input_sequences = [
        "SRSLVISTINQISEDSKEFYFTLDNGKTMFPSNSQAWGGEKFENGQRAFVIFNELEQPVNGYDYNIQVRDITKVLTKEIVTMDDEE" \
        "NTEEKIGDDKINATYMWISKDKKYLTIEFQYYSTHSEDKKHFLNLVINNKDNTDDEYINLEFRHNSERDSPDHLGEGYVSFKLDKI" \
        "EEQIEGKKGLNIRVRTLYDGIKNYKVQFP"]

    input_sequences_encoded = list(torch.IntTensor(encode_primary_string(aa))
                                   for aa in input_sequences)

    print("Exporting to ONNX...")

    output_path = "./tests/output/openprotein.onnx"
    onnx_from_model(model, input_sequences_encoded, output_path)

    print("Wrote ONNX to", output_path) 
开发者ID:biolib,项目名称:openprotein,代码行数:23,代码来源:onnx_export.py

示例9: torch2tvm_module

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def torch2tvm_module(torch_module: torch.nn.Module, torch_inputs: Tuple[torch.Tensor, ...], target):
    torch_module.eval()
    input_names = []
    input_shapes = {}
    with torch.no_grad():
        for index, torch_input in enumerate(torch_inputs):
            name = "i" + str(index)
            input_names.append(name)
            input_shapes[name] = torch_input.shape
        buffer = io.BytesIO()
        torch.onnx.export(torch_module, torch_inputs, buffer, input_names=input_names, output_names=["o" + str(i) for i in range(len(torch_inputs))])
        outs = torch_module(*torch_inputs)
        buffer.seek(0, 0)
        onnx_model = onnx.load_model(buffer)
        relay_module, params = tvm.relay.frontend.from_onnx(onnx_model, shape=input_shapes)
    with tvm.relay.build_config(opt_level=3):
        graph, tvm_module, params = tvm.relay.build(relay_module, target, params=params)
    return graph, tvm_module, params 
开发者ID:mit-han-lab,项目名称:temporal-shift-module,代码行数:20,代码来源:main.py

示例10: convert_to_onnx

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def convert_to_onnx(model, input_shape, output_file, input_names, output_names):
    """Convert PyTorch model to ONNX and check the resulting onnx model"""

    output_file.parent.mkdir(parents=True, exist_ok=True)
    model.eval()
    dummy_input = torch.randn(input_shape)
    model(dummy_input)
    torch.onnx.export(model, dummy_input, str(output_file), verbose=False,
                      input_names=input_names.split(','), output_names=output_names.split(','))

    # Model check after conversion
    model = onnx.load(str(output_file))
    try:
        onnx.checker.check_model(model)
        print('ONNX check passed successfully.')
    except onnx.onnx_cpp2py_export.checker.ValidationError as exc:
        sys.exit('ONNX check failed with error: ' + str(exc)) 
开发者ID:opencv,项目名称:open_model_zoo,代码行数:19,代码来源:pytorch_to_onnx.py

示例11: export_onnx_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def export_onnx_model(model, inputs, passes):
    """Trace and export a model to onnx format. Modified from
    https://github.com/facebookresearch/detectron2/

    Args:
        model (nn.Module):
        inputs (tuple[args]): the model will be called by `model(*inputs)`
        passes (None or list[str]): the optimization passed for ONNX model

    Returns:
        an onnx model
    """
    assert isinstance(model, torch.nn.Module)

    # make sure all modules are in eval mode, onnx may change the training
    # state of the module if the states are not consistent
    def _check_eval(module):
        assert not module.training

    model.apply(_check_eval)

    # Export the model to ONNX
    with torch.no_grad():
        with io.BytesIO() as f:
            torch.onnx.export(
                model,
                inputs,
                f,
                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                # verbose=True,  # NOTE: uncomment this for debugging
                # export_params=True,
            )
            onnx_model = onnx.load_from_string(f.getvalue())

    # Apply ONNX's Optimization
    if passes is not None:
        all_passes = optimizer.get_available_passes()
        assert all(p in all_passes for p in passes), \
            f'Only {all_passes} are supported'
    onnx_model = optimizer.optimize(onnx_model, passes)
    return onnx_model 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:43,代码来源:pytorch2onnx.py

示例12: export_onnx

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def export_onnx(path, batch_size, seq_len):
    print('The model is also exported in ONNX format at {}'.
          format(os.path.realpath(args.onnx_export)))
    model.eval()
    dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device)
    hidden = model.init_hidden(batch_size)
    torch.onnx.export(model, (dummy_input, hidden), path)


# Loop over epochs. 
开发者ID:L0SG,项目名称:relational-rnn-pytorch,代码行数:12,代码来源:train_rmc.py

示例13: export

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def export(dir):
    dummy_input = Variable(torch.randn(1, 3, 4, 4))
    model = broadcast_mul()
    model.eval()
    torch.save(model.state_dict(),os.path.join(dir,"broadcast_mul.pth"))
    onnx.export(model, dummy_input,os.path.join(dir,"broadcast_mul.onnx"), verbose=True) 
开发者ID:MTlab,项目名称:onnx2caffe,代码行数:8,代码来源:broadcast_mul.py

示例14: export

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def export(dir):
    file_path = os.path.realpath(__file__)
    file_dir = os.path.dirname(file_path)
    dummy_input = Variable(torch.randn(1, 3, 32, 32))
    model = ResNet34()
    # model = load_network(model,os.path.join(file_dir,'..','model','pose_v02.pth'))
    model.eval()
    torch.save(model.state_dict(),os.path.join(dir,"resnet.pth"))
    onnx.export(model, dummy_input,os.path.join(dir,"resnet.onnx"), verbose=True) 
开发者ID:MTlab,项目名称:onnx2caffe,代码行数:11,代码来源:resnet.py

示例15: export

# 需要导入模块: import torch [as 别名]
# 或者: from torch import onnx [as 别名]
def export(dir):
    file_path = os.path.realpath(__file__)
    file_dir = os.path.dirname(file_path)
    dummy_input = Variable(torch.randn(1, 3, 32, 32))
    model = GoogLeNet()
    # model = load_network(model,os.path.join(file_dir,'..','model','pose_v02.pth'))
    model.eval()
    torch.save(model.state_dict(),os.path.join(dir,"googlenet.pth"))
    onnx.export(model, dummy_input,os.path.join(dir,"googlenet.onnx"), verbose=True) 
开发者ID:MTlab,项目名称:onnx2caffe,代码行数:11,代码来源:googlenet.py


注:本文中的torch.onnx方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。