ONNX 模型在 MLIR 编译器基础设施中的表示和参考下推
此项目由 onnx 维护
托管于 GitHub Pages — 主题来自 orderedlist
总的来说,onnx-mlir 将自定义加速器视为插件,可以在构建 onnx-mlir 和编译模型时开启/关闭。主要通过 cmake
进行处理,本文档将概述其流程。
除了本文档,NNPA 加速器 可用作已在 onnx-mlir 中部署的示例。
在 onnx-mlir 中,加速器的所有代码都应放在 src/Accelerators
下的单独文件夹中。因此,支持加速器的第一步是在 src/Accelerators
中为此创建一个文件夹。
文件夹名称将用作 onnx-mlir 中的加速器名称。具体来说,它用于
cmake
构建加速器文件夹内的代码,onnx-mlir
命令时为加速器编译模型,以及onnx-mlir-opt
命令时启用与加速器相关的 pass。文件夹的内容取决于每个加速器。但是,我们建议尽可能遵循 onnx-mlir
的根文件夹的结构。这有助于在整个项目中保持一致性。
要在 onnx-mlir 中构建加速器,请在构建 onnx-mlir 时使用 cmake 变量 ONNX_MLIR_ACCELERATORS
。ONNX_MLIR_ACCELERATORS
接受分号分隔的加速器名称列表。例如,
$ cd build
$ cmake .. -DONNX_MLIR_ACCELERATORS='accel1;accel2'
请注意,列表应加引号。
编译器命令 onnx-mlir
有一个选项,即 --maccel
,用于为选定的加速器编译模型。为每个加速器添加一个 --maccel=accel_name
条目。例如,
$ onnx-mlir --maccel=accel1 --maccel=accel2 model.onnx
只有构建的加速器才能与 --maccel
一起使用。
可以通过 onnx-mlir-opt
命令使用 --maccel
选项运行或测试加速器定义的 pass,该选项类似于 onnx-mlir
中的 --maccel
(参见第 1.2 节)。例如,调用加速器 accel1
定义的 --optimize-data-layout
pass
$ onnx-mlir-opt --maccel=accel1 --optimize-data-layout model.mlir
只有构建的加速器才能与 --maccel
一起使用。
每个加速器都需要定义一些宏。这些需要包含在 onnx_mlir::accel::Accelerator 中。这些宏是
INSTRUMENTSTAGE_ENUM_<accel_name>
INSTRUMENTSTAGE_CL_ENUM_<accel_name>
PROFILEIR_CL_ENUM_<accel_name>
OPTREPORT_ENUM_<accel_name>
OPTREPORT_CL_ENUM_<accel_name>
将 <accel_name>
替换为加速器的名称,例如,如果您的加速器名为 ACCEL1
,则使用
#define INSTRUMENTSTAGE_ENUM_ACCEL1
#define INSTRUMENTSTAGE_CL_ENUM_ACCEL1
#define PROFILEIR_CL_ENUM_ACCEL1
#define OPTREPORT_ENUM_ACCEL1
#define OPTREPORT_CL_ENUM_ACCEL1
在 MLIR 中编写代码通常涉及设计方言和 pass。支持加速器也是如此。因此,将加速器代码集成到 onnx-mlir 中就是注册 onnx-mlir 中的方言和 pass。
我们提供了一个基类 onnx_mlir::accel::Accelerator,用户可以从中定义一个派生类并编写钩子来注册方言和 pass。
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir driver
//===--------------------------------------------------------------------===//
/// Add the transformations necessary to support the accelerator.
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm,
onnx_mlir::EmissionTargetType &emissionTarget) const = 0;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
/// Register the MLIR dialects required to support an accelerator.
virtual void registerDialects(mlir::DialectRegistry ®istry) const = 0;
/// Register accelerator transformation passes to make available as
/// command line options.
virtual void registerPasses(int optLevel) const = 0;
//===--------------------------------------------------------------------===//
// Hooks for both onnx-mlir and onnx-mlir-opt drivers
//===--------------------------------------------------------------------===//
/// Configure passes for the accelerator.
virtual void configurePasses() const = 0;
//===--------------------------------------------------------------------===//
// Hooks for onnx-to-krnl pass
//===--------------------------------------------------------------------===//
/// Convert TensorType to MemRefType.
/// Acccelators may have special versions of TensorType. If not, override this
/// method and return nullptr.
virtual mlir::MemRefType convertTensorTypeToMemRefType(
const mlir::TensorType tensorType) const = 0;
/// Define conversion target to be used with ONNXToKrnl.
virtual void conversionTargetONNXToKrnl(
mlir::ConversionTarget &target) const = 0;
/// Define rewrite patterns to be used with ONNXToKrnl.
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const = 0;
//===--------------------------------------------------------------------===//
// Hooks for krnl-to-llvm pass
//===--------------------------------------------------------------------===//
/// Define conversion target to be used with KrnlToLLVM.
virtual void conversionTargetKrnlToLLVM(
mlir::ConversionTarget &target) const = 0;
/// Define rewrite patterns to be used with KrnlToLLVM.
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns,
mlir::LLVMTypeConverter &typeConverter, mlir::MLIRContext *ctx) const = 0;
虽然 onnx-mlir 中有很多 pass,但我们只为 onnx-to-krnl
和 krnl-to-llvm
这两个 pass 提供了钩子。原因是,原则上它们是 onnx-mlir 中的第一个和最后一个 pass。onnx-to-krnl
pass 是我们可以决定哪些 ONNX 运算符将在主机上运行(通过将它们降低到 Krnl 方言)或在加速器上运行(通过将它们降低到为加速器定义的方言)的地方。krnl-to-llvm
pass 是我们将 Krnl 和加速器运算符降低到 LLVM 方言的地方,例如生成汇编代码或简单地为加速器调用外部 API。onnx-to-krnl
和 krnl-to-llvm
之间可以有任何方言和加速器的 pass。
例如,对于 NNPA 加速器,我们定义了 ZHigh 方言 用于 onnx-to-krnl
,并定义了 ZLow 方言 用于 krnl-to-llvm
。
加速器的测试应放在 test 文件夹中。具体来说,