當前位置: 首頁>>代碼示例>>Python>>正文


Python tensorrt.infer方法代碼示例

本文整理匯總了Python中tensorrt.infer方法的典型用法代碼示例。如果您正苦於以下問題:Python tensorrt.infer方法的具體用法?Python tensorrt.infer怎麽用?Python tensorrt.infer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorrt的用法示例。


在下文中一共展示了tensorrt.infer方法的9個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def __init__(self, engine, idx_or_name, max_batch_size):
        if isinstance(idx_or_name, string_types):
            self.name = idx_or_name
            self.index  = engine.get_binding_index(self.name)
            if self.index == -1:
                raise IndexError("Binding name not found: %s" % self.name)
        else:
            self.index = idx_or_name
            self.name  = engine.get_binding_name(self.index)
            if self.name is None:
                raise IndexError("Binding index out of range: %i" % self.index)
        self.is_input = engine.binding_is_input(self.index)
        dtype = engine.get_binding_data_type(self.index)
        dtype_map = {trt.infer.DataType_kFLOAT: np.float32,
                     trt.infer.DataType_kHALF:  np.float16,
                     trt.infer.DataType_kINT8:  np.int8}
        if hasattr(trt.infer, 'DataType_kINT32'):
            dtype_map[trt.infer.DataType_kINT32] = np.int32
        self.dtype = dtype_map[dtype]
        shape = engine.get_binding_dimensions(self.index).shape()
        self.shape = (max_batch_size,) + shape
        self._host_buf   = None
        self._device_buf = None 
開發者ID:mlperf,項目名稱:training_results_v0.6,代碼行數:25,代碼來源:tensorrt_engine.py

示例2: __init__

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def __init__(self, timing_iter):
        trt.infer.Profiler.__init__(self)
        self.timing_iterations = timing_iter
        self.profile = [] 
開發者ID:aimuch,項目名稱:iAI,代碼行數:6,代碼來源:googlenet.py

示例3: main

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def main():
    path = dir_path = os.path.dirname(os.path.realpath(__file__))

    print("Building and running GPU inference for GoogleNet, N=4")
    #Convert caffe model to TensorRT engine
    engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
        MODEL_PROTOTXT,
        CAFFEMODEL,
        10,
        16 << 20,
        OUTPUT_LAYERS,
        trt.infer.DataType.FLOAT)

    runtime = trt.infer.create_infer_runtime(G_LOGGER)

    print("Bindings after deserializing")
    for bi in range(engine.get_nb_bindings()):
        if engine.binding_is_input(bi) == True:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Input")
        else:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Output")

    time_inference(engine, BATCH_SIZE)

    engine.destroy()
    runtime.destroy()

    G_PROFILER.print_layer_times()

    print("Done")

    return 
開發者ID:aimuch,項目名稱:iAI,代碼行數:34,代碼來源:googlenet.py

示例4: main

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def main():
    path = dir_path = os.path.dirname(os.path.realpath(__file__))

    print("Building and running GPU inference for GoogleNet, N=4")
    #Convert caffe model to TensorRT engine
    engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
        MODEL_PROTOTXT, 
        CAFFEMODEL, 
        10, 
        16 << 20, 
        OUTPUT_LAYERS,
        trt.infer.DataType.FLOAT)

    runtime = trt.infer.create_infer_runtime(G_LOGGER)

    print("Bindings after deserializing")
    for bi in range(engine.get_nb_bindings()):
        if engine.binding_is_input(bi) == True:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Input")
        else:
            print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Output")	
    
    time_inference(engine, BATCH_SIZE)
    
    engine.destroy()
    runtime.destroy()

    G_PROFILER.print_layer_times()

    print("Done")

    return 
開發者ID:aimuch,項目名稱:iAI,代碼行數:34,代碼來源:googlenet.py

示例5: __init__

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def __init__(self, sev):
        trt.infer.Logger.__init__(self)
        self.severity = sev 
開發者ID:aimuch,項目名稱:iAI,代碼行數:5,代碼來源:caffe_mnist.py

示例6: main

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def main():
    #Convert caffe model to TensorRT engine
    runtime = trt.infer.create_infer_runtime(G_LOGGER)
    engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
        MODEL_PROTOTXT,
        CAFFE_MODEL,
        1,
        1 << 20,
        OUTPUT_LAYERS,
        trt.infer.DataType.FLOAT)

    #get random test case
    rand_file = randint(0, 9)
    img = get_testcase(DATA + str(rand_file) + '.pgm')

    print("Test case: " + str(rand_file))
    data = apply_mean(img, IMAGE_MEAN)

    context = engine.create_execution_context()

    out = infer(context, data, OUTPUT_SIZE, 1)

    print("Prediction: " + str(np.argmax(out)))

    context.destroy()
    engine.destroy()
    runtime.destroy() 
