当前位置: 首页>>代码示例>>Python>>正文


Python tensorrt.infer方法代码示例

本文整理汇总了Python中tensorrt.infer方法的典型用法代码示例。如果您正苦于以下问题:Python tensorrt.infer方法的具体用法?Python tensorrt.infer怎么用?Python tensorrt.infer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorrt的用法示例。


在下文中一共展示了tensorrt.infer方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def __init__(self, engine, idx_or_name, max_batch_size):
        if isinstance(idx_or_name, string_types):
            self.name = idx_or_name
            self.index  = engine.get_binding_index(self.name)
            if self.index == -1:
                raise IndexError("Binding name not found: %s" % self.name)
        else:
            self.index = idx_or_name
            self.name  = engine.get_binding_name(self.index)
            if self.name is None:
                raise IndexError("Binding index out of range: %i" % self.index)
        self.is_input = engine.binding_is_input(self.index)
        dtype = engine.get_binding_data_type(self.index)
        dtype_map = {trt.infer.DataType_kFLOAT: np.float32,
                     trt.infer.DataType_kHALF:  np.float16,
                     trt.infer.DataType_kINT8:  np.int8}
        if hasattr(trt.infer, 'DataType_kINT32'):
            dtype_map[trt.infer.DataType_kINT32] = np.int32
        self.dtype = dtype_map[dtype]
        shape = engine.get_binding_dimensions(self.index).shape()
        self.shape = (max_batch_size,) + shape
        self._host_buf   = None
        self._device_buf = None 
开发者ID:mlperf,项目名称:training_results_v0.6,代码行数:25,代码来源:tensorrt_engine.py

示例2: __init__

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def __init__(self, timing_iter):
        trt.infer.Profiler.__init__(self)
        self.timing_iterations = timing_iter
        self.profile = [] 
开发者ID:aimuch,项目名称:iAI,代码行数:6,代码来源:googlenet.py

示例3: main

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def main():
    path = dir_path = os.path.dirname(os.path.realpath(__file__))

    print("Building and running GPU inference for GoogleNet, N=4")
    #Convert caffe model to TensorRT engine
    engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
        MODEL_PROTOTXT,
        CAFFEMODEL,
        10,
        16 << 20,
        OUTPUT_LAYERS,
        trt.infer.DataType.FLOAT)

    runtime = trt.infer.create_infer_runtime(G_LOGGER)

    print("Bindings after deserializing")
    for bi in range(engine.get_nb_bindings()):
        if engine.binding_is_input(bi) == True:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Input")
        else:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Output")

    time_inference(engine, BATCH_SIZE)

    engine.destroy()
    runtime.destroy()

    G_PROFILER.print_layer_times()

    print("Done")

    return 
开发者ID:aimuch,项目名称:iAI,代码行数:34,代码来源:googlenet.py

示例4: main

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def main():
    path = dir_path = os.path.dirname(os.path.realpath(__file__))

    print("Building and running GPU inference for GoogleNet, N=4")
    #Convert caffe model to TensorRT engine
    engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
        MODEL_PROTOTXT, 
        CAFFEMODEL, 
        10, 
        16 << 20, 
        OUTPUT_LAYERS,
        trt.infer.DataType.FLOAT)

    runtime = trt.infer.create_infer_runtime(G_LOGGER)

    print("Bindings after deserializing")
    for bi in range(engine.get_nb_bindings()):
        if engine.binding_is_input(bi) == True:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Input")
        else:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Output")	
    
    time_inference(engine, BATCH_SIZE)
    
    engine.destroy()
    runtime.destroy()

    G_PROFILER.print_layer_times()

    print("Done")

    return 
开发者ID:aimuch,项目名称:iAI,代码行数:34,代码来源:googlenet.py

示例5: __init__

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def __init__(self, sev):
        trt.infer.Logger.__init__(self)
        self.severity = sev 
开发者ID:aimuch,项目名称:iAI,代码行数:5,代码来源:caffe_mnist.py

