ONNX 模型在 MLIR 编译器基础设施中的表示和参考下推
此项目由 onnx 维护
托管于 GitHub Pages — 主题来自 orderedlist
本文档介绍了用于对 ONNX 方言中的运算进行常量传播的 --constprop-onnx
传递。
源代码.
给定以下代码
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
%3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"std.return"(%4) : (tensor<1xf32>) -> ()
}
如果我们调用 onnx-mlir-op --constprop-onnx
,我们将得到
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
"std.return"(%0) : (tensor<1xf32>) -> ()
}
ONNXConstantOp 使用 MLIR DenseElementsAttr 来存储常量值。需要注意的是,一旦创建了 DenseElementsAttr,它就会一直存在并占用内存,直到编译结束。在示例中,三个 ONNXConstantOp 中的所有三个 DenseElementsAttr 都存在直到编译结束。特别是,由折叠两个 ONNXAddOps 产生的两个 ONNXConstantOps 中的两个中间 DenseElementsAttr 也存在。对于实际模型,中间 DenseElementsAttr 的数量会迅速增加,这会导致编译期间内存占用量很大。
为了避免在 --constprop-onnx
期间为中间 ONNXConstantOps 创建过多的 DenseElementsAttr,我们设计了一种机制,可以动态地为中间 ONNXConstantOps 分配和释放缓冲区,并且仅在常量传播和其他 ONNX 方言传递之后、在降低到 Krnl(或任何其他目标方言)之前创建 DenseElementsAttr。
这是通过一个自定义属性 DisposableElementsAttr 实现的,该属性在非复杂标量元素类型(布尔、整数和浮点类型)的常见情况下充当 DenseElementsAttr 的替代品。DisposableElementsAttr 实现与 DenseElementsAttr 相同的 ElementsAttr 接口,在大多数情况下它们在功能上是相同的,并且周围的代码不需要区分。它只需要使用 OnnxElementsAttrBuilder 类和 ElementsAttrHelper 函数来构建和访问 ElementsAttr 实例,以实现内存占用和性能优势。
DisposableElementsAttr 缓冲区的释放发生在编译器传递之间,由 DisposableGarbageCollector 完成。DisposableGarbageCollector 在“module”传递之间(这些传递保证“停止世界”,没有其他传递并行执行)被 PassManager 作为“插装”运行。
DisposableElementsAttr 提供了其他内存和速度优势,这些优势在类源文件中的注释中概述,并在 2022 年 11 月的演示文稿中进行了说明,该演示文稿从会议 wiki 页面链接。
我们使用 MLIR 声明式重写规则 (DRR) 来编写常量传播的模式。用于定义模式的 DRR 定义如下所示。
class Pattern<
dag sourcePattern,
list<dag> resultPatterns,
list<dag> additionalConstraints = [],
list<dag> supplementalPatterns = [],
dag benefitsAdded = (addBenefit 0)
>;
有关 DRR 的更多信息,请在此处查阅。
现在,我们通过一个简单的示例来说明如何为 ONNXAddOp 添加常量传播。
我们首先将一个模式添加到 ConstProp.td 中。
// Constant Propagation for Add
def AddConstProp : Pat<
// source patten: From add(lhs, rhs).
(ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
(ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
// result pattern: To c = lhs + rhs
(CreateAddOfTwoConst $addOp, $lhs, $rhs),
// Additional constraints: if both lhs and rhs are dense constants.
[(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;
上述模式会将输入为常量的 ONNXAddOp 替换为编译时添加输入的新的常量。要检查输入是否为常量,仅使用 ONNXConstantOp 并不足够,因为常量张量可以是稀疏的,而我们目前仅支持密集常量张量。我们需要使用 IsFromDenseONNXConstantOp
来额外检查一个密集常量张量。
在结果模式中,为了生成 ONNXConstantOp,我们将在编译时添加 lhs
和 rhs
,并发出一个 ONNXConstantOp。为了最小化内存占用,此 ONNXConstantOp 具有 DisposableElementsAttr 而不是常规的 DenseElementsAttr。
函数 CreateAddOfTwoConst
将在编译时执行加法运算并返回一个 ONNXConstantOp。
def CreateAddOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
模式中的函数 CreateAddOfTwoConst
调用 ConstProp.cpp 中的 ConstPropElementwiseBinary
,其内容如下。
template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value replacingValue, Value lhsValue, Value rhsValue) {
ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());
// Get lhs and rhs ElementsAttr from the values' defining constant ops.
ElementsAttr lhs = getConstValueElements(lhsValue);
ElementsAttr rhs = getConstValueElements(rhsValue);
Type operandsElemType = lhs.getElementType();
assert(operandsElemType == rhs.getElementType() &&
"all element-wise binary ops have matching operands element types");
OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));
// Construct and return a new ONNXConstantOp with the resultElements attribute.
return createReplacingConstantOp(rewriter, replacingValue, resultElements)
.getResult();
}
其中 OnnxElementsAttrBuilder.combine(...)
会根据需要广播 lhs 和 rhs 元素,并构建一个新的(Disposable)ElementsAttr,其元素是通过按元素应用二元函数 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType)
得到的,该函数将 ElementwiseBinaryOp ONNX 运算映射到 C++ 运算符。
有关常量传播的更多信息,请参阅 ConstProp.td 和 ConstProp.cpp。