维度标记¶
维度标记是一种实验性尝试,旨在为张量轴赋予语义描述,从而赋予类型,并在此基础上后续执行验证步骤。
动机¶
可以通过一个简单的例子来说明这种机制的动机。在下面的线性神经网络规范中,我们假设模型输入采用 NCHW 格式
input_in_NCHW -> Transpose(input, perm=[0, 2, 1, 3]) -> AveragePool(input, ...)
在这个神经网络中,用户错误地构建了一个将 NCHW 输入转置为奇怪的 NHCW 格式,并经过假设 NCHW 输入格式的空间池化操作的神经网络。尽管这是一个明显的错误,但现有基础设施不会向用户报告错误。对于严重依赖类型检查作为程序正确性保证不可或缺的一部分的程序员来说,这应该会让他们深感不安。本提案旨在解决当前神经网络规范范式中固有的缺乏适当类型检查的问题。
本提案包含三个关键组成部分:标记定义、标记传播和标记验证,每个组成部分都将详细讨论。
标记定义¶
首先,我们为张量类型定义了一组类型。这些类型基于以下原则定义
粒度足够细,以消除潜在陷阱。例如,动机部分说明的上述示例要求我们区分通道维度和空间特征维度,以确保 AveragePool 算子的正确执行。
粒度足够粗,以减轻用户的精神负担。例如,在上面的例子中,区分宽度维度和高度维度的必要性大大降低,因为池化和卷积等操作通常不区分各种空间维度。因此,我们将所有空间维度概括为特征维度。
作为原则 2 的一个重要推论,要模型无关。例如,循环神经网络 (RNN) 中特征维度的语义与卷积神经网络 (CNN) 中空间维度的语义几乎无法区分,因此我们允许用户和开发者将两者都描述为特征维度。
具体而言,在我们的第一个提案中,我们定义了以下一组标准标记
DATA_BATCH
描述训练数据的批次维度。这对应于更常用的张量格式表示法NCHW
中的N
维度。DATA_CHANNEL
描述训练数据的通道维度。这对应于C
维度。DATA_TIME
描述时间维度。DATA_FEATURE
描述特征维度。这对应于H
、W
维度或 RNN 中的特征维度。FILTER_IN_CHANNEL
描述过滤器的输入通道维度。这个维度的大小与输入图像特征图的通道维度相同。FILTER_OUT_CHANNEL
描述过滤器的输出通道维度。这个维度的大小与输出图像特征图的通道维度相同。FILTER_SPATIAL
描述过滤器的空间维度。
标记传播¶
当一个操作相对于其输入张量置换、销毁或创建维度时,就会发生标记传播。在这种情况下,我们将实现定制的、特定于操作的函数,以根据输入张量维度标记推断输出张量维度标记。发生标记传播的一个示例操作是 Transpose 操作,其输出维度标记推断的伪代码可以表达为输入维度标记的函数
for i, j in enumerate(perm):
out_dim_denotaion[i] = in_dim_denotation[j]
标记验证¶
当一个操作期望其输入以特定格式到达时,就会发生标记验证。发生标记验证的一个示例操作是 AveragePool 操作,如果输入带有维度标记,在 2D 情况下应具有标记 [DATA_BATCH
, DATA_CHANNEL
, DATA_FEATURE
, DATA_FEATURE
]。如果期望的维度标记与实际维度标记不匹配,则应报告错误。
类型标记¶
有关如何描述图像和其他类型的更多详细信息,请参阅类型标记文档。