MLIR 编译器基础设施中的 ONNX 模型表示和参考降低
本项目由 onnx 维护
托管于 GitHub Pages — 主题作者 orderedlist
本文档描述了 --constprop-onnx
pass,它用于对 ONNX 方言中的操作进行常量传播。
源代码.
假设有以下代码
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
%3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"std.return"(%4) : (tensor<1xf32>) -> ()
}
如果我们调用 onnx-mlir-op --constprop-onnx
,我们将得到
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
"std.return"(%0) : (tensor<1xf32>) -> ()
}
ONNXConstantOp 使用 MLIR DenseElementsAttr 来存储常量值。值得注意的是,一旦创建了 DenseElementsAttr,它将一直存在并消耗内存直到编译结束。在示例中,三个 ONNXConstantOp 中的所有三个 DenseElementsAttr 都存在直到编译结束。特别是,由折叠两个 ONNXAddOp 生成的两个 ONNXConstantOp 中的两个中间 DenseElementsAttr 也存在。对于一个真实的模型,中间 DenseElementsAttr 的数量会迅速增加,这导致编译期间的内存占用很大。
为了避免在 --constprop-onnx
过程中为中间的 ONNXConstantOp 创建过多的 DenseElementsAttr,我们设计了一种机制,它动态地为中间的 ONNXConstantOp 分配和释放缓冲区,并且只在常量传播和其他 ONNX 方言 passes 之后、就在降低到 Krnl(或任何其他目标方言)之前才创建 DenseElementsAttr。
这通过一个自定义属性 DisposableElementsAttr 实现,它作为 DenseElementsAttr 的替代品,用于非复杂标量元素类型的常见情况:布尔类型、整数类型和浮点类型。DisposableElementsAttr 实现了与 DenseElementsAttr 相同的 ElementsAttr 接口,并且在大多数情况下,它们在功能上是相同的,周围的代码无需区分。它只需要使用 OnnxElementsAttrBuilder 类和 ElementsAttrHelper 函数来构建和访问 ElementsAttr 实例,以获得内存占用和性能方面的收益。
DisposableElementsAttr 缓冲区的释放发生在 DisposableGarbageCollector 中,它由 PassManager 在“module” passes 之间运行(这些 passes 保证会“停止所有活动”,没有其他 passes 并行执行),作为一种“插桩”。
DisposableElementsAttr 还提供了其他内存和速度方面的优势,这些优势在类源代码文件中的注释中有所概述,并在 2022 年 11 月的演示文稿中进行了解释,链接自会议 wiki 页面。
我们使用 MLIR 声明式重写规则 (DRR) 来编写常量传播的模式。用于定义模式的 DRR 定义如下所示
class Pattern<
dag sourcePattern,
list<dag> resultPatterns,
list<dag> additionalConstraints = [],
list<dag> supplementalPatterns = [],
dag benefitsAdded = (addBenefit 0)
>;
可以在此处找到关于 DRR 的更多信息。
现在,我们来看一个简单的示例,它为 ONNXAddOp 添加常量传播。
我们首先向ConstProp.td 添加一个模式。
// Constant Propagation for Add
def AddConstProp : Pat<
// source patten: From add(lhs, rhs).
(ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
(ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
// result pattern: To c = lhs + rhs
(CreateAddOfTwoConst $addOp, $lhs, $rhs),
// Additional constraints: if both lhs and rhs are dense constants.
[(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;
上述模式将替换一个输入为常量的 ONNXAddOp,通过在编译时将输入相加,生成一个新的常量。要检查输入是否为常量,仅使用 ONNXConstantOp 是不够的,因为常量张量可以是稀疏的,而我们目前只支持密集常量张量。我们还需要通过使用 IsFromDenseONNXConstantOp
来检查是否为密集常量张量。
在结果模式中,为了生成一个 ONNXConstantOp,我们将在编译时将 lhs
和 rhs
相加,并发出一个 ONNXConstantOp。为了最小化内存占用,这个 ONNXConstantOp 使用了 DisposableElementsAttr 而不是传统的 DenseElementsAttr。
函数 CreateAddOfTwoConst
将在编译时执行加法并返回一个 ONNXConstantOp。
def CreateAddOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
模式中的函数 CreateAddOfTwoConst
调用ConstProp.cpp 中的 ConstPropElementwiseBinary
,其内容如下。
template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value replacingValue, Value lhsValue, Value rhsValue) {
ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());
// Get lhs and rhs ElementsAttr from the values' defining constant ops.
ElementsAttr lhs = getConstValueElements(lhsValue);
ElementsAttr rhs = getConstValueElements(rhsValue);
Type operandsElemType = lhs.getElementType();
assert(operandsElemType == rhs.getElementType() &&
"all element-wise binary ops have matching operands element types");
OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));
// Construct and return a new ONNXConstantOp with the resultElements attribute.
return createReplacingConstantOp(rewriter, replacingValue, resultElements)
.getResult();
}
其中 OnnxElementsAttrBuilder.combine(...)
根据需要广播 lhs 和 rhs 元素,并构建一个新的 (Disposable) ElementsAttr,其元素是二进制函数 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType)
元素级应用的结果,它将 ElementwiseBinaryOp ONNX op 映射到 C++ 运算符。
关于常量传播的更多信息,请参阅ConstProp.td 和ConstProp.cpp。