在 MLIR 编译器基础设施中表示和参考降低 ONNX 模型
此项目由 onnx 维护
托管在 GitHub Pages 上 — 主题由 orderedlist 提供
本文档描述了 --constprop-onnx
传递,该传递用于对 ONNX 方言中的操作进行常量传播。
源代码.
给定以下代码
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
%3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"std.return"(%4) : (tensor<1xf32>) -> ()
}
如果我们调用 onnx-mlir-op --constprop-onnx
,我们将得到
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
"std.return"(%0) : (tensor<1xf32>) -> ()
}
ONNXConstantOp 使用 MLIR DenseElementsAttr 存储常量值。需要注意的是,一旦创建了 DenseElementsAttr,它就会一直存在并占用内存,直到编译结束。在 示例 中,三个 ONNXConstantOp 中的所有三个 DenseElementsAttr 都存在直到编译结束。特别是,由折叠两个 ONNXAddOp 生成的两个 ONNXConstantOp 中的两个中间 DenseElementsAttr 也存在。对于真实世界的模型,中间 DenseElementsAttr 的数量会迅速增加,这会导致编译期间内存占用量很大。
为了避免在 --constprop-onnx
期间为中间 ONNXConstantOp 创建过多的 DenseElementsAttr,我们设计了一种机制,可以为中间 ONNXConstantOp 动态分配和释放缓冲区,并且仅在常量传播和其他 ONNX 方言传递之后,就在降低到 Krnl(或任何其他目标方言)之前创建 DenseElementsAttr。
这是通过一个自定义属性 DisposableElementsAttr 实现的,该属性充当非复杂标量元素类型(布尔值、整数和浮点数类型)的常见情况下的 DenseElementsAttr 的替代品。DisposableElementsAttr 实现与 DenseElementsAttr 相同的 ElementsAttr 接口,在大多数情况下,它们在功能上是相同的,周围的代码不需要进行区分。它只需要使用 OnnxElementsAttrBuilder 类和 ElementsAttrHelper 函数来构造和访问 ElementsAttr 实例,以获得内存占用和性能优势。
DisposableElementsAttr 缓冲区的释放发生在 DisposableGarbageCollector 中的编译器传递之间,该收集器由 PassManager 在“模块”传递(保证“停止世界”,没有其他传递并行执行)之间作为“检测”运行。
DisposableElementsAttr 提供了其他内存和速度优势,这些优势在类源文件中注释中进行了概述,并在 2022 年 11 月的演示文稿中进行了说明,该演示文稿链接自 会议 Wiki 页面。
我们使用 MLIR 声明性重写规则 (DRR) 为常量传播编写模式。用于定义模式的 DRR 定义如下所示
class Pattern<
dag sourcePattern,
list<dag> resultPatterns,
list<dag> additionalConstraints = [],
list<dag> supplementalPatterns = [],
dag benefitsAdded = (addBenefit 0)
>;
有关 DRR 的更多信息,请参见 此处。
现在,我们来看一个简单的示例,该示例为 ONNXAddOp 添加常量传播。
我们首先向 ConstProp.td 添加一个模式。
// Constant Propagation for Add
def AddConstProp : Pat<
// source patten: From add(lhs, rhs).
(ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
(ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
// result pattern: To c = lhs + rhs
(CreateAddOfTwoConst $addOp, $lhs, $rhs),
// Additional constraints: if both lhs and rhs are dense constants.
[(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;
上述模式将用一个新的常量替换输入为常量的 ONNXAddOp,方法是在编译时添加输入。要检查输入是否为常量,仅使用 ONNXConstantOp 不够,因为常量张量可以是稀疏的,而我们现在仅支持密集常量张量。我们需要另外使用 IsFromDenseONNXConstantOp
检查密集常量张量。
在结果模式中,为了生成 ONNXConstantOp,我们将在编译时添加 lhs
和 rhs
,并发出 ONNXConstantOp。为了最大程度地减少内存占用,此 ONNXConstantOp 具有 DisposableElementsAttr 而不是传统的 DenseElementsAttr。
函数 CreateAddOfTwoConst
将在编译时执行加法并返回 ONNXConstantOp。
def CreateAddOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
模式中的函数 CreateAddOfTwoConst
调用 ConstProp.cpp 中的 ConstPropElementwiseBinary
,其内容如下。
template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value replacingValue, Value lhsValue, Value rhsValue) {
ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());
// Get lhs and rhs ElementsAttr from the values' defining constant ops.
ElementsAttr lhs = getConstValueElements(lhsValue);
ElementsAttr rhs = getConstValueElements(rhsValue);
Type operandsElemType = lhs.getElementType();
assert(operandsElemType == rhs.getElementType() &&
"all element-wise binary ops have matching operands element types");
OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));
// Construct and return a new ONNXConstantOp with the resultElements attribute.
return createReplacingConstantOp(rewriter, replacingValue, resultElements)
.getResult();
}
其中 OnnxElementsAttrBuilder.combine(...)
根据需要广播 lhs 和 rhs 元素,并构造一个新的 (Disposable) ElementsAttr,其元素是二元函数 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType)
的逐元素应用的结果,该函数将 ElementwiseBinaryOp ONNX 操作映射到 c++ 运算符。
有关常量传播的更多信息,请参见 ConstProp.td 和 ConstProp.cpp。