本文整理汇总了Python中tensorrt.infer方法的典型用法代码示例。如果您正苦于以下问题:Python tensorrt.infer方法的具体用法?Python tensorrt.infer怎么用?Python tensorrt.infer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorrt
的用法示例。
在下文中一共展示了tensorrt.infer方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def __init__(self, engine, idx_or_name, max_batch_size):
if isinstance(idx_or_name, string_types):
self.name = idx_or_name
self.index = engine.get_binding_index(self.name)
if self.index == -1:
raise IndexError("Binding name not found: %s" % self.name)
else:
self.index = idx_or_name
self.name = engine.get_binding_name(self.index)
if self.name is None:
raise IndexError("Binding index out of range: %i" % self.index)
self.is_input = engine.binding_is_input(self.index)
dtype = engine.get_binding_data_type(self.index)
dtype_map = {trt.infer.DataType_kFLOAT: np.float32,
trt.infer.DataType_kHALF: np.float16,
trt.infer.DataType_kINT8: np.int8}
if hasattr(trt.infer, 'DataType_kINT32'):
dtype_map[trt.infer.DataType_kINT32] = np.int32
self.dtype = dtype_map[dtype]
shape = engine.get_binding_dimensions(self.index).shape()
self.shape = (max_batch_size,) + shape
self._host_buf = None
self._device_buf = None
示例2: __init__
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def __init__(self, timing_iter):
trt.infer.Profiler.__init__(self)
self.timing_iterations = timing_iter
self.profile = []
示例3: main
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def main():
path = dir_path = os.path.dirname(os.path.realpath(__file__))
print("Building and running GPU inference for GoogleNet, N=4")
#Convert caffe model to TensorRT engine
engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
MODEL_PROTOTXT,
CAFFEMODEL,
10,
16 << 20,
OUTPUT_LAYERS,
trt.infer.DataType.FLOAT)
runtime = trt.infer.create_infer_runtime(G_LOGGER)
print("Bindings after deserializing")
for bi in range(engine.get_nb_bindings()):
if engine.binding_is_input(bi) == True:
print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Input")
else:
print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Output")
time_inference(engine, BATCH_SIZE)
engine.destroy()
runtime.destroy()
G_PROFILER.print_layer_times()
print("Done")
return
示例4: main
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def main():
path = dir_path = os.path.dirname(os.path.realpath(__file__))
print("Building and running GPU inference for GoogleNet, N=4")
#Convert caffe model to TensorRT engine
engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
MODEL_PROTOTXT,
CAFFEMODEL,
10,
16 << 20,
OUTPUT_LAYERS,
trt.infer.DataType.FLOAT)
runtime = trt.infer.create_infer_runtime(G_LOGGER)
print("Bindings after deserializing")
for bi in range(engine.get_nb_bindings()):
if engine.binding_is_input(bi) == True:
print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Input")
else:
print("Binding " + str(bi) + " (" + engine.get_binding_name(bi) + "): Output")
time_inference(engine, BATCH_SIZE)
engine.destroy()
runtime.destroy()
G_PROFILER.print_layer_times()
print("Done")
return
示例5: __init__
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def __init__(self, sev):
trt.infer.Logger.__init__(self)
self.severity = sev
示例6: main
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def main():
#Convert caffe model to TensorRT engine
runtime = trt.infer.create_infer_runtime(G_LOGGER)
engine = trt.utils.caffe_to_trt_engine(G_LOGGER,
MODEL_PROTOTXT,
CAFFE_MODEL,
1,
1 << 20,
OUTPUT_LAYERS,
trt.infer.DataType.FLOAT)
#get random test case
rand_file = randint(0, 9)
img = get_testcase(DATA + str(rand_file) + '.pgm')
print("Test case: " + str(rand_file))
data = apply_mean(img, IMAGE_MEAN)
context = engine.create_execution_context()
out = infer(context, data, OUTPUT_SIZE, 1)
print("Prediction: " + str(np.argmax(out)))
context.destroy()
engine.destroy()
runtime.destroy()
示例7: convert_to_datatype
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def convert_to_datatype(v):
if v==8:
return trt.infer.DataType.INT8
elif v==16:
return trt.infer.DataType.HALF
elif v==32:
return trt.infer.DataType.FLOAT
else:
print("ERROR: Invalid model data type bit depth: " + str(v))
return trt.infer.DataType.INT8
示例8: run_onnx
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def run_onnx(onnx_file, data_type, bs, inp):
# Create onnx_config
apex = onnxparser.create_onnxconfig()
apex.set_model_file_name(onnx_file)
apex.set_model_dtype(convert_to_datatype(data_type))
# create parser
trt_parser = onnxparser.create_onnxparser(apex)
assert(trt_parser)
data_type = apex.get_model_dtype()
onnx_filename = apex.get_model_file_name()
trt_parser.parse(onnx_filename, data_type)
trt_parser.report_parsing_info()
trt_parser.convert_to_trtnetwork()
trt_network = trt_parser.get_trtnetwork()
assert(trt_network)
# create infer builder
trt_builder = trt.infer.create_infer_builder(G_LOGGER)
trt_builder.set_max_batch_size(max_batch_size)
trt_builder.set_max_workspace_size(max_workspace_size)
if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
print("------------------- Running FP16 -----------------------------")
trt_builder.set_half2_mode(True)
elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
print("------------------- Running INT8 -----------------------------")
trt_builder.set_int8_mode(True)
else:
print("------------------- Running FP32 -----------------------------")
print("----- Builder is Done -----")
print("----- Creating Engine -----")
trt_engine = trt_builder.build_cuda_engine(trt_network)
print("----- Engine is built -----")
time_inference(engine, bs, inp)
示例9: run_onnx
# 需要导入模块: import tensorrt [as 别名]
# 或者: from tensorrt import infer [as 别名]
def run_onnx(onnx_file, data_type, bs, inp):
# Create onnx_config
apex = onnxparser.create_onnxconfig()
apex.set_model_file_name(onnx_file)
apex.set_model_dtype(convert_to_datatype(data_type))
# create parser
trt_parser = onnxparser.create_onnxparser(apex)
assert(trt_parser)
data_type = apex.get_model_dtype()
onnx_filename = apex.get_model_file_name()
trt_parser.parse(onnx_filename, data_type)
trt_parser.report_parsing_info()
trt_parser.convert_to_trtnetwork()
trt_network = trt_parser.get_trtnetwork()
assert(trt_network)
# create infer builder
trt_builder = trt.infer.create_infer_builder(G_LOGGER)
trt_builder.set_max_batch_size(max_batch_size)
trt_builder.set_max_workspace_size(max_workspace_size)
if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
print("------------------- Running FP16 -----------------------------")
trt_builder.set_half2_mode(True)
elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
print("------------------- Running INT8 -----------------------------")
trt_builder.set_int8_mode(True)
else:
print("------------------- Running FP32 -----------------------------")
print("----- Builder is Done -----")
print("----- Creating Engine -----")
trt_engine = trt_builder.build_cuda_engine(trt_network)
print("----- Engine is built -----")
time_inference(engine, bs, inp)