ONNX Shape Inference¶
ONNX 提供了可选的 ONNX 图形状推断实现。此实现涵盖了每个核心算子,并提供了可扩展性接口。因此,您可以选择对您的图调用现有的形状推断功能,或者为您的自定义算子定义形状推断实现(或者两者都做!)。形状推断函数作为 OpSchema 对象的一个成员存储。
在 ONNX 1.10 版本中,ONNX 图级形状推断增加了符号生成和传播以及形状数据传播。详细提案请参见此处
背景¶
有关静态张量形状的概述,请参阅 IR.md 中的此部分。特别要注意的是,静态张量形状(由 TensorShapeProto
表示)与运行时张量形状不同。此功能通常用于在静态(即编译时)未知确切运行时张量形状的情况。
具有未定义
shape
字段的Tensor
用于表示未知秩的张量。具有已定义
shape
的Tensor
表示已知秩的张量。TensorShapeProto
的每个Dimension
可以具有已知整数值(由dim_value
字段表示),或者可以具有由符号标识表示的未知值(dim_param
字段),或者两个字段都可以未定义(在这种情况下,它表示匿名的未知值)。
调用形状推断¶
可以通过 C++ 或 Python 调用形状推断。Python API 的描述和示例请参见此处。
C++ API 包含一个函数
shape_inference::InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry);
第一个参数是要执行形状推断的 ModelProto
,形状信息将直接标注在原地。第二个参数是可选的。
限制¶
不保证形状推断是完整的。特别是一些动态行为会阻碍形状推断的流程,例如将 Reshape 为动态提供的形状。此外,并非所有算子都必须有形状推断实现。
形状推断仅适用于常量和简单变量。它不支持包含变量的算术表达式。例如,对形状为 (5, 2)
和 (7, 2)
的张量执行 Concat
,可以推断出结果形状为 (12, 2)
,但对形状为 (5, 2)
和 (N, 2)
的张量执行 Concat
,只会简单地产生 (M, 2)
,而不是包含 N+5
的表示。请注意,不同的未知符号值会被传播,因此这里的 M
表示一个未知量,与 M
的其他出现是相同的。
这些限制是当前实现的特性,而非根本限制 - 如果您需要更高级的功能,请告知我们!
为算子实现形状推断¶
您可以使用以下方式向您的算子的 Schema 添加形状推断函数:
OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction
定义在 shape_inference.h 中,以及核心接口结构体 InferenceContext
和各种辅助方法。InferenceContext
是提供给您的推断函数的核心结构体。它允许访问有关算子输入的信息,并且允许写出推断出的信息。
要查看大量示例,请在代码库中搜索 TypeAndShapeInferenceFunction
的出现。其中一个相对复杂的实现在 onnx/defs/tensor/defs.cc 中,用于 Concat
。
在实现算子的形状推断方法时,请注意以下几点,以避免常见错误:
在访问任何输入的
shape
之前,代码必须检查形状是否可用。如果不可用,应将其视为秩未知的动态张量并进行适当处理。通常,形状推断逻辑会由对hasInputShape
或hasNInputShapes
的调用来保护。在访问任何维度的
dim_value
或dim_param
之前,代码必须检查这些字段是否有值。特别是,代码必须处理维度可能没有静态已知值的可能性。
shape_inference.h 中有几个实用函数可以处理各种常见情况。
对于必须具有固定秩的输入,请使用
checkInputRank
。(请参阅RoiAlign
的推断作为示例。)当预期多个输入维度相同时,以及当输入维度传播到特定输出维度时,可以使用
unifyInputDim
、unifyDim
和updateOutputShape
。(请参阅RoiAlign
的推断作为示例。)当使用算术从输入维度计算输出维度时,可以在符号维度上使用重载的
*
和/
算子。(请参阅SpaceToDepth
的推断作为示例。)
这些实用程序可以安全地处理缺失的形状和维度。
示例:考虑一个简单的矩阵乘法算子,它期望输入形状为 [M,K]
和 [K,N]
,并返回形状为 [M,N]
的输出。这可以按如下方式编写:
// Check that input 0 has rank 2 (if its rank is known).
checkInputRank(ctx, 0, 2);
// Check that input 1 has rank 2 (if its rank is known).
checkInputRank(ctx, 1, 2);
Dim M, K, N;
// Check various dimensions, handling missing dimensions/shapes safely.
unifyInputDim(ctx, 0, 0, M);
unifyInputDim(ctx, 0, 1, K);
unifyInputDim(ctx, 1, 0, K);
unifyInputDim(ctx, 1, 1, N);
updateOutputShape(ctx, 0, {M. N});