【推理引擎】一篇关于模型推理的详细对比与学习


一篇关于模型推理的详细对比与学习

1.Intro

通过一直对模型推理和计算加速的兴趣,这些年测试过了许多模型推理相关的框架和技术。有一些专有平台的如TI的TIDL/TIDLRT,有NXP平台的APEX,NVIDIA的TensorRT。除了在端上的还有例如寒武纪,华为晟腾这样的云计算数据中心芯片提供商的云加速推理方案。这些适用于专有平台的推理框架(或叫工具包),基于对芯片架构的了解,能够极大程度的针对特定算子做出性能极高的优化。

但是,这些专有平台的缺点也极为明显,由于没有大规模的测试和适配过一些日新月异的算子搭配和变革,在有些时候也会表现出一些功能或者性能的异常。由于本身的黑盒特性,当用户发现诸如此类的问题时,如果没有开发人员或FAE的协助,很难有效的解决相关问题。

所以一些纯开放或者半开放的框架诸如TVM,NCNN,ONNXruntime等也一定程度上丰富和拓展了用户的选择,使得一些资深人员可以基于一个纯白盒的方式去拓展。

抱着对这些技术强烈的好奇,我在学习TVM的过程中,想平行对比一下在同一个网络上这些普通推理框架的性能,同时也想测试一下针对自定义算子的开发易用性。这将是一篇长学习文档。我将逐步更新我的学习过程。

  • 先展示下测试结果,再详细看之后的测试过程:

  • Life of IRModule in TensorIR

    Parse
    Convert
    Print
    schedule Transform
    Pass Transform
    Build
    Transform
    Display
    Save
    TVM Script - based python-AST allow program, write complex programs
    IRModule
    Tensor Expression - TE DSL, schedule, optimize
    Runnable Module
    Relay
    RelayIR - Relay IR 以函数的形式定义是 TVM 的中间表示, 用于定义, 优化和编译深度学习模型
    JSON-Serialization - JSON 序列化形式是将 Relay IR 模型转换为一种便于保存和传输的格式.
  • TensorIR interactive optimization flow

2.Pre-acknowledges from Internet about the different inference framework

  1. TensorRT
    Performance: TensorRT is highly optimized for NVIDIA GPUs and typically offers the best performance in terms of latency and throughput. It takes advantage of low-level GPU optimizations, including precision calibration (FP16, INT8), kernel fusion, and layer optimization.
    Ease of Use: Requires model conversion from frameworks like PyTorch to ONNX and then to TensorRT. Some additional tuning and calibration might be needed for optimal performance.

  2. ONNX Runtime
    Performance: ONNX Runtime provides good performance, especially when using the GPU execution provider. While it might not match TensorRT’s performance, it is quite efficient and easier to integrate with models already exported to ONNX.
    Ease of Use: Straightforward if you have an ONNX model. It supports various optimizations and execution providers, including NVIDIA GPUs.

  3. TVM
    Performance: TVM can offer excellent performance by compiling models specifically optimized for the target hardware. Performance can be close to or even surpass TensorRT in some cases, depending on the level of optimization and tuning.
    Ease of Use: Requires more effort to tune and optimize. The process involves compiling the model specifically for the hardware, which can be complex and time-consuming.

  4. PyTorch
    Performance: PyTorch is generally not as optimized for inference on GPUs compared to the other specialized inference engines. Performance can be improved with TorchScript and using the PyTorch native AMP (Automatic Mixed Precision) for FP16 operations.
    Ease of Use: Easiest to use if the model is already in PyTorch. Minimal code changes are needed to run the model in inference mode.
    Performance Comparison Summary

TensorRT: Typically offers the best performance on NVIDIA GPUs, especially with FP16/INT8 optimizations.

ONNX Runtime: Good performance and easier integration if you have an ONNX model.

TVM: Potentially very high performance but requires significant tuning and expertise.

PyTorch: Good for ease of use but generally not as fast as the others for GPU inference.

3. The Tests

3.1. Utils

Model: https://github.com/onnx/models/raw/b9a54e89508f101a1611cd64f4ef56b9cb62c7cf/vision/classification/resnet/model/resnet50-v2-7.onnx (we use a Resnet50v2 model,the input is 1 * 3 * 224 * 224,the output is 1 * 1000)

【附件】resnet50-v2-7.onnx.png

Input Data: https://s3.amazonaws.com/model-server/inputs/kitten.jpg (224 * 224 * 3(rgb))

【附件】kitten (1).jpg

Label: https://s3.amazonaws.com/onnx-model-zoo/synset.txt (class 1000)

3.2. OnnxRuntime

