Attention¶
Attention - 23¶
版本¶
领域:
main
起始版本:
23
函数:
True
支持级别:
SupportType.COMMON
形状推断:
True
此版本的算子自版本 23 起可用。
摘要¶
计算查询、键和值张量的缩放点积注意力,如果传递了可选的注意力掩码则使用它。
该算子涵盖了基于 K、Q 和 V 序列长度的自注意力和交叉注意力操作变体。
对于自注意力,kv_sequence_length
等于 q_sequence_length
。
对于交叉注意力,查询和键可能具有不同的长度。
该算子还涵盖了基于头数目的以下 3 种变体:
多头注意力 (MHA):论文 https://arxiv.org/pdf/1706.03762 中描述,
q_num_heads = kv_num_heads
。分组查询注意力 (GQA):论文 https://arxiv.org/pdf/2305.13245 中描述,
q_num_heads > kv_num_heads
,且q_num_heads % kv_num_heads == 0
。多查询注意力 (MQA):论文 https://arxiv.org/pdf/1911.02150 中描述,
q_num_heads > kv_num_heads
,且kv_num_heads=1
。
要添加的注意力偏置是根据 attn_mask
输入和 is_causal attribute
计算的,两者只能提供其中一个。
如果
is_causal
设置为1
,当掩码为方阵时,注意力掩码是下三角矩阵。由于对齐,注意力掩码呈现左上角的因果偏置形式。attn_mask
: 一个布尔掩码,其中True
值表示元素应参与注意力计算;或者一个与查询、键、值具有相同类型的浮点掩码,将其添加到注意力得分中。
过去和当前状态的键/值都是可选的。它们应该一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当重塑后,以下模式应用于 Q、K 和 V 输入:
The following pattern is applied by this operator:
Q K V
| | |
Q*scale K*scale |
| | |
| Transpose |
| | |
---MatMul--- |
| |
at_mask---Add |
| |
softcap (if provided) |
| |
Softmax |
| |
-----MatMul------
|
Y
属性¶
is_causal - INT(默认为
'0'
)如果设置为
1
,当掩码为方阵时,注意力掩码是下三角矩阵。由于对齐,注意力掩码呈现左上角的因果偏置形式。kv_num_heads - INT :
键和值的头数。必须与 Q、K 和 V 的 3D 输入一起使用。
q_num_heads - INT :
查询的头数。必须与 Q、K 和 V 的 3D 输入一起使用。
qk_matmul_output_mode - INT(默认为
'0'
)如果设置为
0
,qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为1
,qk_matmul_output 包括将注意力掩码添加到 qk 矩阵乘法的输出。如果设置为2
,qk_matmul_output 是 softcap 操作后的输出。如果设置为3
,qk_matmul_output 是 softmax 操作后的输出。默认值为 0。scale - FLOAT :
应用的缩放因子。为保证稳定性,在矩阵乘法前对 q 和 k 进行缩放,数学原理参见 https://tinyurl.com/sudb9s96。默认值为
1/sqrt(head_size)
softcap - FLOAT(默认为
'0.0'
)注意力权重的 Softcap 值。默认值为 0。
softmax_precision - INT :
用于 softmax 计算的浮点精度。如果未提供 softmax 精度,则使用与 softmax 输入(Q 和 K)相同的精度。
输入¶
输入在 3 到 6 个之间。
Q (异构) - T1
查询张量。形状为
(batch_size, q_num_heads, q_sequence_length, head_size)
的 4D 张量,或形状为(batch_size, q_sequence_length, q_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,q_hidden_size = q_num_heads * head_size
。K (异构) - T1
键张量。形状为
(batch_size, kv_num_heads, kv_sequence_length, head_size)
的 4D 张量,或形状为(batch_size, kv_sequence_length, k_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,k_hidden_size = kv_num_heads * head_size
。V (异构) - T2
值张量。形状为
(batch_size, kv_num_heads, kv_sequence_length, v_head_size)
的 4D 张量,或形状为(batch_size, kv_sequence_length, v_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,v_hidden_size = kv_num_heads * v_head_size
。attn_mask (可选,异构) - U
注意力掩码。形状必须可广播到形状为
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
的 4D 张量,其中total_sequence_length = past_sequence_length + kv_sequence_length.
支持两种类型的掩码。一种是布尔掩码,其中值为True
表示该元素应参与注意力计算。另一种是与查询、键、值具有相同类型的浮点掩码,将其添加到注意力得分中。past_key (可选,异构) - T1
键的过去状态缓存,形状为
(batch_size, kv_num_heads, past_sequence_length, head_size)
past_value (可选,异构) - T2
值的过去状态缓存,形状为
(batch_size, kv_num_heads, past_sequence_length, v_head_size)
输出¶
输出在 1 到 4 个之间。
Y (异构) - T1
输出张量。形状为
(batch_size, q_num_heads, q_sequence_length, v_head_size)
的 4D 张量,或形状为(batch_size, q_sequence_length, hidden_size)
的 3D 张量。对于 3D 输入张量的情况,hidden_size = q_num_heads * v_head_size
。present_key (可选,异构) - T1
更新后的键缓存,形状为
(batch_size, kv_num_heads, total_sequence_length, head_size)
,其中total_sequence_length = past_sequence_length + kv_sequence_length
。present_value (可选,异构) - T2
更新后的值缓存,形状为
(batch_size, kv_num_heads, total_sequence_length, v_head_size)
,其中total_sequence_length = past_sequence_length + kv_sequence_length
。qk_matmul_output (可选,异构) - T1
QK 矩阵乘法的输出。形状为
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
的 4D 张量,其中total_sequence_length = past_sequence_length + kv_sequence_length
。
类型约束¶
T1 类型范围 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将 Q 和 K 输入类型约束为浮点张量。
T2 类型范围 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将 V 输入类型约束为浮点张量。
U 类型范围 (
tensor(bfloat16)
,tensor(bool)
,tensor(double)
,tensor(float)
,tensor(float16)
,tensor(int16)
,tensor(int32)
,tensor(int64)
,tensor(int8)
,tensor(uint16)
,tensor(uint32)
,tensor(uint64)
,tensor(uint8)
)将输出“mask”类型约束为布尔张量和输入类型。