注意

注意 - 24

版本

  • 名称Attention (GitHub)

  • : main

  • 起始版本24

  • 函数: True

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的操作符已可用于版本 24 及以上

摘要

对查询、键和值张量计算缩放点积注意力,如果传入可选的注意力掩码则使用该掩码。

此操作符涵盖了基于 K、Q 和 V 序列长度的注意力操作的自注意力和交叉注意力变体。

对于自注意力,kv_sequence_length 等于 q_sequence_length

对于交叉注意力,查询和键的长度可能不同。

此操作符还涵盖了基于头数目的以下 3 种变体:

  1. 多头注意力 (MHA):在论文 https://arxiv.org/pdf/1706.03762 中描述,q_num_heads = kv_num_heads

  2. 组查询注意力 (GQA):在论文 https://arxiv.org/pdf/2305.13245 中描述,q_num_heads > kv_num_headsq_num_heads % kv_num_heads == 0

  3. 多查询注意力 (MQA):在论文 https://arxiv.org/pdf/1911.02150 中描述,q_num_heads > kv_num_headskv_num_heads=1

要添加的注意力偏置是根据 attn_mask 输入和 is_causal 属性计算的。

  1. attn_mask:一个布尔掩码,其中 True 值表示该元素应参与注意力,或者是一个与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。

  2. 如果 is_causal 设置为 1,则无论 attn_mask 输入如何,对角线上方的注意力分数都会被掩盖。

关于 KV 缓存更新,此操作符允许以下两种用例:

  1. 缓存更新发生在 Attention 操作符内部。在这种情况下,KV 输入仅包含当前自回归步骤的传入标记,并且所有四个可选的输入/输出(过去和当前的键和值)都是必需的。Attention 操作会分别对过去和传入的键和值执行 Concat 操作,以形成当前的键和值。请注意,这仅适用于过去键和值不包含填充标记的特殊情况。

  2. 缓存更新发生在 Attention 操作符外部(例如,通过 TensorScatter 操作符)。在这种情况下,KV 输入对应于整个缓存张量,因此不应使用四个可选的输入/输出(过去和当前的键和值)。可以提供一个形状为 (batch_size,) 的附加输入 nonpad_kv_seqlen,以指示批次中每个样本的非填充标记数量,从而节省不必要的计算。这里,attn_mask 的 kv_sequence 维度可以比 KV 短,但仍然需要至少与 nonpad_kv_seqlen 的最大值一样长。

过去和当前状态的键/值都是可选的。它们应一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当的重塑后,以下模式应用于 Q、K 和 V 输入:

  The following pattern is applied by this operator:
      Q          K          V
      |          |          |
Q*sqrt(scale) K*sqrt(scale) |
      |          |          |
      |       Transpose     |
      |          |          |
      ---MatMul---          |
            |               |
 at_mask---Add              |
            |               |
  softcap (if provided)     |
            |               |
         Softmax            |
            |               |
            -----MatMul------
                   |
                   Y

