TensorScatter¶
TensorScatter - 24¶
版本¶
域:
main起始版本:
24函数:
False支持级别:
SupportType.COMMON形状推断:
True
此版本的操作符已可用于版本 24 及以上。
摘要¶
TensorScatter 是一种通用的张量更新操作,其动机是满足 LLM 中常见的 Attention 操作的 KV 缓存更新需求。它是一种函数式操作,模拟对 KV 缓存缓冲区的原地更新。
过去和现在的缓存张量具有相同的形状(batch_size, D1, D2, …, max_sequence_length, …, Dn),其中序列维度(由 axis 属性指示)为 max_sequence_length,因此这些张量的大小在迭代之间无需增长。update 张量的形状仅在序列维度上与缓存张量不同:(batch_size, D1, D2, …, sequence_length, …, Dn),其中 sequence_length <= max_sequence_length。
可选的 write_indices 输入指示批次中每个样本的写入索引,如果未提供则假定为零。当 mode 属性设置为“circular”时,写入索引是 max_sequence_length 的模。该操作可以用以下伪代码描述:
for prefix_idx in np.ndindex(past_cache.shape[:axis]):
batch_idx = prefix_idx[0]
for sequence_idx in range(sequence_length):
cache_idx = (*prefix_idx, write_indices[batch_idx] + sequence_idx)
if mode == "circular":
cache_idx = tuple(np.mod(np.asarray(cache_idx), max_sequence_length))
update_idx = (*prefix_idx, sequence_idx)
present_cache[cache_idx] = update[update_idx]
在 Attention 的预填充(prefill)阶段,只需要前两个输入。在解码(decode)阶段,还需要 write_indices,以便将传入的键或值更新追加到批次中每个样本的最后一个有效 token 之后。
属性¶
axis - INT (默认为
'-2')past_cache和update张量的序列维度。不能为 0(批次维度)。默认为 -2。mode - STRING (默认为
'linear')缓存更新的写入模式。支持的模式包括
linear和circular。linear模式要求 write_indices+sequence_length<=max_sequence_length。对于circular模式,更新以循环方式发生,即更新索引是max_sequence_length的模。
输入¶
2 到 3 个输入之间。
past_cache (异构) - T
过去的状态缓存,用于键或值,形状为
(batch_size, D1, D2, ..., max_sequence_length, ..., Dn)。update (异构) - T
新的更新张量,形状为
(batch_size, D1, D2, ..., sequence_length, ..., Dn)。write_indices (可选, 异构) - tensor(int64)
输入更新张量在缓存中的写入索引。形状为
(batch_size,)。如果未提供,则假定为全零。
输出¶
present_cache (异构) - T
更新后的缓存。形状与
past_cache相同。
类型约束¶
T (
tensor(bfloat16),tensor(bool),tensor(complex128),tensor(complex64),tensor(double),tensor(float),tensor(float16),tensor(float4e2m1),tensor(float8e4m3fn),tensor(float8e4m3fnuz),tensor(float8e5m2),tensor(float8e5m2fnuz),tensor(float8e8m0),tensor(int16),tensor(int32),tensor(int4),tensor(int64),tensor(int8),tensor(string),tensor(uint16),tensor(uint32),tensor(uint4),tensor(uint64),tensor(uint8))将输入和输出类型限制为任何张量类型。