3.2.1. Float32(GPU, CPU)

  1. import onnxruntime as ort
  2. import numpy as np
  3. import time
  4. import timeit
  5. from PIL import Image
  6. # https://github.com/onnx/models/raw/b9a54e89508f101a1611cd64f4ef56b9cb62c7cf/vision/classification/resnet/model/resnet50-v2-7.onnx
  7. model_path = "resnet50-v2-7.onnx"
  8. # https://s3.amazonaws.com/model-server/inputs/kitten.jpg
  9. img_path = 'kitten.jpg'
  10. # Resize it to 224x224
  11. resized_image = Image.open(img_path).resize((224, 224))
  12. img_data = np.asarray(resized_image).astype("float32")
  13. # Our input image is in HWC layout while ONNX expects CHW input, so convert the array
  14. img_data = np.transpose(img_data, (2, 0, 1))
  15. # Normalize according to the ImageNet input specification
  16. imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
  17. imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
  18. norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
  19. # Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
  20. input_data = np.expand_dims(norm_img_data, axis=0).astype("float32")
  21. # timing_number should be bigger, if it is 10, no difference between cpu and gpu
  22. timing_number = 200
  23. timing_repeat = 10
  24. # cpu execute time
  25. def run_model_cpu(model_path, input_data):
  26. session_cpu = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
  27. print("Available cpu providers: ", session_cpu.get_providers())
  28. input_name = session_cpu.get_inputs()[0].name
  29. cpu_result = (
  30. np.array(timeit.Timer(lambda: session_cpu.run(None, {input_name: input_data})).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  31. )
  32. cpu_result = {"mean": np.mean(cpu_result), "median": np.median(cpu_result), "std": np.std(cpu_result)}
  33. print("Onnxruntime_CPU_Float32: %s" % (cpu_result))
  34. # gpu execute time
  35. def run_model_gpu(model_path, input_data):
  36. try:
  37. session_gpu = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
  38. print("Available GPU providers: ", session_gpu.get_providers())
  39. except Exception as e:
  40. print("Error initializing GPU session: ", e)
  41. return
  42. input_name = session_gpu.get_inputs()[0].name
  43. gpu_result = (
  44. np.array(timeit.Timer(lambda: session_gpu.run(None, {input_name: input_data})).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  45. )
  46. gpu_result = {"mean": np.mean(gpu_result), "median": np.median(gpu_result), "std": np.std(gpu_result)}
  47. print("Onnxruntime_GPU_Float32: %s" % (gpu_result))
  48. run_model_cpu(model_path, input_data)
  49. run_model_gpu(model_path, input_data)
  • Execution Result:
  1. # CPU
  2. class='n02123045 tabby, tabby cat' with probability=0.621104
  3. class='n02123159 tiger cat' with probability=0.356378
  4. class='n02124075 Egyptian cat' with probability=0.019712
  5. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  6. class='n04040759 radiator' with probability=0.000262
  7. # GPU
  8. class='n02123045 tabby, tabby cat' with probability=0.621105
  9. class='n02123159 tiger cat' with probability=0.356377
  10. class='n02124075 Egyptian cat' with probability=0.019713
  11. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  12. class='n04040759 radiator' with probability=0.000262
  13. Available cpu providers: ['CPUExecutionProvider']
  14. Onnxruntime_CPU_Float32: {'mean': 0.009986212425981647, 'median': 0.009982698982494185, 'std': 1.576711878851109e-05}
  15. Available GPU providers: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  16. Onnxruntime_GPU_Float32: {'mean': 0.004069052441569511, 'median': 0.0037474105950968803, 'std': 0.0009607391846988234}

3.2.2. Int8(GPU, CPU)

After thorough performance analysis, we’ve observed that on CPUs, the application of INT8 quantization indeed yields significant performance improvements. This technique accelerates processing speed by reducing the data width required during model computations, all while maintaining the model’s accuracy.

Considering that the ResNet50 model is relatively small compared to other models, its performance on CPUs and GPUs exhibits no significant difference. This consistency in efficiency provides us with flexibility, allowing us to seamlessly switch between different hardware configurations without the concern of performance loss.

Further experiments have corroborated this finding. When testing the ResNet152 model with similar benchmarks, we also observed no discernible difference in inference speed between CPUs and GPUs. This suggests that for models of this size, the choice of hardware platform depends more on available resources and specific application scenarios rather than sheer performance considerations.

  • Model Quant(Onnx Model Quant is only Available for int8)
  1. from onnxruntime.quantization import quantize_dynamic, QuantType
  2. help(QuantType)

  • Quant the model from float32 to int8
  1. import onnx
  2. import onnxruntime as ort
  3. import numpy as np
  4. import timeit
  5. from onnxruntime.quantization import quantize_dynamic, QuantType
  6. # Define paths
  7. model_fp32 = "resnet50-v2-7.onnx"
  8. model_int8 = "resnet50-v2-7-int8.onnx"
  9. img_path = 'kitten.jpg'
  10. # Load and update the model to opset 11
  11. model = onnx.load(model_fp32)
  12. model_opset_version = model.opset_import[0].version
  13. if model_opset_version < 11:
  14. print(f"The original model opset version is {model_opset_version}, updating to opset 11.")
  15. model = onnx.version_converter.convert_version(model, 11)
  16. onnx.save(model, "resnet50-v2-11.onnx")
  17. model_fp32 = "resnet50-v2-11.onnx"
  18. # Quantize the model
  19. quantized_model = quantize_dynamic(model_fp32, model_int8, weight_type=QuantType.QUInt8)
  20. print(f"Quantized model saved to {model_int8}")
  • Just Load the int8 model
  1. model_path = "resnet50-v2-7-int8.onnx"
  • Execution Result:
  1. # CPU Resnet50
  2. class='n02123045 tabby, tabby cat' with probability=0.657020
  3. class='n02123159 tiger cat' with probability=0.316742
  4. class='n02124075 Egyptian cat' with probability=0.023232
  5. class='n02129604 tiger, Panthera tigris' with probability=0.001320
  6. class='n04040759 radiator' with probability=0.000253
  7. # GPU Resnet50
  8. class='n02123045 tabby, tabby cat' with probability=0.647389
  9. class='n02123159 tiger cat' with probability=0.325191
  10. class='n02124075 Egyptian cat' with probability=0.024315
  11. class='n02129604 tiger, Panthera tigris' with probability=0.001383
  12. class='n04040759 radiator' with probability=0.000234
  13. # CPU Resnet152
  14. class='n02124075 Egyptian cat' with probability=0.618835
  15. class='n02123045 tabby, tabby cat' with probability=0.089756
  16. class='n04589890 window screen' with probability=0.055205
  17. class='n04590129 window shade' with probability=0.054228
  18. class='n02123159 tiger cat' with probability=0.041171
  19. # GPU Resnet152
  20. class='n02124075 Egyptian cat' with probability=0.616358
  21. class='n02123045 tabby, tabby cat' with probability=0.091070
  22. class='n04589890 window screen' with probability=0.055855
  23. class='n04590129 window shade' with probability=0.054603
  24. class='n02123159 tiger cat' with probability=0.042334
  25. Available cpu providers: ['CPUExecutionProvider']
  26. Onnxruntime_CPU_Int8: {'mean': 0.011375669833010759, 'median': 0.011376627020072192, 'std': 1.180472312469979e-05}
  27. Available GPU providers: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  28. Onnxruntime_GPU_Int8: {'mean': 0.013842492688942002, 'median': 0.013839954254799523, 'std': 2.1459357287547657e-05}
  29. Available cpu providers: ['CPUExecutionProvider']
  30. Onnxruntime_CPU_ResNet152_Int8: {'mean': 0.03108119218947832, 'median': 0.031068197362474168, 'std': 3.971259752378018e-05}
  31. Available GPU providers: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  32. Onnxruntime_GPU_ResNet152_Int8: {'mean': 0.03715013448652462, 'median': 0.03713770236005075, 'std': 3.551800251648068e-05}

3.3. Pytorch

3.3.1. Float32(GPU, CPU)

  1. import numpy as np
  2. import timeit
  3. import torch
  4. from torchvision import models, transforms
  5. from PIL import Image
  6. img_path = 'kitten.jpg'
  7. resized_image = Image.open(img_path).resize((224, 224))
  8. preprocess = transforms.Compose([
  9. transforms.ToTensor(),
  10. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  11. ])
  12. input_tensor = preprocess(resized_image)
  13. input_batch = input_tensor.unsqueeze(0)
  14. model = models.resnet50(pretrained=True)
  15. model.eval()
  16. timing_number = 200
  17. timing_repeat = 10
  18. def run_model_cpu(model, input_batch):
  19. model.to('cpu')
  20. with torch.no_grad():
  21. cpu_result = (
  22. np.array(timeit.Timer(lambda: model(input_batch)).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  23. )
  24. cpu_result = {"mean": np.mean(cpu_result), "median": np.median(cpu_result), "std": np.std(cpu_result)}
  25. print("Torch_CPU_Float32: %s" % (cpu_result))
  26. def run_model_gpu(model, input_batch):
  27. if torch.cuda.is_available():
  28. model.to('cuda')
  29. input_batch_gpu = input_batch.to('cuda')
  30. with torch.no_grad():
  31. gpu_result = (
  32. np.array(timeit.Timer(lambda: model(input_batch_gpu)).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  33. )
  34. gpu_result = {"mean": np.mean(gpu_result), "median": np.median(gpu_result), "std": np.std(gpu_result)}
  35. print("Torch_GPU_Float32: %s" % (gpu_result))
  36. else:
  37. print("CUDA Invalid!")
  38. run_model_cpu(model, input_batch)
  39. run_model_gpu(model, input_batch)
  • Execution Result:
  1. # CPU
  2. class='n02123045 tabby, tabby cat' with probability=0.476912
  3. class='n02123159 tiger cat' with probability=0.465774
  4. class='n02124075 Egyptian cat' with probability=0.046519
  5. class='n03958227 plastic bag' with probability=0.002096
  6. class='n02971356 carton' with probability=0.000678
  7. # GPU
  8. class='n02123045 tabby, tabby cat' with probability=0.476538
  9. class='n02123159 tiger cat' with probability=0.466136
  10. class='n02124075 Egyptian cat' with probability=0.046531
  11. class='n03958227 plastic bag' with probability=0.002095
  12. class='n02971356 carton' with probability=0.000678
  13. Torch_CPU_Float32: {'mean': 0.021870234759990125, 'median': 0.021858943702391116, 'std': 3.487536414872911e-05}
  14. Torch_GPU_Float32: {'mean': 0.0018875277980696411, 'median': 0.0018826104051549919, 'std': 2.818227506480741e-05}

3.4. TVM

3.4.1. Float32 (cpu, cpu-auto-tune)

  • float32
  1. import onnx
  2. from tvm.contrib.download import download_testdata
  3. from PIL import Image
  4. import numpy as np
  5. import tvm.relay as relay
  6. import tvm
  7. from tvm.contrib import graph_executor
  8. onnx_model = onnx.load('resnet50-v2-7.onnx')
  9. img_path = 'kitten.jpg'
  10. # Resize it to 224x224
  11. resized_image = Image.open(img_path).resize((224, 224))
  12. img_data = np.asarray(resized_image).astype("float32")
  13. # Our input image is in HWC layout while ONNX expects CHW input, so convert the array
  14. img_data = np.transpose(img_data, (2, 0, 1))
  15. # Normalize according to the ImageNet input specification
  16. imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
  17. imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
  18. norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
  19. # Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
  20. img_data = np.expand_dims(norm_img_data, axis=0).astype("float32")
  21. print(img_data.shape)
  22. print(img_data.dtype)
  23. # The input name may vary across model types. You can use a tool
  24. # like Netron to check input names
  25. input_name = "data"
  26. target = "llvm"
  27. shape_dict = {input_name: img_data.shape}
  28. mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
  29. with tvm.transform.PassContext(opt_level=3):
  30. lib = relay.build(mod, target=target, params=params)
  31. dev = tvm.device(str(target), 0)
  32. module = graph_executor.GraphModule(lib["default"](dev))
  33. dtype = "float32"
  34. module.set_input(input_name, img_data)
  35. module.run()
  36. output_shape = (1, 1000)
  37. tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
  38. from scipy.special import softmax
  39. labels_path = "synset.txt"
  40. with open(labels_path, "r") as f:
  41. labels = [l.rstrip() for l in f]
  42. # Open the output and read the output tensor
  43. scores = softmax(tvm_output)
  44. scores = np.squeeze(scores)
  45. ranks = np.argsort(scores)[::-1]
  46. for rank in ranks[0:5]:
  47. print("class='%s' with probability=%f" % (labels[rank], scores[rank]))
  48. import timeit
  49. timing_number = 10
  50. timing_repeat = 10
  51. unoptimized = (
  52. np.array(timeit.Timer(lambda: module.run()).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  53. )
  54. unoptimized = {
  55. "mean": np.mean(unoptimized),
  56. "median": np.median(unoptimized),
  57. "std": np.std(unoptimized),
  58. }
  59. import tvm.auto_scheduler as auto_scheduler
  60. from tvm.autotvm.tuner import XGBTuner
  61. from tvm import autotvm
  62. number = 10
  63. repeat = 1
  64. min_repeat_ms = 0 # since we're tuning on a CPU, can be set to 0
  65. timeout = 10 # in seconds
  66. # create a TVM runner
  67. runner = autotvm.LocalRunner(
  68. number=number,
  69. repeat=repeat,
  70. timeout=timeout,
  71. min_repeat_ms=min_repeat_ms,
  72. enable_cpu_cache_flush=True,
  73. )
  74. tuning_option = {
  75. "tuner": "xgb",
  76. "trials": 20,
  77. "early_stopping": 100,
  78. "measure_option": autotvm.measure_option(
  79. builder=autotvm.LocalBuilder(build_func="default"), runner=runner
  80. ),
  81. "tuning_records": "resnet-50-v2-autotuning.json",
  82. }
  83. # begin by extracting the tasks from the onnx model
  84. tasks = autotvm.task.extract_from_program(mod["main"], target=target, params=params)
  85. print(tasks)
  86. # Tune the extracted tasks sequentially.
  87. for i, task in enumerate(tasks):
  88. prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
  89. # choose tuner
  90. tuner = "xgb"
  91. # create tuner
  92. if tuner == "xgb":
  93. tuner_obj = XGBTuner(task, loss_type="reg")
  94. elif tuner == "xgb_knob":
  95. tuner_obj = XGBTuner(task, loss_type="reg", feature_type="knob")
  96. elif tuner == "xgb_itervar":
  97. tuner_obj = XGBTuner(task, loss_type="reg", feature_type="itervar")
  98. elif tuner == "xgb_curve":
  99. tuner_obj = XGBTuner(task, loss_type="reg", feature_type="curve")
  100. elif tuner == "xgb_rank":
  101. tuner_obj = XGBTuner(task, loss_type="rank")
  102. elif tuner == "xgb_rank_knob":
  103. tuner_obj = XGBTuner(task, loss_type="rank", feature_type="knob")
  104. elif tuner == "xgb_rank_itervar":
  105. tuner_obj = XGBTuner(task, loss_type="rank", feature_type="itervar")
  106. elif tuner == "xgb_rank_curve":
  107. tuner_obj = XGBTuner(task, loss_type="rank", feature_type="curve")
  108. elif tuner == "xgb_rank_binary":
  109. tuner_obj = XGBTuner(task, loss_type="rank-binary")
  110. elif tuner == "xgb_rank_binary_knob":
  111. tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="knob")
  112. elif tuner == "xgb_rank_binary_itervar":
  113. tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="itervar")
  114. elif tuner == "xgb_rank_binary_curve":
  115. tuner_obj = XGBTuner(task, loss_type="rank-binary", feature_type="curve")
  116. elif tuner == "ga":
  117. tuner_obj = GATuner(task, pop_size=50)
  118. elif tuner == "random":
  119. tuner_obj = RandomTuner(task)
  120. elif tuner == "gridsearch":
  121. tuner_obj = GridSearchTuner(task)
  122. else:
  123. raise ValueError("Invalid tuner: " + tuner)
  124. tuner_obj.tune(
  125. n_trial=min(tuning_option["trials"], len(task.config_space)),
  126. early_stopping=tuning_option["early_stopping"],
  127. measure_option=tuning_option["measure_option"],
  128. callbacks=[
  129. autotvm.callback.progress_bar(tuning_option["trials"], prefix=prefix),
  130. autotvm.callback.log_to_file(tuning_option["tuning_records"]),
  131. ],
  132. )
  133. with autotvm.apply_history_best(tuning_option["tuning_records"]):
  134. with tvm.transform.PassContext(opt_level=3, config={}):
  135. lib = relay.build(mod, target=target, params=params)
  136. dev = tvm.device(str(target), 0)
  137. module = graph_executor.GraphModule(lib["default"](dev))
  138. dtype = "float32"
  139. module.set_input(input_name, img_data)
  140. module.run()
  141. output_shape = (1, 1000)
  142. tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
  143. scores = softmax(tvm_output)
  144. scores = np.squeeze(scores)
  145. ranks = np.argsort(scores)[::-1]
  146. for rank in ranks[0:5]:
  147. print("class='%s' with probability=%f" % (labels[rank], scores[rank]))
  148. import timeit
  149. timing_number = 10
  150. timing_repeat = 10
  151. optimized = (
  152. np.array(timeit.Timer(lambda: module.run()).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  153. )
  154. optimized = {"mean": np.mean(optimized), "median": np.median(optimized), "std": np.std(optimized)}
  155. print("optimized: %s" % (optimized))
  156. print("unoptimized: %s" % (unoptimized))
  • Execute Result:
  1. # No tune
  2. class='n02123045 tabby, tabby cat' with probability=0.621103
  3. class='n02123159 tiger cat' with probability=0.356379
  4. class='n02124075 Egyptian cat' with probability=0.019712
  5. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  6. class='n04040759 radiator' with probability=0.000262
  7. [Task 1/25] Current/Best: 283.00/ 348.36 GFLOPS | Progress: (20/20) | 8.42 s Done.
  8. [Task 2/25] Current/Best: 74.32/ 278.44 GFLOPS | Progress: (20/20) | 5.78 s Done.
  9. [Task 3/25] Current/Best: 116.92/ 337.35 GFLOPS | Progress: (20/20) | 6.30 s Done.
  10. [Task 4/25] Current/Best: 304.23/ 328.07 GFLOPS | Progress: (20/20) | 7.33 s Done.
  11. [Task 5/25] Current/Best: 189.33/ 316.91 GFLOPS | Progress: (20/20) | 6.24 s Done.
  12. [Task 6/25] Current/Best: 221.74/ 314.72 GFLOPS | Progress: (20/20) | 7.33 s Done.
  13. [Task 7/25] Current/Best: 255.23/ 307.63 GFLOPS | Progress: (20/20) | 6.55 s Done.
  14. [Task 8/25] Current/Best: 26.67/ 299.85 GFLOPS | Progress: (20/20) | 8.58 s Done.
  15. [Task 9/25] Current/Best: 174.97/ 347.91 GFLOPS | Progress: (20/20) | 8.16 s Done.
  16. [Task 10/25] Current/Best: 154.63/ 280.17 GFLOPS | Progress: (20/20) | 6.48 s Done.
  17. [Task 11/25] Current/Best: 278.62/ 338.75 GFLOPS | Progress: (20/20) | 6.02 s Done.
  18. [Task 12/25] Current/Best: 224.97/ 308.45 GFLOPS | Progress: (20/20) | 7.64 s Done.
  19. [Task 13/25] Current/Best: 261.67/ 308.99 GFLOPS | Progress: (20/20) | 6.87 s Done.
  20. [Task 14/25] Current/Best: 92.82/ 312.59 GFLOPS | Progress: (20/20) | 12.86 s Done.
  21. [Task 15/25] Current/Best: 214.17/ 325.88 GFLOPS | Progress: (20/20) | 7.92 s Done.
  22. [Task 16/25] Current/Best: 207.14/ 267.99 GFLOPS | Progress: (20/20) | 6.40 s Done.
  23. [Task 17/25] Current/Best: 135.56/ 323.36 GFLOPS | Progress: (20/20) | 6.73 s Done.
  24. [Task 18/25] Current/Best: 199.88/ 292.03 GFLOPS | Progress: (20/20) | 7.68 s Done.
  25. [Task 19/25] Current/Best: 63.22/ 279.44 GFLOPS | Progress: (20/20) | 7.64 s Done.
  26. [Task 20/25] Current/Best: 16.96/ 287.96 GFLOPS | Progress: (20/20) | 12.55 s Done.
  27. [Task 21/25] Current/Best: 121.29/ 333.84 GFLOPS | Progress: (20/20) | 12.13 s Done.
  28. [Task 22/25] Current/Best: 258.80/ 267.89 GFLOPS | Progress: (20/20) | 6.79 s Done.
  29. [Task 23/25] Current/Best: 305.98/ 305.98 GFLOPS | Progress: (20/20) | 7.84 s Done.
  30. [Task 25/25] Current/Best: 0.00/ 0.00 GFLOPS | Progress: (0/20) | 0.00 s Done.
  31. [Task 25/25] Current/Best: 56.92/ 78.51 GFLOPS | Progress: (20/20) | 15.41 s Done.
  32. # Tune
  33. class='n02123045 tabby, tabby cat' with probability=0.621104
  34. class='n02123159 tiger cat' with probability=0.356378
  35. class='n02124075 Egyptian cat' with probability=0.019712
  36. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  37. class='n04040759 radiator' with probability=0.000262
  38. TVM_CPU_Float32_Auto_Tune_XGB: {'mean': 0.03430713728850241, 'median': 0.033569534790149194, 'std': 0.004160319905261305}
  39. TVM_CPU_Float32: {'mean': 0.03740469979154296, 'median': 0.03763830485739163, 'std': 0.00046793406370928035}

3.4.2. Float32 (GPU, gpu-Auto-tune)

  1. # just change target from llvm to cuda, use the code above
  2. target = "cuda"
  3. # or
  4. target = "cuda -arch=sm_89" # depends on your platform
  5. # use `nvcc --list-gpu-arch` to check
  • Execution Result:
  1. # No tune
  2. class='n02123045 tabby, tabby cat' with probability=0.621104
  3. class='n02123159 tiger cat' with probability=0.356378
  4. class='n02124075 Egyptian cat' with probability=0.019712
  5. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  6. class='n04040759 radiator' with probability=0.000262
  7. # Tune
  8. class='n02123045 tabby, tabby cat' with probability=0.621104
  9. class='n02123159 tiger cat' with probability=0.356378
  10. class='n02124075 Egyptian cat' with probability=0.019712
  11. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  12. class='n04040759 radiator' with probability=0.000262
  13. TVM_GPU_Float32_Auto_Tune_XGB: {'mean': 0.0067301002040039744, 'median': 0.006774019840086112, 'std': 0.00013180002687155013}
  14. TVM_GPU_Float32: {'mean': 0.0018996601785183885, 'median': 0.0019048284900782164, 'std': 1.3217539930054235e-05}

3.4.3 Int8, cpu, tvm默认仅支持int8

  1. import onnx
  2. from tvm.contrib.download import download_testdata
  3. from PIL import Image
  4. import numpy as np
  5. import tvm
  6. from tvm import relay
  7. from tvm.contrib import graph_executor
  8. from scipy.special import softmax
  9. import timeit
  10. onnx_model = onnx.load('resnet50-v2-7.onnx')
  11. img_path = 'kitten.jpg'
  12. resized_image = Image.open(img_path).resize((224, 224))
  13. img_data = np.asarray(resized_image).astype("float32")
  14. img_data = np.transpose(img_data, (2, 0, 1))
  15. imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
  16. imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
  17. norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
  18. img_data = np.expand_dims(norm_img_data, axis=0).astype("float32")
  19. input_name = "data"
  20. shape_dict = {input_name: img_data.shape}
  21. mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
  22. calibration_samples = [tvm.nd.array(img_data)]
  23. with tvm.transform.PassContext(opt_level=3):
  24. with relay.quantize.qconfig(calibrate_mode='global_scale', global_scale=8.0, skip_conv_layers=[]):
  25. quantized_mod = relay.quantize.quantize(mod, params, dataset=calibration_samples)
  26. target = "llvm"
  27. with tvm.transform.PassContext(opt_level=3):
  28. lib = relay.build(quantized_mod, target=target, params=params)
  29. dev = tvm.device(str(target), 0)
  30. module = graph_executor.GraphModule(lib["default"](dev))
  31. module.set_input(input_name, img_data)
  32. module.run()
  33. output_shape = (1, 1000)
  34. tvm_output = module.get_output(0, tvm.nd.empty(output_shape)).numpy()
  35. labels_path = "synset.txt"
  36. with open(labels_path, "r") as f:
  37. labels = [l.rstrip() for l in f]
  38. # Open the output and read the output tensor
  39. scores = softmax(tvm_output)
  40. scores = np.squeeze(scores)
  41. ranks = np.argsort(scores)[::-1]
  42. for rank in ranks[0:5]:
  43. print("class='%s' with probability=%f" % (labels[rank], scores[rank]))
  44. timing_number = 200
  45. timing_repeat = 10
  46. unoptimized = (
  47. np.array(timeit.Timer(lambda: module.run()).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  48. )
  49. unoptimized = {
  50. "mean": np.mean(unoptimized),
  51. "median": np.median(unoptimized),
  52. "std": np.std(unoptimized),
  53. }
  54. print("TVM_CPU_Int8:", unoptimized)
  • Execution Result:
  1. One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
  2. class='n02123045 tabby, tabby cat' with probability=0.404191
  3. class='n02123159 tiger cat' with probability=0.210446
  4. class='n02124075 Egyptian cat' with probability=0.179177
  5. class='n03223299 doormat, welcome mat' with probability=0.033031
  6. class='n02127052 lynx, catamount' with probability=0.011398
  7. TVM_CPU_Int8: {'mean': 0.12497943410099835, 'median': 0.1239671477425145, 'std': 0.004604382417633903}

3.4.4. Int8 gpu tvm默认仅支持int8)

  1. target = 'cuda -arch=sm_89'
  • Execution Result:
  1. One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
  2. class='n02123045 tabby, tabby cat' with probability=0.404193
  3. class='n02123159 tiger cat' with probability=0.210444
  4. class='n02124075 Egyptian cat' with probability=0.179176
  5. class='n03223299 doormat, welcome mat' with probability=0.033031
  6. class='n02127052 lynx, catamount' with probability=0.011398
  7. TVM_GPU_Int8: {'mean': 0.05043668361147866, 'median': 0.049358997289964464, 'std': 0.002297650553083921}

3.5. TensorRT

3.5.1. Float32_GPU

  1. import tensorrt as trt
  2. import pycuda.driver as cuda
  3. import pycuda.autoinit
  4. import numpy as np
  5. import onnx
  6. import timeit
  7. from PIL import Image
  8. def check_result(output):
  9. from scipy.special import softmax
  10. labels_path = "synset.txt"
  11. with open(labels_path, "r") as f:
  12. labels = [l.rstrip() for l in f]
  13. scores = softmax(output)
  14. scores = np.squeeze(scores)
  15. ranks = np.argsort(scores)[::-1]
  16. for rank in ranks[0:5]:
  17. print("class='%s' with probability=%f" % (labels[rank], scores[rank]))
  18. onnx_file_path = 'resnet50-v2-7.onnx'
  19. # Got an warning omit it
  20. onnx_model = onnx.load(onnx_file_path)
  21. # Add trt logger
  22. TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
  23. runtime = trt.Runtime(TRT_LOGGER)
  24. ## Avoid In node -1 (importModel): INVALID_VALUE: Assertion failed: !_importer_ctx.network()->hasImplicitBatchDimension() && "This version of the ONNX parser only supports TensorRT INetworkDefinitions with an explicit batch dimension. Please ensure the network was created using the EXPLICIT_BATCH NetworkDefinitionCreationFlag." Error
  25. explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  26. with trt.Builder(TRT_LOGGER) as builder, builder.create_network(explicit_batch) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
  27. # parser the onnx model
  28. with open(onnx_file_path, 'rb') as model:
  29. if not parser.parse(model.read()):
  30. for error in range(parser.num_errors):
  31. print("Error {}: {}".format(error, parser.get_error(error)))
  32. else:
  33. print("Model parsed successfully")
  34. # Check the network layers if need
  35. # for i in range(network.num_layers):
  36. # layer = network.get_layer(i)
  37. # print(layer.name, layer.get_output(0).shape)
  38. # config & build the trt engine, the warning can be omit
  39. config = builder.create_builder_config()
  40. config.max_workspace_size = 1 << 30 # 1GB
  41. engine = builder.build_engine(network, config)
  42. if engine is None:
  43. raise RuntimeError("Failed to create the TensorRT engine")
  44. # save the engine file if need (optinal), it contains the optimized result of the model graph
  45. engine_file_path = 'resnet50.trt'
  46. with open(engine_file_path, 'wb') as f:
  47. f.write(engine.serialize())
  48. # load engine
  49. with open(engine_file_path, 'rb') as f:
  50. engine = runtime.deserialize_cuda_engine(f.read())
  51. # create context, context is like a handle
  52. context = engine.create_execution_context()
  53. img_path = 'kitten.jpg'
  54. # Resize it to 224x224
  55. resized_image = Image.open(img_path).resize((224, 224))
  56. img_data = np.asarray(resized_image).astype("float32")
  57. # Our input image is in HWC layout while ONNX expects CHW input, so convert the array
  58. img_data = np.transpose(img_data, (2, 0, 1))
  59. # Normalize according to the ImageNet input specification
  60. imagenet_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
  61. imagenet_stddev = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
  62. norm_img_data = (img_data / 255 - imagenet_mean) / imagenet_stddev
  63. # Add the batch dimension, as we are expecting 4-dimensional input: NCHW.
  64. host_input = np.expand_dims(norm_img_data, axis=0).astype("float32")
  65. # make sure the numpy is contiguous
  66. host_input = np.ascontiguousarray(host_input)
  67. # prepare input
  68. device_input = cuda.mem_alloc(host_input.nbytes)
  69. cuda.memcpy_htod(device_input, host_input)
  70. # prepare output
  71. output_shape = context.get_binding_shape(1)
  72. host_output = np.empty(output_shape, dtype=np.float32)
  73. device_output = cuda.mem_alloc(host_output.nbytes)
  74. # check the shape
  75. assert context.all_binding_shapes_specified
  76. assert context.get_binding_shape(0) == host_input.shape
  77. # execute
  78. context.execute_v2(bindings=[int(device_input), int(device_output)])
  79. cuda.memcpy_dtoh(host_output, device_output)
  80. check_result(host_output)
  81. timing_number = 200
  82. timing_repeat = 10
  83. gpu_result = (
  84. np.array(timeit.Timer(lambda: context.execute_v2(bindings=[int(device_input), int(device_output)])).repeat(repeat=timing_repeat, number=timing_number)) / timing_number
  85. )
  86. gpu_result = {"mean": np.mean(gpu_result), "median": np.median(gpu_result), "std": np.std(gpu_result)}
  87. print("Tensorrt_GPU_Float32: %s" % (gpu_result))
  • Execution Result:
  1. [05/22/2024-15:26:37] [TRT] [W] onnx2trt_utils.cpp:369: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
  2. Model parsed successfully
  3. test_tensorrt_resnet50_float32_gpu.py:50: DeprecationWarning: Use set_memory_pool_limit instead.
  4. config.max_workspace_size = 1 << 30 # 1GB
  5. test_tensorrt_resnet50_float32_gpu.py:51: DeprecationWarning: Use build_serialized_network instead.
  6. engine = builder.build_engine(network, config)
  7. [05/22/2024-15:26:38] [TRT] [W] TensorRT was linked against cuDNN 8.4.1 but loaded cuDNN 8.4.0
  8. [05/22/2024-15:26:45] [TRT] [W] TensorRT was linked against cuDNN 8.4.1 but loaded cuDNN 8.4.0
  9. [05/22/2024-15:26:45] [TRT] [W] The getMaxBatchSize() function should not be used with an engine built from a network created with NetworkDefinitionCreationFlag::kEXPLICIT_BATCH flag. This function will always return 1.
  10. [05/22/2024-15:26:45] [TRT] [W] The getMaxBatchSize() function should not be used with an engine built from a network created with NetworkDefinitionCreationFlag::kEXPLICIT_BATCH flag. This function will always return 1.
  11. [05/22/2024-15:26:45] [TRT] [W] TensorRT was linked against cuDNN 8.4.1 but loaded cuDNN 8.4.0
  12. [05/22/2024-15:26:45] [TRT] [W] TensorRT was linked against cuDNN 8.4.1 but loaded cuDNN 8.4.0
  13. class='n02123045 tabby, tabby cat' with probability=0.621154
  14. class='n02123159 tiger cat' with probability=0.356326
  15. class='n02124075 Egyptian cat' with probability=0.019713
  16. class='n02129604 tiger, Panthera tigris' with probability=0.001215
  17. class='n04040759 radiator' with probability=0.000262
  18. Tensorrt_GPU_Float32: {'mean': 0.0011052313830005004, 'median': 0.0011054065149801316, 'std': 6.103668141089059e-07}

3.6. Summary

Platform Precision Mean (s) Median (s) Std (s)
ONNXRuntime (CPU) Float32 0.00998621 0.00998270 0.00001577
ONNXRuntime (GPU) Float32 0.00406905 0.00374741 0.00096074
ONNXRuntime (CPU) INT8 0.01137567 0.01137663 0.00001180
ONNXRuntime (GPU) INT8 0.01384249 0.01383995 0.00002146
PyTorch (CPU) Float32 0.02187023 0.02185894 0.00003488
PyTorch (GPU) Float32 0.00188753 0.00188261 0.00002818
TVM (CPU, No tune) Float32 0.03740470 0.03763830 0.00046793
TVM (CPU, Auto-tune) Float32 0.03430714 0.03356953 0.00416032
TVM (GPU, No tune) Float32 0.00189966 0.00190483 0.00001322
TVM (GPU, Auto-tune) Float32 0.00673010 0.00677402 0.00013180
TVM (CPU) INT8 0.12497943 0.12396715 0.00460438
TVM (GPU) INT8 0.05043668 0.04935900 0.00229765
TensorRT (GPU) Float32 0.00110523 0.00110541 0.00000061


dingfeng 2024年5月28日 16:07 371 0 条评论 收藏文档
评论

暂无评论,我来发表第一篇评论!

发表评论