属性

  • is_causal - INT(默认为'0'

    如果设置为 1,则当掩码为方阵时,注意力掩码是一个下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。

  • kv_num_heads - INT :

    键和值的头数。必须与 Q、K 和 V 的 3D 输入一起使用。

  • q_num_heads - INT :

    查询的头数。必须与 Q、K 和 V 的 3D 输入一起使用。

  • qk_matmul_output_mode - INT(默认为'0'

    如果设置为 0,qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为 1,qk_matmul_output 包含注意力掩码添加到 qk 矩阵乘法输出的结果。如果设置为 2,qk_matmul_output 是 softcap 操作后的输出。如果设置为 3,qk_matmul_output 是 softmax 操作后的输出。默认值为 0。

  • scale - FLOAT :

    应用于 \(Q*K^T\) 的缩放因子。默认值为 1/sqrt(head_size)。为防止数值溢出,在矩阵乘法之前,将 QKsqrt(scale) 缩放。

  • softcap - FLOAT(默认为'0.0'

    注意力权重的 softcap 值。默认值为 0。

  • softmax_precision - INT :

    softmax 计算中使用的浮点精度。如果未提供 softmax 精度,则使用 softmax 输入(Q 和 K)的相同精度。

输入

3 到 7 个输入之间。

  • Q(异构)- T1

    查询张量。形状为 (batch_size, q_num_heads, q_sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, q_sequence_length, q_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,q_hidden_size = q_num_heads * head_size

  • K(异构)- T1

    键张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, kv_sequence_length, k_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,k_hidden_size = kv_num_heads * head_size

  • V(异构)- T2

    值张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, v_head_size) 的 4D 张量,或形状为 (batch_size, kv_sequence_length, v_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,v_hidden_size = kv_num_heads * v_head_size

  • attn_mask(可选,异构)- U

    注意力掩码。形状必须可广播到 (batch_size, q_num_heads, q_sequence_length, total_sequence_length),其中 total_sequence_length = past_sequence_length + kv_sequence_length. 最后一个维度也可以比 total_sequence_length 短,并将用负无穷填充到 total_sequence_length。支持两种类型的掩码:一个布尔掩码,其中 True 值表示该元素应参与注意力;或者是一个与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。

  • past_key(可选,异构)- T1

    过去状态缓存的键,形状为 (batch_size, kv_num_heads, past_sequence_length, head_size)

  • past_value(可选,异构)- T2

    过去状态缓存的值,形状为 (batch_size, kv_num_heads, past_sequence_length, v_head_size)

  • nonpad_kv_seqlen(可选,异构)- tensor(int64)

    一个形状为 (batch_size,) 的整数向量,表示每个样本中有效(即非填充)标记的数量。可以由此推导出填充掩码。此输入不应与 past_keypast_value 输入或 present_keypresent_value 输出一起使用(参见操作符描述中的 KV 缓存用例)。

输出

1 到 4 个输出之间。

  • Y(异构)- T1

    输出张量。形状为 (batch_size, q_num_heads, q_sequence_length, v_head_size) 的 4D 张量,或形状为 (batch_size, q_sequence_length, hidden_size) 的 3D 张量。对于 3D 输入张量的情况,hidden_size = q_num_heads * v_head_size

  • present_key(可选,异构)- T1

    更新后的键缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • present_value(可选,异构)- T2

    更新后的值缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, v_head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • qk_matmul_output(可选,异构)- T1

    QK 矩阵乘法的输出。形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的 4D 张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length

类型约束

  • T1 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    将 Q 和 K 输入类型限制为浮点张量。

  • T2 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    将 V 输入类型限制为浮点张量。

  • U 在 ( tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) )

    将输出“mask”类型限制为布尔张量和输入类型。

注意 - 23

版本

  • 名称Attention (GitHub)

  • : main

  • 起始版本23

  • 函数: True

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的操作符已可用于版本 23 及以上

摘要

对查询、键和值张量计算缩放点积注意力,如果传入可选的注意力掩码则使用该掩码。

此操作符涵盖了基于 K、Q 和 V 序列长度的注意力操作的自注意力和交叉注意力变体。

对于自注意力,kv_sequence_length 等于 q_sequence_length

对于交叉注意力,查询和键的长度可能不同。

此操作符还涵盖了基于头数目的以下 3 种变体:

  1. 多头注意力 (MHA):在论文 https://arxiv.org/pdf/1706.03762 中描述,q_num_heads = kv_num_heads

  2. 组查询注意力 (GQA):在论文 https://arxiv.org/pdf/2305.13245 中描述,q_num_heads > kv_num_headsq_num_heads % kv_num_heads == 0

  3. 多查询注意力 (MQA):在论文 https://arxiv.org/pdf/1911.02150 中描述,q_num_heads > kv_num_headskv_num_heads=1

要添加的注意力偏置是根据 attn_mask 输入和 is_causal attribute 计算的,只能提供其中之一。

  1. 如果 is_causal 设置为 1,则当掩码为方阵时,注意力掩码是一个下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。

  2. attn_mask:一个布尔掩码,其中 True 值表示该元素应参与注意力,或者是一个与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。

过去和当前状态的键/值都是可选的。它们应一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当的重塑后,以下模式应用于 Q、K 和 V 输入:

  The following pattern is applied by this operator:
      Q          K          V
      |          |          |
Q*sqrt(scale) K*sqrt(scale) |
      |          |          |
      |       Transpose     |
      |          |          |
      ---MatMul---          |
            |               |
 at_mask---Add              |
            |               |
  softcap (if provided)     |
            |               |
         Softmax            |
            |               |
            -----MatMul------
                   |
                   Y

属性

  • is_causal - INT(默认为'0'

    如果设置为 1,则当掩码为方阵时,注意力掩码是一个下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。

  • kv_num_heads - INT :

    键和值的头数。必须与 Q、K 和 V 的 3D 输入一起使用。

  • q_num_heads - INT :

    查询的头数。必须与 Q、K 和 V 的 3D 输入一起使用。

  • qk_matmul_output_mode - INT(默认为'0'

    如果设置为 0,qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为 1,qk_matmul_output 包含注意力掩码添加到 qk 矩阵乘法输出的结果。如果设置为 2,qk_matmul_output 是 softcap 操作后的输出。如果设置为 3,qk_matmul_output 是 softmax 操作后的输出。默认值为 0。

  • scale - FLOAT :

    应用于 \(Q*K^T\) 的缩放因子。默认值为 1/sqrt(head_size)。为防止数值溢出,在矩阵乘法之前,将 QKsqrt(scale) 缩放。

  • softcap - FLOAT(默认为'0.0'

    注意力权重的 softcap 值。默认值为 0。

  • softmax_precision - INT :

    softmax 计算中使用的浮点精度。如果未提供 softmax 精度,则使用 softmax 输入(Q 和 K)的相同精度。

输入

3 到 6 个输入之间。

  • Q(异构)- T1

    查询张量。形状为 (batch_size, q_num_heads, q_sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, q_sequence_length, q_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,q_hidden_size = q_num_heads * head_size

  • K(异构)- T1

    键张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, kv_sequence_length, k_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,k_hidden_size = kv_num_heads * head_size

  • V(异构)- T2

    值张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, v_head_size) 的 4D 张量,或形状为 (batch_size, kv_sequence_length, v_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,v_hidden_size = kv_num_heads * v_head_size

  • attn_mask(可选,异构)- U

    注意力掩码。形状必须可广播到形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的 4D 张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length. 支持两种类型的掩码。一个布尔掩码,其中 True 值表示该元素应参与注意力。还支持与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。

  • past_key(可选,异构)- T1

    过去状态缓存的键,形状为 (batch_size, kv_num_heads, past_sequence_length, head_size)

  • past_value(可选,异构)- T2

    过去状态缓存的值,形状为 (batch_size, kv_num_heads, past_sequence_length, v_head_size)

输出

1 到 4 个输出之间。

  • Y(异构)- T1

    输出张量。形状为 (batch_size, q_num_heads, q_sequence_length, v_head_size) 的 4D 张量,或形状为 (batch_size, q_sequence_length, hidden_size) 的 3D 张量。对于 3D 输入张量的情况,hidden_size = q_num_heads * v_head_size

  • present_key(可选,异构)- T1

    更新后的键缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • present_value(可选,异构)- T2

    更新后的值缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, v_head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • qk_matmul_output(可选,异构)- T1

    QK 矩阵乘法的输出。形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的 4D 张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length

类型约束

  • T1 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    将 Q 和 K 输入类型限制为浮点张量。

  • T2 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    将 V 输入类型限制为浮点张量。

  • U 在 ( tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) )

    将输出“mask”类型限制为布尔张量和输入类型。