注意¶
注意 - 24¶
版本¶
域:
main
起始版本:
24
函数:
True
支持级别:
SupportType.COMMON
形状推断:
True
此版本的操作符已可用于版本 24 及以上。
摘要¶
对查询、键和值张量计算缩放点积注意力,如果传入可选的注意力掩码则使用该掩码。
此操作符涵盖了基于 K、Q 和 V 序列长度的注意力操作的自注意力和交叉注意力变体。
对于自注意力,kv_sequence_length
等于 q_sequence_length
。
对于交叉注意力,查询和键的长度可能不同。
此操作符还涵盖了基于头数目的以下 3 种变体:
多头注意力 (MHA):在论文 https://arxiv.org/pdf/1706.03762 中描述,
q_num_heads = kv_num_heads
。组查询注意力 (GQA):在论文 https://arxiv.org/pdf/2305.13245 中描述,
q_num_heads > kv_num_heads
,q_num_heads % kv_num_heads == 0
。多查询注意力 (MQA):在论文 https://arxiv.org/pdf/1911.02150 中描述,
q_num_heads > kv_num_heads
,kv_num_heads=1
。
要添加的注意力偏置是根据 attn_mask
输入和 is_causal
属性计算的。
attn_mask
:一个布尔掩码,其中True
值表示该元素应参与注意力,或者是一个与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。如果
is_causal
设置为1
,则无论attn_mask
输入如何,对角线上方的注意力分数都会被掩盖。
关于 KV 缓存更新,此操作符允许以下两种用例:
缓存更新发生在 Attention 操作符内部。在这种情况下,
K
和V
输入仅包含当前自回归步骤的传入标记,并且所有四个可选的输入/输出(过去和当前的键和值)都是必需的。Attention 操作会分别对过去和传入的键和值执行 Concat 操作,以形成当前的键和值。请注意,这仅适用于过去键和值不包含填充标记的特殊情况。缓存更新发生在 Attention 操作符外部(例如,通过
TensorScatter
操作符)。在这种情况下,K
和V
输入对应于整个缓存张量,因此不应使用四个可选的输入/输出(过去和当前的键和值)。可以提供一个形状为 (batch_size,) 的附加输入nonpad_kv_seqlen
,以指示批次中每个样本的非填充标记数量,从而节省不必要的计算。这里,attn_mask
的 kv_sequence 维度可以比K
和V
短,但仍然需要至少与nonpad_kv_seqlen
的最大值一样长。
过去和当前状态的键/值都是可选的。它们应一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当的重塑后,以下模式应用于 Q、K 和 V 输入:
The following pattern is applied by this operator:
Q K V
| | |
Q*sqrt(scale) K*sqrt(scale) |
| | |
| Transpose |
| | |
---MatMul--- |
| |
at_mask---Add |
| |
softcap (if provided) |
| |
Softmax |
| |
-----MatMul------
|
Y
属性¶
is_causal - INT(默认为
'0'
)如果设置为
1
,则当掩码为方阵时,注意力掩码是一个下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。kv_num_heads - INT :
键和值的头数。必须与 Q、K 和 V 的 3D 输入一起使用。
q_num_heads - INT :
查询的头数。必须与 Q、K 和 V 的 3D 输入一起使用。
qk_matmul_output_mode - INT(默认为
'0'
)如果设置为
0
,qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为1
,qk_matmul_output 包含注意力掩码添加到 qk 矩阵乘法输出的结果。如果设置为2
,qk_matmul_output 是 softcap 操作后的输出。如果设置为3
,qk_matmul_output 是 softmax 操作后的输出。默认值为 0。scale - FLOAT :
应用于 \(Q*K^T\) 的缩放因子。默认值为
1/sqrt(head_size)
。为防止数值溢出,在矩阵乘法之前,将Q
、K
按sqrt(scale)
缩放。softcap - FLOAT(默认为
'0.0'
)注意力权重的 softcap 值。默认值为 0。
softmax_precision - INT :
softmax 计算中使用的浮点精度。如果未提供 softmax 精度,则使用 softmax 输入(Q 和 K)的相同精度。
输入¶
3 到 7 个输入之间。
Q(异构)- T1
查询张量。形状为
(batch_size, q_num_heads, q_sequence_length, head_size)
的 4D 张量,或形状为(batch_size, q_sequence_length, q_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,q_hidden_size = q_num_heads * head_size
K(异构)- T1
键张量。形状为
(batch_size, kv_num_heads, kv_sequence_length, head_size)
的 4D 张量,或形状为(batch_size, kv_sequence_length, k_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,k_hidden_size = kv_num_heads * head_size
V(异构)- T2
值张量。形状为
(batch_size, kv_num_heads, kv_sequence_length, v_head_size)
的 4D 张量,或形状为(batch_size, kv_sequence_length, v_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,v_hidden_size = kv_num_heads * v_head_size
attn_mask(可选,异构)- U
注意力掩码。形状必须可广播到
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
,其中total_sequence_length = past_sequence_length + kv_sequence_length.
最后一个维度也可以比total_sequence_length
短,并将用负无穷填充到total_sequence_length
。支持两种类型的掩码:一个布尔掩码,其中True
值表示该元素应参与注意力;或者是一个与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。past_key(可选,异构)- T1
过去状态缓存的键,形状为
(batch_size, kv_num_heads, past_sequence_length, head_size)
past_value(可选,异构)- T2
过去状态缓存的值,形状为
(batch_size, kv_num_heads, past_sequence_length, v_head_size)
nonpad_kv_seqlen(可选,异构)- tensor(int64)
一个形状为
(batch_size,)
的整数向量,表示每个样本中有效(即非填充)标记的数量。可以由此推导出填充掩码。此输入不应与past_key
和past_value
输入或present_key
和present_value
输出一起使用(参见操作符描述中的 KV 缓存用例)。
输出¶
1 到 4 个输出之间。
Y(异构)- T1
输出张量。形状为
(batch_size, q_num_heads, q_sequence_length, v_head_size)
的 4D 张量,或形状为(batch_size, q_sequence_length, hidden_size)
的 3D 张量。对于 3D 输入张量的情况,hidden_size = q_num_heads * v_head_size
present_key(可选,异构)- T1
更新后的键缓存,形状为
(batch_size, kv_num_heads, total_sequence_length, head_size)
,其中total_sequence_length = past_sequence_length + kv_sequence_length
。present_value(可选,异构)- T2
更新后的值缓存,形状为
(batch_size, kv_num_heads, total_sequence_length, v_head_size)
,其中total_sequence_length = past_sequence_length + kv_sequence_length
。qk_matmul_output(可选,异构)- T1
QK 矩阵乘法的输出。形状为
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
的 4D 张量,其中total_sequence_length = past_sequence_length + kv_sequence_length
。
类型约束¶
T1 在 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将 Q 和 K 输入类型限制为浮点张量。
T2 在 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将 V 输入类型限制为浮点张量。
U 在 (
tensor(bfloat16)
,tensor(bool)
,tensor(double)
,tensor(float)
,tensor(float16)
,tensor(int16)
,tensor(int32)
,tensor(int64)
,tensor(int8)
,tensor(uint16)
,tensor(uint32)
,tensor(uint64)
,tensor(uint8)
)将输出“mask”类型限制为布尔张量和输入类型。
注意 - 23¶
版本¶
域:
main
起始版本:
23
函数:
True
支持级别:
SupportType.COMMON
形状推断:
True
此版本的操作符已可用于版本 23 及以上。
摘要¶
对查询、键和值张量计算缩放点积注意力,如果传入可选的注意力掩码则使用该掩码。
此操作符涵盖了基于 K、Q 和 V 序列长度的注意力操作的自注意力和交叉注意力变体。
对于自注意力,kv_sequence_length
等于 q_sequence_length
。
对于交叉注意力,查询和键的长度可能不同。
此操作符还涵盖了基于头数目的以下 3 种变体:
多头注意力 (MHA):在论文 https://arxiv.org/pdf/1706.03762 中描述,
q_num_heads = kv_num_heads
。组查询注意力 (GQA):在论文 https://arxiv.org/pdf/2305.13245 中描述,
q_num_heads > kv_num_heads
,q_num_heads % kv_num_heads == 0
。多查询注意力 (MQA):在论文 https://arxiv.org/pdf/1911.02150 中描述,
q_num_heads > kv_num_heads
,kv_num_heads=1
。
要添加的注意力偏置是根据 attn_mask
输入和 is_causal attribute
计算的,只能提供其中之一。
如果
is_causal
设置为1
,则当掩码为方阵时,注意力掩码是一个下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。attn_mask
:一个布尔掩码,其中True
值表示该元素应参与注意力,或者是一个与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。
过去和当前状态的键/值都是可选的。它们应一起使用,不允许只使用其中一个。在根据提供的序列长度和头数对 K 和 V 输入进行适当的重塑后,以下模式应用于 Q、K 和 V 输入:
The following pattern is applied by this operator:
Q K V
| | |
Q*sqrt(scale) K*sqrt(scale) |
| | |
| Transpose |
| | |
---MatMul--- |
| |
at_mask---Add |
| |
softcap (if provided) |
| |
Softmax |
| |
-----MatMul------
|
Y
属性¶
is_causal - INT(默认为
'0'
)如果设置为
1
,则当掩码为方阵时,注意力掩码是一个下三角矩阵。由于对齐,注意力掩码具有左上角因果偏置的形式。kv_num_heads - INT :
键和值的头数。必须与 Q、K 和 V 的 3D 输入一起使用。
q_num_heads - INT :
查询的头数。必须与 Q、K 和 V 的 3D 输入一起使用。
qk_matmul_output_mode - INT(默认为
'0'
)如果设置为
0
,qk_matmul_output 是 qk 矩阵乘法的输出。如果设置为1
,qk_matmul_output 包含注意力掩码添加到 qk 矩阵乘法输出的结果。如果设置为2
,qk_matmul_output 是 softcap 操作后的输出。如果设置为3
,qk_matmul_output 是 softmax 操作后的输出。默认值为 0。scale - FLOAT :
应用于 \(Q*K^T\) 的缩放因子。默认值为
1/sqrt(head_size)
。为防止数值溢出,在矩阵乘法之前,将Q
、K
按sqrt(scale)
缩放。softcap - FLOAT(默认为
'0.0'
)注意力权重的 softcap 值。默认值为 0。
softmax_precision - INT :
softmax 计算中使用的浮点精度。如果未提供 softmax 精度,则使用 softmax 输入(Q 和 K)的相同精度。
输入¶
3 到 6 个输入之间。
Q(异构)- T1
查询张量。形状为
(batch_size, q_num_heads, q_sequence_length, head_size)
的 4D 张量,或形状为(batch_size, q_sequence_length, q_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,q_hidden_size = q_num_heads * head_size
K(异构)- T1
键张量。形状为
(batch_size, kv_num_heads, kv_sequence_length, head_size)
的 4D 张量,或形状为(batch_size, kv_sequence_length, k_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,k_hidden_size = kv_num_heads * head_size
V(异构)- T2
值张量。形状为
(batch_size, kv_num_heads, kv_sequence_length, v_head_size)
的 4D 张量,或形状为(batch_size, kv_sequence_length, v_hidden_size)
的 3D 张量。对于 3D 输入张量的情况,v_hidden_size = kv_num_heads * v_head_size
attn_mask(可选,异构)- U
注意力掩码。形状必须可广播到形状为
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
的 4D 张量,其中total_sequence_length = past_sequence_length + kv_sequence_length.
支持两种类型的掩码。一个布尔掩码,其中True
值表示该元素应参与注意力。还支持与查询、键、值相同类型的浮点掩码,该掩码会添加到注意力分数中。past_key(可选,异构)- T1
过去状态缓存的键,形状为
(batch_size, kv_num_heads, past_sequence_length, head_size)
past_value(可选,异构)- T2
过去状态缓存的值,形状为
(batch_size, kv_num_heads, past_sequence_length, v_head_size)
输出¶
1 到 4 个输出之间。
Y(异构)- T1
输出张量。形状为
(batch_size, q_num_heads, q_sequence_length, v_head_size)
的 4D 张量,或形状为(batch_size, q_sequence_length, hidden_size)
的 3D 张量。对于 3D 输入张量的情况,hidden_size = q_num_heads * v_head_size
present_key(可选,异构)- T1
更新后的键缓存,形状为
(batch_size, kv_num_heads, total_sequence_length, head_size)
,其中total_sequence_length = past_sequence_length + kv_sequence_length
。present_value(可选,异构)- T2
更新后的值缓存,形状为
(batch_size, kv_num_heads, total_sequence_length, v_head_size)
,其中total_sequence_length = past_sequence_length + kv_sequence_length
。qk_matmul_output(可选,异构)- T1
QK 矩阵乘法的输出。形状为
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
的 4D 张量,其中total_sequence_length = past_sequence_length + kv_sequence_length
。
类型约束¶
T1 在 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将 Q 和 K 输入类型限制为浮点张量。
T2 在 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将 V 输入类型限制为浮点张量。
U 在 (
tensor(bfloat16)
,tensor(bool)
,tensor(double)
,tensor(float)
,tensor(float16)
,tensor(int16)
,tensor(int32)
,tensor(int64)
,tensor(int8)
,tensor(uint16)
,tensor(uint32)
,tensor(uint64)
,tensor(uint8)
)将输出“mask”类型限制为布尔张量和输入类型。