示例6: main

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def main():
    #Convert caffe model to TensorRT engine
    runtime = trt.infer.create_infer_runtime(G_LOGGER)
    engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
        MODEL_PROTOTXT,
        CAFFE_MODEL,
        1,
        1 << 20,
        OUTPUT_LAYERS,
        trt.infer.DataType.FLOAT)

    #get random test case
    rand_file = randint(0, 9)
    img = get_testcase(DATA + str(rand_file) + '.pgm')

    print("Test case: " + str(rand_file))
    data = apply_mean(img, IMAGE_MEAN)

    context = engine.create_execution_context()

    out = infer(context, data, OUTPUT_SIZE, 1)

    print("Prediction: " + str(np.argmax(out)))

    context.destroy()
    engine.destroy()
    runtime.destroy() 
开发者ID:aimuch,项目名称:iAI,代码行数:29,代码来源:caffe_mnist.py

示例7: convert_to_datatype

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def convert_to_datatype(v):
    if v==8:
        return trt.infer.DataType.INT8
    elif v==16:
        return trt.infer.DataType.HALF
    elif v==32:
        return trt.infer.DataType.FLOAT
    else:
        print("ERROR: Invalid model data type bit depth: " + str(v))
        return trt.infer.DataType.INT8 
开发者ID:CUHKSZ-TQL,项目名称:EverybodyDanceNow_reproduce_pytorch,代码行数:12,代码来源:run_engine.py

示例8: run_onnx

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def run_onnx(onnx_file, data_type, bs, inp):
    # Create onnx_config
    apex = onnxparser.create_onnxconfig()
    apex.set_model_file_name(onnx_file)
    apex.set_model_dtype(convert_to_datatype(data_type))

     # create parser
    trt_parser = onnxparser.create_onnxparser(apex)
    assert(trt_parser)
    data_type = apex.get_model_dtype()
    onnx_filename = apex.get_model_file_name()
    trt_parser.parse(onnx_filename, data_type)
    trt_parser.report_parsing_info()
    trt_parser.convert_to_trtnetwork()
    trt_network = trt_parser.get_trtnetwork()
    assert(trt_network)

    # create infer builder
    trt_builder = trt.infer.create_infer_builder(G_LOGGER)
    trt_builder.set_max_batch_size(max_batch_size)
    trt_builder.set_max_workspace_size(max_workspace_size)
    
    if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
        print("-------------------  Running FP16 -----------------------------")
        trt_builder.set_half2_mode(True)
    elif (apex.get_model_dtype() == trt.infer.DataType_kINT8): 
        print("-------------------  Running INT8 -----------------------------")
        trt_builder.set_int8_mode(True)
    else:
        print("-------------------  Running FP32 -----------------------------")
        
    print("----- Builder is Done -----")
    print("----- Creating Engine -----")
    trt_engine = trt_builder.build_cuda_engine(trt_network)
    print("----- Engine is built -----")
    time_inference(engine, bs, inp) 
开发者ID:CUHKSZ-TQL,项目名称:EverybodyDanceNow_reproduce_pytorch,代码行数:38,代码来源:run_engine.py

示例9: run_onnx

# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def run_onnx(onnx_file, data_type, bs, inp):
    # Create onnx_config
    apex = onnxparser.create_onnxconfig()
    apex.set_model_file_name(onnx_file)
    apex.set_model_dtype(convert_to_datatype(data_type))

     # create parser
    trt_parser = onnxparser.create_onnxparser(apex)
    assert(trt_parser)
    data_type = apex.get_model_dtype()
    onnx_filename = apex.get_model_file_name()
    trt_parser.parse(onnx_filename, data_type)
    trt_parser.report_parsing_info()
    trt_parser.convert_to_trtnetwork()
    trt_network = trt_parser.get_trtnetwork()
    assert(trt_network)

    # create infer builder
    trt_builder = trt.infer.create_infer_builder(G_LOGGER)
    trt_builder.set_max_batch_size(max_batch_size)
    trt_builder.set_max_workspace_size(max_workspace_size)

    if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
        print("-------------------  Running FP16 -----------------------------")
        trt_builder.set_half2_mode(True)
    elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
        print("-------------------  Running INT8 -----------------------------")
        trt_builder.set_int8_mode(True)
    else:
        print("-------------------  Running FP32 -----------------------------")

    print("----- Builder is Done -----")
    print("----- Creating Engine -----")
    trt_engine = trt_builder.build_cuda_engine(trt_network)
    print("----- Engine is built -----")
    time_inference(engine, bs, inp) 
开发者ID:thomasjhuang,项目名称:deep-learning-for-document-dewarping,代码行数:38,代码来源:run_engine.py


注:本文中的tensorrt.infer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。