TensorScatter¶
TensorScatter - 24¶
版本¶
域:
main
起始版本:
24
函数:
False
支持级别:
SupportType.COMMON
形状推断:
True
此版本的操作符已可用于版本 24 及以上。
摘要¶
TensorScatter 是一种通用的张量更新操作,其动机是满足 LLM 中常见的 Attention 操作的 KV 缓存更新需求。它是一种函数式操作,模拟对 KV 缓存缓冲区的原地更新。
过去和现在的缓存张量具有相同的形状(batch_size, D1, D2, …, max_sequence_length, …, Dn),其中序列维度(由 axis
属性指示)为 max_sequence_length,因此这些张量的大小在迭代之间无需增长。update
张量的形状仅在序列维度上与缓存张量不同:(batch_size, D1, D2, …, sequence_length, …, Dn),其中 sequence_length <= max_sequence_length。
可选的 write_indices
输入指示批次中每个样本的写入索引,如果未提供则假定为零。当 mode
属性设置为“circular”时,写入索引是 max_sequence_length 的模。该操作可以用以下伪代码描述:
for prefix_idx in np.ndindex(past_cache.shape[:axis]):
batch_idx = prefix_idx[0]
for sequence_idx in range(sequence_length):
cache_idx = (*prefix_idx, write_indices[batch_idx] + sequence_idx)
if mode == "circular":
cache_idx = tuple(np.mod(np.asarray(cache_idx), max_sequence_length))
update_idx = (*prefix_idx, sequence_idx)
present_cache[cache_idx] = update[update_idx]
在 Attention 的预填充(prefill)阶段,只需要前两个输入。在解码(decode)阶段,还需要 write_indices
,以便将传入的键或值更新追加到批次中每个样本的最后一个有效 token 之后。
属性¶
axis - INT (默认为
'-2'
)past_cache
和update
张量的序列维度。不能为 0(批次维度)。默认为 -2。mode - STRING (默认为
'linear'
)缓存更新的写入模式。支持的模式包括
linear
和circular
。linear
模式要求 write_indices+sequence_length<=max_sequence_length。对于circular
模式,更新以循环方式发生,即更新索引是max_sequence_length
的模。
输入¶
2 到 3 个输入之间。
past_cache (异构) - T
过去的状态缓存,用于键或值,形状为
(batch_size, D1, D2, ..., max_sequence_length, ..., Dn)
。update (异构) - T
新的更新张量,形状为
(batch_size, D1, D2, ..., sequence_length, ..., Dn)
。write_indices (可选, 异构) - tensor(int64)
输入更新张量在缓存中的写入索引。形状为
(batch_size,)
。如果未提供,则假定为全零。
输出¶
present_cache (异构) - T
更新后的缓存。形状与
past_cache
相同。
类型约束¶
T (
tensor(bfloat16)
,tensor(bool)
,tensor(complex128)
,tensor(complex64)
,tensor(double)
,tensor(float)
,tensor(float16)
,tensor(float4e2m1)
,tensor(float8e4m3fn)
,tensor(float8e4m3fnuz)
,tensor(float8e5m2)
,tensor(float8e5m2fnuz)
,tensor(float8e8m0)
,tensor(int16)
,tensor(int32)
,tensor(int4)
,tensor(int64)
,tensor(int8)
,tensor(string)
,tensor(uint16)
,tensor(uint32)
,tensor(uint4)
,tensor(uint64)
,tensor(uint8)
)将输入和输出类型限制为任何张量类型。