维度标注¶
维度标注是一项实验性尝试,旨在为张量轴提供语义描述,从而定义类型,并基于这些类型执行后续的验证步骤。
动机¶
此类机制的动机可以通过一个简单的例子来说明。在下面的线性神经网络规范中,我们假设输入模型格式为 NCHW。
input_in_NCHW -> Transpose(input, perm=[0, 2, 1, 3]) -> AveragePool(input, ...)
在此神经网络中,用户错误地构建了一个将 NCHW 输入转置为奇怪的 NHCW 格式,并通过了假定 NCHW 输入格式的空间池化操作。尽管这明显是错误的,但现有基础设施不会向用户报告错误。这应该让高度依赖类型检查作为程序正确性保证的程序员深感不安。本提案旨在解决当前神经网络规范范式中固有的、缺乏恰当类型检查的真空。
本提案包含三个关键组成部分:标注定义、标注传播和标注验证,每个部分都将进行详细讨论。
标注定义¶
首先,我们为张量类型定义了一组类型。这些类型基于以下原则定义:
足够精细,可以消除潜在的陷阱。例如,动机部分说明的上述示例要求我们区分通道维度和空间特征维度,以确保 AveragePool 操作执行的正确性。
足够粗粒度,可以减轻用户的精神负担。例如,在上述示例中,区分宽度维度和高度维度在很大程度上没有必要,因为池化和卷积等操作通常不会区分不同的空间维度。因此,我们将所有空间维度概括为特征维度。
作为第 2 点的重要推论,做到模型无关。例如,循环神经网络 (RNN) 中特征维度的语义与卷积神经网络 (CNN) 中空间维度的语义几乎无法区分,因此我们允许用户和开发人员将两者都描述为特征维度。
具体来说,在我们最初的提案中,我们定义了以下一组标准标注:
DATA_BATCH
描述训练数据的批次维度。这对应于更常用的张量格式表示法NCHW
中的N
维度。DATA_CHANNEL
描述训练数据的通道维度。这对应于C
维度。DATA_TIME
描述时间维度。DATA_FEATURE
描述特征维度。这对应于H
、W
维度或 RNN 中的特征维度。FILTER_IN_CHANNEL
描述滤波器输入通道维度。这是与输入图像特征图的通道维度(在大小上)相同的维度。FILTER_OUT_CHANNEL
描述滤波器输出通道维度。这是与输出图像特征图的通道维度(在大小上)相同的维度。FILTER_SPATIAL
描述滤波器空间维度。
标注传播¶
当操作相对于输入张量进行维度重排、销毁或创建时,会发生标注传播。在这种情况下,我们将实现自定义的、针对特定操作的函数,以根据输入张量维度标注推断输出张量维度标注。一个发生标注传播的操作示例是 Transpose 操作,其中输出维度标注推断的伪代码可以表示为输入维度标注的函数。
for i, j in enumerate(perm):
out_dim_denotaion[i] = in_dim_denotation[j]
标注验证¶
当操作期望其输入以特定格式到达时,会发生标注验证。一个发生标注验证的操作示例是 AveragePool 操作,其中输入(如果带有关联的维度标注)在 2D 情况下的标注应为 [DATA_BATCH
, DATA_CHANNEL
, DATA_FEATURE
, DATA_FEATURE
]。如果期望的维度标注与实际的维度标注不匹配,则应报告错误。
类型标注¶
有关如何描述图像和其他类型的详细信息,请参阅 类型标注文档。