维度标注

维度标注是一项实验性尝试,旨在为张量轴提供语义描述,从而定义类型,并基于这些类型执行后续的验证步骤。

动机

此类机制的动机可以通过一个简单的例子来说明。在下面的线性神经网络规范中,我们假设输入模型格式为 NCHW。

input_in_NCHW -> Transpose(input, perm=[0, 2, 1, 3]) -> AveragePool(input, ...)

在此神经网络中,用户错误地构建了一个将 NCHW 输入转置为奇怪的 NHCW 格式,并通过了假定 NCHW 输入格式的空间池化操作。尽管这明显是错误的,但现有基础设施不会向用户报告错误。这应该让高度依赖类型检查作为程序正确性保证的程序员深感不安。本提案旨在解决当前神经网络规范范式中固有的、缺乏恰当类型检查的真空。

本提案包含三个关键组成部分:标注定义、标注传播和标注验证,每个部分都将进行详细讨论。

标注定义

首先,我们为张量类型定义了一组类型。这些类型基于以下原则定义:

  1. 足够精细,可以消除潜在的陷阱。例如,动机部分说明的上述示例要求我们区分通道维度和空间特征维度,以确保 AveragePool 操作执行的正确性。

  2. 足够粗粒度,可以减轻用户的精神负担。例如,在上述示例中,区分宽度维度和高度维度在很大程度上没有必要,因为池化和卷积等操作通常不会区分不同的空间维度。因此,我们将所有空间维度概括为特征维度。

  3. 作为第 2 点的重要推论,做到模型无关。例如,循环神经网络 (RNN) 中特征维度的语义与卷积神经网络 (CNN) 中空间维度的语义几乎无法区分,因此我们允许用户和开发人员将两者都描述为特征维度。

具体来说,在我们最初的提案中,我们定义了以下一组标准标注:

  1. DATA_BATCH 描述训练数据的批次维度。这对应于更常用的张量格式表示法 NCHW 中的 N 维度。

  2. DATA_CHANNEL 描述训练数据的通道维度。这对应于 C 维度。

  3. DATA_TIME 描述时间维度。

  4. DATA_FEATURE 描述特征维度。这对应于 HW 维度或 RNN 中的特征维度。

  5. FILTER_IN_CHANNEL 描述滤波器输入通道维度。这是与输入图像特征图的通道维度(在大小上)相同的维度。

  6. FILTER_OUT_CHANNEL 描述滤波器输出通道维度。这是与输出图像特征图的通道维度(在大小上)相同的维度。

  7. FILTER_SPATIAL 描述滤波器空间维度。

标注传播

当操作相对于输入张量进行维度重排、销毁或创建时,会发生标注传播。在这种情况下,我们将实现自定义的、针对特定操作的函数,以根据输入张量维度标注推断输出张量维度标注。一个发生标注传播的操作示例是 Transpose 操作,其中输出维度标注推断的伪代码可以表示为输入维度标注的函数。

for i, j in enumerate(perm):
    out_dim_denotaion[i] = in_dim_denotation[j]

标注验证

当操作期望其输入以特定格式到达时,会发生标注验证。一个发生标注验证的操作示例是 AveragePool 操作,其中输入(如果带有关联的维度标注)在 2D 情况下的标注应为 [DATA_BATCH, DATA_CHANNEL, DATA_FEATURE, DATA_FEATURE]。如果期望的维度标注与实际的维度标注不匹配,则应报告错误。

类型标注

有关如何描述图像和其他类型的详细信息,请参阅 类型标注文档