開發者ID:aimuch,項目名稱:iAI,代碼行數:29,代碼來源:caffe_mnist.py

示例7: convert_to_datatype

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def convert_to_datatype(v):
    if v==8:
        return trt.infer.DataType.INT8
    elif v==16:
        return trt.infer.DataType.HALF
    elif v==32:
        return trt.infer.DataType.FLOAT
    else:
        print("ERROR: Invalid model data type bit depth: " + str(v))
        return trt.infer.DataType.INT8 
開發者ID:CUHKSZ-TQL,項目名稱:EverybodyDanceNow_reproduce_pytorch,代碼行數:12,代碼來源:run_engine.py

示例8: run_onnx

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def run_onnx(onnx_file, data_type, bs, inp):
    # Create onnx_config
    apex = onnxparser.create_onnxconfig()
    apex.set_model_file_name(onnx_file)
    apex.set_model_dtype(convert_to_datatype(data_type))

     # create parser
    trt_parser = onnxparser.create_onnxparser(apex)
    assert(trt_parser)
    data_type = apex.get_model_dtype()
    onnx_filename = apex.get_model_file_name()
    trt_parser.parse(onnx_filename, data_type)
    trt_parser.report_parsing_info()
    trt_parser.convert_to_trtnetwork()
    trt_network = trt_parser.get_trtnetwork()
    assert(trt_network)

    # create infer builder
    trt_builder = trt.infer.create_infer_builder(G_LOGGER)
    trt_builder.set_max_batch_size(max_batch_size)
    trt_builder.set_max_workspace_size(max_workspace_size)
    
    if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
        print("-------------------  Running FP16 -----------------------------")
        trt_builder.set_half2_mode(True)
    elif (apex.get_model_dtype() == trt.infer.DataType_kINT8): 
        print("-------------------  Running INT8 -----------------------------")
        trt_builder.set_int8_mode(True)
    else:
        print("-------------------  Running FP32 -----------------------------")
        
    print("----- Builder is Done -----")
    print("----- Creating Engine -----")
    trt_engine = trt_builder.build_cuda_engine(trt_network)
    print("----- Engine is built -----")
    time_inference(engine, bs, inp) 
開發者ID:CUHKSZ-TQL,項目名稱:EverybodyDanceNow_reproduce_pytorch,代碼行數:38,代碼來源:run_engine.py

示例9: run_onnx

# 需要導入模塊: import tensorrt [as 別名]
# 或者: from tensorrt import infer [as 別名]
def run_onnx(onnx_file, data_type, bs, inp):
    # Create onnx_config
    apex = onnxparser.create_onnxconfig()
    apex.set_model_file_name(onnx_file)
    apex.set_model_dtype(convert_to_datatype(data_type))

     # create parser
    trt_parser = onnxparser.create_onnxparser(apex)
    assert(trt_parser)
    data_type = apex.get_model_dtype()
    onnx_filename = apex.get_model_file_name()
    trt_parser.parse(onnx_filename, data_type)
    trt_parser.report_parsing_info()
    trt_parser.convert_to_trtnetwork()
    trt_network = trt_parser.get_trtnetwork()
    assert(trt_network)

    # create infer builder
    trt_builder = trt.infer.create_infer_builder(G_LOGGER)
    trt_builder.set_max_batch_size(max_batch_size)
    trt_builder.set_max_workspace_size(max_workspace_size)

    if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
        print("-------------------  Running FP16 -----------------------------")
        trt_builder.set_half2_mode(True)
    elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
        print("-------------------  Running INT8 -----------------------------")
        trt_builder.set_int8_mode(True)
    else:
        print("-------------------  Running FP32 -----------------------------")

    print("----- Builder is Done -----")
    print("----- Creating Engine -----")
    trt_engine = trt_builder.build_cuda_engine(trt_network)
    print("----- Engine is built -----")
    time_inference(engine, bs, inp) 
開發者ID:thomasjhuang,項目名稱:deep-learning-for-document-dewarping,代碼行數:38,代碼來源:run_engine.py


注:本文中的tensorrt.infer方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。