RMSNormalization

RMSNormalization - 23

版本

  • 名称: RMSNormalization (GitHub)

  • : main

  • 起始版本: 23

  • 函数: True

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的运算符已从版本 23 开始可用。

摘要

这是 ONNX 中定义的 RMS 归一化,作为函数,如论文 https://arxiv.org/pdf/1910.07467 所述。整体计算可以分为两个阶段。对最后 D 个维度进行均方根范数计算,其中 D 是 normalized_shape 的维度。例如,如果 normalized_shape 是 (3, 5)(一个二维形状),则对输入的最后 2 个维度计算 rms 范数。标准化所需的计算可以用以下公式描述。

XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
MeanSquareEpsilon = Add(XSquaredMean, epsilon)
RMS = Sqrt(MeanSquareEpsilon)
Normalized = Div(X, RMS)

其中 normalized_axes[axis, ..., X 的秩 - 1]。变量 RMS 代表均方根。根据 stash_type 属性,实际计算必须在不同的浮点精度下进行。例如,如果 stash_type 为 1,则此运算符会将所有输入变量转换为 32 位浮点数,执行计算,最后将 Normalized 转换回 X 的原始类型。然后第二阶段使用以下方式缩放第一阶段的结果

Y= Mul(Normalized, Scale)

d[i] 表示 X 的第 i 个维度。如果 X 的形状是 [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]],则 RMS 的形状是 [d[0], ..., d[axis-1], 1, ..., 1]YX 具有相同的形状。此运算符支持单向广播(Scale 应对张量 X 进行单向广播);更多详细信息请查看 ONNX 中的广播

属性

  • axis - INT (默认为 '-1')

    第一个归一化维度。如果 rank(X) 为 r,则 axis 的允许范围是 [-r, r)。负值表示从后向前计数维度。

  • epsilon - FLOAT (默认为 '1e-05')

    用于避免除以零的 epsilon 值。

  • stash_type - INT (默认为 '1')

    计算第一阶段使用的浮点精度。

输入

  • X (异构) - T

    要进行归一化的输入张量。通常,对于 n 维数据,形状为 (D1, D2, ..., Dn),均方根范数是针对最后 D 个维度计算的,其中 D 由 axis 属性确定。

  • scale (异构) - V

    缩放张量。缩放张量的形状应能广播到归一化形状。

输出

  • Y (异构) - V

    输出数据张量。形状与 X 相同。

类型约束

  • T 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) ) 中

    将输入 X 的类型限制为浮点张量。

  • V 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) ) 中

    将输出 Y 和 scale 的类型限制为浮点张量。