onnx-mlir

Logo

MLIR 编译器基础设施中的 ONNX 模型表示和参考降低

在 GitHub 上查看项目 onnx/onnx-mlir

操作指南

使用 Python 进行推理
使用 C/C++ 进行推理
使用 Java 进行推理

参考

ONNX 方言
OMTensor C99 运行时 API
OMTensorList C99 运行时 API
OMTensor Java 运行时 API
OMTensorList Java 运行时 API
生成 ONNX 方言
关于文档

开发

添加操作
测试指南
错误处理
命令行选项
插桩
常量传播
添加加速器

工具

工具

RunONNXModel.py
DocCheck

本项目由 onnx 维护

托管于 GitHub Pages — 主题作者 orderedlist

ONNX 操作的常量传播

本文档描述了 --constprop-onnx pass,它用于对 ONNX 方言中的操作进行常量传播。

源代码.

示例

假设有以下代码

func @foo() -> tensor<1xf32> {
  %0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
  %3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
  "std.return"(%4) : (tensor<1xf32>) -> ()
}

如果我们调用 onnx-mlir-op --constprop-onnx,我们将得到

func @foo() -> tensor<1xf32> {
  %0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  "std.return"(%0) : (tensor<1xf32>) -> ()
}

注意

ONNXConstantOp 使用 MLIR DenseElementsAttr 来存储常量值。值得注意的是,一旦创建了 DenseElementsAttr,它将一直存在并消耗内存直到编译结束。在示例中,三个 ONNXConstantOp 中的所有三个 DenseElementsAttr 都存在直到编译结束。特别是,由折叠两个 ONNXAddOp 生成的两个 ONNXConstantOp 中的两个中间 DenseElementsAttr 也存在。对于一个真实的模型,中间 DenseElementsAttr 的数量会迅速增加,这导致编译期间的内存占用很大。

为了避免在 --constprop-onnx 过程中为中间的 ONNXConstantOp 创建过多的 DenseElementsAttr,我们设计了一种机制,它动态地为中间的 ONNXConstantOp 分配和释放缓冲区,并且只在常量传播和其他 ONNX 方言 passes 之后、就在降低到 Krnl(或任何其他目标方言)之前才创建 DenseElementsAttr。

这通过一个自定义属性 DisposableElementsAttr 实现,它作为 DenseElementsAttr 的替代品,用于非复杂标量元素类型的常见情况:布尔类型、整数类型和浮点类型。DisposableElementsAttr 实现了与 DenseElementsAttr 相同的 ElementsAttr 接口,并且在大多数情况下,它们在功能上是相同的,周围的代码无需区分。它只需要使用 OnnxElementsAttrBuilder 类和 ElementsAttrHelper 函数来构建和访问 ElementsAttr 实例,以获得内存占用和性能方面的收益。

DisposableElementsAttr 缓冲区的释放发生在 DisposableGarbageCollector 中,它由 PassManager 在“module” passes 之间运行(这些 passes 保证会“停止所有活动”,没有其他 passes 并行执行),作为一种“插桩”。

DisposableElementsAttr 还提供了其他内存和速度方面的优势,这些优势在类源代码文件中的注释中有所概述,并在 2022 年 11 月的演示文稿中进行了解释,链接自会议 wiki 页面

编写常量传播规则

我们使用 MLIR 声明式重写规则 (DRR) 来编写常量传播的模式。用于定义模式的 DRR 定义如下所示

class Pattern<
   dag sourcePattern,
   list<dag> resultPatterns,
   list<dag> additionalConstraints = [],
   list<dag> supplementalPatterns = [],
   dag benefitsAdded = (addBenefit 0)
>;

可以在此处找到关于 DRR 的更多信息。

现在,我们来看一个简单的示例,它为 ONNXAddOp 添加常量传播。

步骤 1:编写 DRR 模式

我们首先向ConstProp.td 添加一个模式。

// Constant Propagation for Add
def AddConstProp : Pat<
    // source patten: From add(lhs, rhs).
    (ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
                      (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
    // result pattern: To c = lhs + rhs
    (CreateAddOfTwoConst $addOp, $lhs, $rhs),
    // Additional constraints: if both lhs and rhs are dense constants.
    [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;

上述模式将替换一个输入为常量的 ONNXAddOp,通过在编译时将输入相加,生成一个新的常量。要检查输入是否为常量,仅使用 ONNXConstantOp 是不够的,因为常量张量可以是稀疏的,而我们目前只支持密集常量张量。我们还需要通过使用 IsFromDenseONNXConstantOp 来检查是否为密集常量张量。

在结果模式中,为了生成一个 ONNXConstantOp,我们将在编译时将 lhsrhs 相加,并发出一个 ONNXConstantOp。为了最小化内存占用,这个 ONNXConstantOp 使用了 DisposableElementsAttr 而不是传统的 DenseElementsAttr。

函数 CreateAddOfTwoConst 将在编译时执行加法并返回一个 ONNXConstantOp。

def CreateAddOfTwoConst :
   NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;

步骤 2:为输入和结果准备数组缓冲区

模式中的函数 CreateAddOfTwoConst 调用ConstProp.cpp 中的 ConstPropElementwiseBinary,其内容如下。

template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
    Value replacingValue, Value lhsValue, Value rhsValue) {
  ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
  Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());

  // Get lhs and rhs ElementsAttr from the values' defining constant ops.
  ElementsAttr lhs = getConstValueElements(lhsValue);
  ElementsAttr rhs = getConstValueElements(rhsValue);

  Type operandsElemType = lhs.getElementType();
  assert(operandsElemType == rhs.getElementType() &&
         "all element-wise binary ops have matching operands element types");
  OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
  ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
      combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));

  // Construct and return a new ONNXConstantOp with the resultElements attribute.
  return createReplacingConstantOp(rewriter, replacingValue, resultElements)
      .getResult();
}

其中 OnnxElementsAttrBuilder.combine(...) 根据需要广播 lhs 和 rhs 元素,并构建一个新的 (Disposable) ElementsAttr,其元素是二进制函数 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType) 元素级应用的结果,它将 ElementwiseBinaryOp ONNX op 映射到 C++ 运算符。

TODO:描述如何为新的 ops 添加 OnnxElementsAttrBuilder builder 方法

关于常量传播的更多信息,请参阅ConstProp.tdConstProp.cpp