onnx-mlir

Logo

在 MLIR 编译器基础设施中表示和参考降低 ONNX 模型

在 GitHub 上查看项目 onnx/onnx-mlir

操作指南

使用 Python 进行推理
使用 C/C++ 进行推理
使用 Java 进行推理

参考文献

ONNX 方言
OMTensor C99 运行时 API
OMTensorList C99 运行时 API
OMTensor Java 运行时 API
OMTensorList Java 运行时 API
生成 ONNX 方言
关于文档

开发

添加操作
测试指南
错误处理
命令行选项
检测
常量传播
添加加速器

工具

工具

RunONNXModel.py
DocCheck

此项目由 onnx 维护

托管在 GitHub Pages 上 — 主题由 orderedlist 提供

ONNX 操作的常量传播

本文档描述了 --constprop-onnx 传递,该传递用于对 ONNX 方言中的操作进行常量传播。

源代码.

示例

给定以下代码

func @foo() -> tensor<1xf32> {
  %0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
  %3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
  "std.return"(%4) : (tensor<1xf32>) -> ()
}

如果我们调用 onnx-mlir-op --constprop-onnx,我们将得到

func @foo() -> tensor<1xf32> {
  %0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  "std.return"(%0) : (tensor<1xf32>) -> ()
}

备注

ONNXConstantOp 使用 MLIR DenseElementsAttr 存储常量值。需要注意的是,一旦创建了 DenseElementsAttr,它就会一直存在并占用内存,直到编译结束。在 示例 中,三个 ONNXConstantOp 中的所有三个 DenseElementsAttr 都存在直到编译结束。特别是,由折叠两个 ONNXAddOp 生成的两个 ONNXConstantOp 中的两个中间 DenseElementsAttr 也存在。对于真实世界的模型,中间 DenseElementsAttr 的数量会迅速增加,这会导致编译期间内存占用量很大。

为了避免在 --constprop-onnx 期间为中间 ONNXConstantOp 创建过多的 DenseElementsAttr,我们设计了一种机制,可以为中间 ONNXConstantOp 动态分配和释放缓冲区,并且仅在常量传播和其他 ONNX 方言传递之后,就在降低到 Krnl(或任何其他目标方言)之前创建 DenseElementsAttr。

这是通过一个自定义属性 DisposableElementsAttr 实现的,该属性充当非复杂标量元素类型(布尔值、整数和浮点数类型)的常见情况下的 DenseElementsAttr 的替代品。DisposableElementsAttr 实现与 DenseElementsAttr 相同的 ElementsAttr 接口,在大多数情况下,它们在功能上是相同的,周围的代码不需要进行区分。它只需要使用 OnnxElementsAttrBuilder 类和 ElementsAttrHelper 函数来构造和访问 ElementsAttr 实例,以获得内存占用和性能优势。

DisposableElementsAttr 缓冲区的释放发生在 DisposableGarbageCollector 中的编译器传递之间,该收集器由 PassManager 在“模块”传递(保证“停止世界”,没有其他传递并行执行)之间作为“检测”运行。

DisposableElementsAttr 提供了其他内存和速度优势,这些优势在类源文件中注释中进行了概述,并在 2022 年 11 月的演示文稿中进行了说明,该演示文稿链接自 会议 Wiki 页面

编写常量传播规则

我们使用 MLIR 声明性重写规则 (DRR) 为常量传播编写模式。用于定义模式的 DRR 定义如下所示

class Pattern<
   dag sourcePattern,
   list<dag> resultPatterns,
   list<dag> additionalConstraints = [],
   list<dag> supplementalPatterns = [],
   dag benefitsAdded = (addBenefit 0)
>;

有关 DRR 的更多信息,请参见 此处

现在,我们来看一个简单的示例,该示例为 ONNXAddOp 添加常量传播。

步骤 1:编写 DRR 模式

我们首先向 ConstProp.td 添加一个模式。

// Constant Propagation for Add
def AddConstProp : Pat<
    // source patten: From add(lhs, rhs).
    (ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
                      (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
    // result pattern: To c = lhs + rhs
    (CreateAddOfTwoConst $addOp, $lhs, $rhs),
    // Additional constraints: if both lhs and rhs are dense constants.
    [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;

上述模式将用一个新的常量替换输入为常量的 ONNXAddOp,方法是在编译时添加输入。要检查输入是否为常量,仅使用 ONNXConstantOp 不够,因为常量张量可以是稀疏的,而我们现在仅支持密集常量张量。我们需要另外使用 IsFromDenseONNXConstantOp 检查密集常量张量。

在结果模式中,为了生成 ONNXConstantOp,我们将在编译时添加 lhsrhs,并发出 ONNXConstantOp。为了最大程度地减少内存占用,此 ONNXConstantOp 具有 DisposableElementsAttr 而不是传统的 DenseElementsAttr。

函数 CreateAddOfTwoConst 将在编译时执行加法并返回 ONNXConstantOp。

def CreateAddOfTwoConst :
   NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;

步骤 2:为输入和结果准备数组缓冲区

模式中的函数 CreateAddOfTwoConst 调用 ConstProp.cpp 中的 ConstPropElementwiseBinary,其内容如下。

template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
    Value replacingValue, Value lhsValue, Value rhsValue) {
  ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
  Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());

  // Get lhs and rhs ElementsAttr from the values' defining constant ops.
  ElementsAttr lhs = getConstValueElements(lhsValue);
  ElementsAttr rhs = getConstValueElements(rhsValue);

  Type operandsElemType = lhs.getElementType();
  assert(operandsElemType == rhs.getElementType() &&
         "all element-wise binary ops have matching operands element types");
  OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
  ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
      combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));

  // Construct and return a new ONNXConstantOp with the resultElements attribute.
  return createReplacingConstantOp(rewriter, replacingValue, resultElements)
      .getResult();
}

其中 OnnxElementsAttrBuilder.combine(...) 根据需要广播 lhs 和 rhs 元素,并构造一个新的 (Disposable) ElementsAttr,其元素是二元函数 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType) 的逐元素应用的结果,该函数将 ElementwiseBinaryOp ONNX 操作映射到 c++ 运算符。

待办事项:描述如何为新的操作添加 OnnxElementsAttrBuilder 构建器方法

有关常量传播的更多信息,请参见 ConstProp.tdConstProp.cpp