实现 ONNX 后端

什么是 ONNX 后端

ONNX 后端是一个可以运行 ONNX 模型的库。由于已经存在许多深度学习框架,您可能不需要从头开始创建一切。相反,您可能需要创建一个转换器,将 ONNX 模型转换为相应的框架特定表示,然后将执行委托给该框架。例如,onnx-caffe2 (作为 caffe2 的一部分)onnx-coremlonnx-tensorflow 都实现了转换器。

统一的后端接口

ONNX 在 onnx/backend/base.py 中定义了一个统一的 (Python) 后端接口。

该接口中有三个核心概念:DeviceBackendBackendRep

  • Device 是各种硬件(例如 CPU、GPU 等)的轻量级抽象。

  • Backend 是将 ONNX 模型和输入进行计算并返回输出的实体。

    对于一次性执行,用户可以使用 run_noderun_model 快速获得结果。

    对于重复执行,用户应该使用 prepare,其中 Backend 完成重复执行模型的所有准备工作(例如,加载初始化器),并返回一个 BackendRep 句柄。

  • BackendRepBackend 在准备好重复执行模型后返回的句柄。然后,用户将输入传递给 BackendReprun 函数以检索相应的结果。

请注意,即使 ONNX 统一后端接口是用 Python 定义的,您的后端也不必用 Python 实现。例如,您的后端可以用 C++ 创建,并使用 pybind11cython 等工具来实现该接口。

ONNX 后端测试

ONNX 提供了一个标准的后端测试套件来协助后端实现的验证。强烈建议每个 ONNX 后端都运行此测试。

将 ONNX 后端测试套件集成到您的 CI 中很简单。以下是一些示例,展示了后端如何执行集成:

如果您安装了 pytest,您可以在运行 ONNX 后端测试后获得覆盖率报告,以了解您的后端表现如何。

---------- onnx coverage: ----------
Operators (passed/loaded/total): 21/21/70
------------------------------------
╒════════════════════╤════════════════════╕
│ Operator           │ Attributes         │
│                    │ (name: #values)    │
╞════════════════════╪════════════════════╡
│ Slice              │ axes: 2            │
│                    │ ends: 3            │
│                    │ starts: 3          │
├────────────────────┼────────────────────┤
│ Constant           │ value: 1           │
├────────────────────┼────────────────────┤
│ Concat             │ axis: 0            │
├────────────────────┼────────────────────┤
│ Conv               │ group: 6           │
│                    │ kernel_shape: 5    │
│                    │ pads: 4            │
│                    │ strides: 3         │
│                    │ auto_pad: 0        │
│                    │ dilations: 0       │
├────────────────────┼────────────────────┤
│ Reshape            │ shape: 9           │
├────────────────────┼────────────────────┤
│ BatchNormalization │ consumed_inputs: 1 │
│                    │ epsilon: 2         │
│                    │ is_test: 1         │
│                    │ momentum: 0        │
│                    │ spatial: 0         │
├────────────────────┼────────────────────┤
│ Dropout            │ is_test: 1         │
│                    │ ratio: 2           │
├────────────────────┼────────────────────┤
│ MaxPool            │ kernel_shape: 2    │
│                    │ pads: 3            │
│                    │ strides: 2         │
│                    │ auto_pad: 0        │
│                    │ dilations: 0       │
├────────────────────┼────────────────────┤
│ Transpose          │ perm: 1            │
├────────────────────┼────────────────────┤
│ MatMul             │ No attributes      │
├────────────────────┼────────────────────┤
│ Relu               │ No attributes      │
├────────────────────┼────────────────────┤
│ LRN                │ alpha: 2           │
│                    │ beta: 1            │
│                    │ bias: 2            │
│                    │ size: 1            │
├────────────────────┼────────────────────┤
│ Add                │ axis: 1            │
│                    │ broadcast: 1       │
├────────────────────┼────────────────────┤
│ Abs                │ No attributes      │
├────────────────────┼────────────────────┤
│ Pad                │ mode: 3            │
│                    │ paddings: 2        │
│                    │ value: 1           │
├────────────────────┼────────────────────┤
│ Softmax            │ axis: 0            │
├────────────────────┼────────────────────┤
│ GlobalAveragePool  │ No attributes      │
├────────────────────┼────────────────────┤
│ Mul                │ axis: 1            │
│                    │ broadcast: 1       │
├────────────────────┼────────────────────┤
│ Sum                │ No attributes      │
├────────────────────┼────────────────────┤
│ Gemm               │ broadcast: 1       │
│                    │ transB: 1          │
│                    │ alpha: 0           │
│                    │ beta: 0            │
│                    │ transA: 0          │
├────────────────────┼────────────────────┤
│ AveragePool        │ kernel_shape: 3    │
│                    │ pads: 3            │
│                    │ strides: 2         │
│                    │ auto_pad: 0        │
╘════════════════════╧════════════════════╛

行中数字 Operators (passed/loaded/total): 21/21/70 表示您的后端在所有测试用例中覆盖了 21 个运算符并已通过,ONNX 后端测试的所有测试用例覆盖了 21 个运算符,而 ONNX 总共有 70 个运算符。