ONNX Shape Inference

ONNX 提供了一个可选的 ONNX 图形状推断实现。此实现涵盖了核心运算符的每个运算符,并提供了可扩展的接口。因此,您可以选择在您的图上调用现有的形状推断功能,或者为您自定义运算符定义形状推断实现(或两者兼有!)。形状推断函数存储在 OpSchema 对象的一个成员中。

在 ONNX 1.10 版本中,符号生成和传播以及形状数据传播被添加到 ONNX 图级别的形状推断中。详细的提案在这里:这里

背景

请参阅 IR.md此部分,了解静态张量形状的概述。特别是,静态张量形状(由 TensorShapeProto 表示)与运行时张量形状不同。当精确的运行时张量形状在静态(即编译时)未知时,通常会使用此功能。

  • 具有未定义 shape 字段的 Tensor 用于表示秩未知的张量。

  • 具有已定义 shapeTensor 表示秩已知的张量。

  • TensorShapeProto 的每个 Dimension 可以具有已知的整数值(由 dim_value 字段表示),或者可以具有由符号标识符表示的未知值(dim_param 字段),或者两者都没有定义(在这种情况下,它表示一个匿名的未知值)。

调用形状推断

可以通过 C++ 或 Python 调用形状推断。Python API 附带示例,此处 有详细描述。

C++ API 由一个函数组成

shape_inference::InferShapes(
    ModelProto& m,
    const ISchemaRegistry* schema_registry);

第一个参数是要进行形状推断的 ModelProto,它将就地用形状信息进行注解。第二个参数是可选的。

局限性

形状推断不保证是完整的。特别是,一些动态行为会阻塞形状推断的流程,例如将张量重塑为动态提供的形状。此外,并非所有运算符都需要具有形状推断实现。

形状推断仅适用于常量和简单变量。它不支持包含变量的算术表达式。例如,形状为 (5, 2)(7, 2) 的张量的 Concat 可以推断为产生形状为 (12, 2) 的结果,但是形状为 (5, 2)(N, 2) 的张量的 Concat 将简单地产生 (M, 2),而不是包含 N+5 的表示。请注意,不同的未知符号值将被传播,因此这里的 M 表示一个未知量,它与其他 M 的出现相同。

这些限制是当前实现的属性,而不是根本性的约束——如果您需要更高级的功能,请告诉我们!

为运算符实现形状推断

您可以使用以下方式为运算符的 Schema 添加形状推断函数:

OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);

InferenceFunction 定义在 shape_inference.h 中,以及核心接口结构 InferenceContext 和一系列辅助方法。 InferenceContext 是提供给推断函数的核心结构。它允许访问运算符输入的有关信息,也允许写出推断的信息。

要查看大量示例,请在代码库中搜索 TypeAndShapeInferenceFunction 的出现。一个相对复杂的例子是 Concat 在 onnx/defs/tensor/defs.cc 中的实现。

在为运算符实现形状推断方法时,请注意以下几点,以避免常见错误:

  • 在访问任何输入的 shape 之前,代码必须检查形状是否可用。如果不可用,则应将其视为秩未知的动态张量并进行适当处理。通常,形状推断逻辑会通过调用 hasInputShapehasNInputShapes 来保护。

  • 在访问任何维度的 dim_valuedim_param 之前,代码必须检查这些字段是否具有值。特别是,代码必须处理维度可能没有静态已知值的可能性。

shape_inference.h 中有几个实用函数可以处理各种常见情况。

  • 对于必须具有固定秩的输入,请使用 checkInputRank。(请参阅 RoiAlign 的推断作为示例。)

  • unifyInputDimunifyDimupdateOutputShape 可用于多个输入维度预期相同,以及当输入维度传播到特定输出维度时。(请参阅 RoiAlign 的推断作为示例。)

  • 当输出维度使用算术从输入维度计算时,可以在符号维度上使用重载运算符 */。(请参阅 SpaceToDepth 的推断作为示例。)

这些实用程序可安全地处理缺失的形状和维度。

示例:考虑一个简单的矩阵乘法运算符,它期望输入形状为 [M,K][K,N],并返回形状为 [M,N] 的输出。这可以如下编码:

   // Check that input 0 has rank 2 (if its rank is known).
   checkInputRank(ctx, 0, 2);
   // Check that input 1 has rank 2 (if its rank is known).
   checkInputRank(ctx, 1, 2);
   Dim M, K, N;
   // Check various dimensions, handling missing dimensions/shapes safely.
   unifyInputDim(ctx, 0, 0, M);
   unifyInputDim(ctx, 0, 1, K);
   unifyInputDim(ctx, 1, 0, K);
   unifyInputDim(ctx, 1, 1, N);
   updateOutputShape(ctx, 0, {M. N});