Attention

Attention - 23

版本

  • 名称: Attention (GitHub)

  • 领域: main

  • 起始版本: 23

  • 函数: True

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的算子自版本 23 起可用。

摘要

计算查询、键和值张量的缩放点积注意力,如果传递了可选的注意力掩码则使用它。

该算子涵盖了基于 K、Q 和 V 序列长度的自注意力和交叉注意力操作变体。

对于自注意力,kv_sequence_length 等于 q_sequence_length

对于交叉注意力,查询和键可能具有不同的长度。

该算子还涵盖了基于头数目的以下 3 种变体:

  1. 多头注意力 (MHA):论文 https://arxiv.org/pdf/1706.03762 中描述,q_num_heads = kv_num_heads

  2. 分组查询注意力 (GQA):论文 https://arxiv.org/pdf/2305.13245 中描述,q_num_heads > kv_num_heads,且 q_num_heads % kv_num_heads == 0

  3. 多查询注意力 (MQA):论文 https://arxiv.org/pdf/1911.02150 中描述,q_num_heads > kv_num_heads,且 kv_num_heads=1

要添加的注意力偏置是根据 attn_mask 输入和 is_causal attribute 计算的,两者只能提供其中一个。

  1. 如果 is_causal 设置为 1,当掩码为方阵时,注意力掩码是下三角矩阵。由于对齐,注意力掩码呈现左上角的因果偏置形式。

  2. attn_mask: 一个布尔掩码,其中 True 值表示元素应参与注意力计算;或者一个与查询、键、值具有相同类型的浮点掩码,将其添加到注意力得分中。

过去和当前状态的键/值都是可选的。它们应该一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当重塑后,以下模式应用于 Q、K 和 V 输入:

  The following pattern is applied by this operator:
      Q          K          V
      |          |          |
     Q*scale     K*scale    |
      |          |          |
      |       Transpose     |
      |          |          |
      ---MatMul---          |
            |               |
 at_mask---Add              |
            |               |
  softcap (if provided)     |
            |               |
         Softmax            |
            |               |
            -----MatMul------
                   |
                   Y

属性

  • is_causal - INT(默认为 '0'

    如果设置为 1,当掩码为方阵时,注意力掩码是下三角矩阵。由于对齐,注意力掩码呈现左上角的因果偏置形式。

  • kv_num_heads - INT :

    键和值的头数。必须与 Q、K 和 V 的 3D 输入一起使用。

  • q_num_heads - INT :

    查询的头数。必须与 Q、K 和 V 的 3D 输入一起使用。

  • qk_matmul_output_mode - INT(默认为 '0'

    如果设置为 0,qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为 1,qk_matmul_output 包括将注意力掩码添加到 qk 矩阵乘法的输出。如果设置为 2,qk_matmul_output 是 softcap 操作后的输出。如果设置为 3,qk_matmul_output 是 softmax 操作后的输出。默认值为 0。

  • scale - FLOAT :

    应用的缩放因子。为保证稳定性,在矩阵乘法前对 q 和 k 进行缩放,数学原理参见 https://tinyurl.com/sudb9s96。默认值为 1/sqrt(head_size)

  • softcap - FLOAT(默认为 '0.0'

    注意力权重的 Softcap 值。默认值为 0。

  • softmax_precision - INT :

    用于 softmax 计算的浮点精度。如果未提供 softmax 精度,则使用与 softmax 输入(Q 和 K)相同的精度。

输入

输入在 3 到 6 个之间。

  • Q (异构) - T1

    查询张量。形状为 (batch_size, q_num_heads, q_sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, q_sequence_length, q_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,q_hidden_size = q_num_heads * head_size

  • K (异构) - T1

    键张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, kv_sequence_length, k_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,k_hidden_size = kv_num_heads * head_size

  • V (异构) - T2

    值张量。形状为 (batch_size, kv_num_heads, kv_sequence_length, v_head_size) 的 4D 张量,或形状为 (batch_size, kv_sequence_length, v_hidden_size) 的 3D 张量。对于 3D 输入张量的情况,v_hidden_size = kv_num_heads * v_head_size

  • attn_mask (可选,异构) - U

    注意力掩码。形状必须可广播到形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的 4D 张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length. 支持两种类型的掩码。一种是布尔掩码,其中值为 True 表示该元素应参与注意力计算。另一种是与查询、键、值具有相同类型的浮点掩码,将其添加到注意力得分中。

  • past_key (可选,异构) - T1

    键的过去状态缓存,形状为 (batch_size, kv_num_heads, past_sequence_length, head_size)

  • past_value (可选,异构) - T2

    值的过去状态缓存,形状为 (batch_size, kv_num_heads, past_sequence_length, v_head_size)

输出

输出在 1 到 4 个之间。

  • Y (异构) - T1

    输出张量。形状为 (batch_size, q_num_heads, q_sequence_length, v_head_size) 的 4D 张量,或形状为 (batch_size, q_sequence_length, hidden_size) 的 3D 张量。对于 3D 输入张量的情况,hidden_size = q_num_heads * v_head_size

  • present_key (可选,异构) - T1

    更新后的键缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • present_value (可选,异构) - T2

    更新后的值缓存,形状为 (batch_size, kv_num_heads, total_sequence_length, v_head_size),其中 total_sequence_length = past_sequence_length + kv_sequence_length

  • qk_matmul_output (可选,异构) - T1

    QK 矩阵乘法的输出。形状为 (batch_size, q_num_heads, q_sequence_length, total_sequence_length) 的 4D 张量,其中 total_sequence_length = past_sequence_length + kv_sequence_length

类型约束

  • T1 类型范围 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    将 Q 和 K 输入类型约束为浮点张量。

  • T2 类型范围 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    将 V 输入类型约束为浮点张量。

  • U 类型范围 ( tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) )

    将输出“mask”类型约束为布尔张量和输入类型。