TensorScatter

TensorScatter - 24

版本

  • 名称: TensorScatter (GitHub)

  • : main

  • 起始版本24

  • 函数: False

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的操作符已可用于版本 24 及以上

摘要

TensorScatter 是一种通用的张量更新操作,其动机是满足 LLM 中常见的 Attention 操作的 KV 缓存更新需求。它是一种函数式操作,模拟对 KV 缓存缓冲区的原地更新。

过去和现在的缓存张量具有相同的形状(batch_size, D1, D2, …, max_sequence_length, …, Dn),其中序列维度(由 axis 属性指示)为 max_sequence_length,因此这些张量的大小在迭代之间无需增长。update 张量的形状仅在序列维度上与缓存张量不同:(batch_size, D1, D2, …, sequence_length, …, Dn),其中 sequence_length <= max_sequence_length。

可选的 write_indices 输入指示批次中每个样本的写入索引,如果未提供则假定为零。当 mode 属性设置为“circular”时,写入索引是 max_sequence_length 的模。该操作可以用以下伪代码描述:

for prefix_idx in np.ndindex(past_cache.shape[:axis]):
    batch_idx = prefix_idx[0]
    for sequence_idx in range(sequence_length):
        cache_idx = (*prefix_idx, write_indices[batch_idx] + sequence_idx)
        if mode == "circular":
            cache_idx = tuple(np.mod(np.asarray(cache_idx), max_sequence_length))
        update_idx = (*prefix_idx, sequence_idx)
        present_cache[cache_idx] = update[update_idx]

在 Attention 的预填充(prefill)阶段,只需要前两个输入。在解码(decode)阶段,还需要 write_indices,以便将传入的键或值更新追加到批次中每个样本的最后一个有效 token 之后。

属性

  • axis - INT (默认为 '-2')

    past_cacheupdate 张量的序列维度。不能为 0(批次维度)。默认为 -2。

  • mode - STRING (默认为 'linear')

    缓存更新的写入模式。支持的模式包括 linearcircularlinear 模式要求 write_indices+sequence_length<=max_sequence_length。对于 circular 模式,更新以循环方式发生,即更新索引是 max_sequence_length 的模。

输入

2 到 3 个输入之间。

  • past_cache (异构) - T

    过去的状态缓存,用于键或值,形状为 (batch_size, D1, D2, ..., max_sequence_length, ..., Dn)

  • update (异构) - T

    新的更新张量,形状为 (batch_size, D1, D2, ..., sequence_length, ..., Dn)

  • write_indices (可选, 异构) - tensor(int64)

    输入更新张量在缓存中的写入索引。形状为 (batch_size,)。如果未提供,则假定为全零。

输出

  • present_cache (异构) - T

    更新后的缓存。形状与 past_cache 相同。

类型约束

  • T ( tensor(bfloat16), tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(float8e8m0), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8) )

    将输入和输出类型限制为任何张量类型。