RotaryEmbedding

RotaryEmbedding - 23

版本

  • **名称**: RotaryEmbedding (GitHub)

  • **域**: main

  • **起始版本**: 23

  • **函数**: True

  • **支持级别**: SupportType.COMMON

  • **形状推断**: True

此版本的算子已从 **版本 23** 开始可用。

摘要

RotaryEmbedding 是基于论文 https://arxiv.org/pdf/2104.09864 实现的旋转位置嵌入 (RoPE)。RoPE 的主要优势在于它允许模型理解令牌的绝对位置以及令牌之间的相对距离。这是通过一个旋转机制实现的,其中旋转量根据令牌的绝对位置 (position_ids) 计算。

旋转机制由用于表示旋转角度的正弦和余弦函数定义。对于序列中的每个令牌,其位置嵌入通过旋转其嵌入向量来计算。这可以通过将嵌入向量分成两半或隔位交错并对嵌入向量的每一半应用旋转矩阵来实现。旋转矩阵由令牌在序列中的位置参数化。嵌入向量的旋转后的两半被连接起来,形成每个令牌的最终位置嵌入。旋转后的位置嵌入用于自注意力机制。旋转确保模型捕获绝对和相对位置信息。

旋转嵌入使用以下算法定义

def compute_rotary_embedding(
    input,
    position_ids,
    sin_cache,
    cos_cache,
    interleaved=0,
    rotary_embedding_dim=0,
    num_heads=0,
):
    # First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
    if len(input.shape) == 4:
        input = np.transpose(input, (0, 2, 1, 3))
    batch_size = input.shape[0]
    sequence_length = input.shape[1]
    if len(input.shape) == 3:
        hidden_size = input.shape[2]
        assert num_heads != 0
        head_size = int(hidden_size / num_heads)
        new_shape = [batch_size, sequence_length, num_heads, head_size]
        input = np.reshape(input, new_shape)
    assert len(input.shape) == 4
    head_size = input.shape[3]

    # Fully or partially perform rotation on input based on rotary_embedding_dim attribute
    if rotary_embedding_dim == 0:
        # If rotary_embedding_dim not provided, perform full rotation by using head_size
        rotary_embedding_dim = head_size
    x_rotate = input[:, :, :, :rotary_embedding_dim]
    x_not_rotate = input[:, :, :, rotary_embedding_dim:]
    rotary_embedding_dim_half = int(rotary_embedding_dim / 2)

    # Retrieve sin and cos caches using position ids
    if position_ids is not None:
        cos = cos_cache[position_ids]  # Shape: [batch_size, sequence_length, head_size/2]
        sin = sin_cache[position_ids]  # Shape: [batch_size, sequence_length, head_size/2]
    else:
        cos = cos_cache
        sin = sin_cache
    cos = cos[:, :, :rotary_embedding_dim_half]  # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
    sin = sin[:, :, :rotary_embedding_dim_half]  # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
    cos = np.expand_dims(cos, axis=2)  # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
    sin = np.expand_dims(sin, axis=2)  # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]

    # Either divide the input in halves or interleave (based on interleaved attribute)
    if interleaved:
        x1 = x_rotate[:, :, :, 0::2]
        x2 = x_rotate[:, :, :, 1::2]
    else:
        x1, x2 = np.split(x_rotate, 2, axis=-1)

    # Calculate real and imaginary values
    real = cos * x1 - sin * x2
    imag = sin * x1 + cos * x2

    # Inserted rotated embeddings back to the original input
    if interleaved:
        # x_rotate[:, :, :, 0::2] = real
        # x_rotate[:, :, :, 1::2] = imag
        real = np.expand_dims(real, axis=-1)
        imag = np.expand_dims(imag, axis=-1)
        x_rotate_concat = np.concatenate((real, imag), axis=-1)
        x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
    else:
        x_rotate = np.concatenate((real, imag), axis=-1)
    output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
    if len(original_input_shape) == 3:
        output = np.reshape(output, input.shape)
    else:
        output = np.transpose(output, (0, 2, 1, 3))
    return output

属性

  • **interleaved - INT** (默认为 '0')

    使用交错模式进行旋转。默认值为 0 (False)。

  • num_heads - INT :

    注意力头的数量。当输入是 3D 张量时必须提供。

  • **rotary_embedding_dim - INT** (默认为 '0')

    用于应用部分旋转嵌入的旋转嵌入维度。

输入

3 到 4 个输入。

  • **X** (异构) - **T**

    表示令牌嵌入的输入张量。形状为 (batch_size, num_heads, sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, sequence_length, hidden_size) 的 3D 张量。对于 4D 输入张量的情况,head_size 必须是偶数。对于 3D 输入张量的情况,必须提供 num_heads 属性,并且 hidden_size 必须是 num_heads 的偶数倍,其中 hidden_size = num_heads * head_size

  • **cos_cache** (异构) - **T**

    旋转的余弦值。当提供 position_ids 时,形状为 (max_position_id_plus_1, head_size / 2) 的 2D 张量用于完全旋转,或形状为 (max_position_id_plus_1, rotary_embedding_dim / 2) 的 2D 张量用于部分旋转。当未提供 position_ids 时,形状为 (batch_size, sequence_length, head_size / 2) 的 3D 张量用于完全旋转,或形状为 (batch_size, sequence_length, rotary_embedding_dim / 2) 的 3D 张量用于部分旋转。max_position_id_plus_1 是模型的参数。

  • **sin_cache** (异构) - **T**

    旋转的正弦值。当提供 position_ids 时,形状为 (max_position_id_plus_1, head_size / 2) 的 2D 张量用于完全旋转,或形状为 (max_position_id_plus_1, rotary_embedding_dim / 2) 的 2D 张量用于部分旋转。当未提供 position_ids 时,形状为 (batch_size, sequence_length, head_size / 2) 的 3D 张量用于完全旋转,或形状为 (batch_size, sequence_length, rotary_embedding_dim / 2) 的 3D 张量用于部分旋转。max_position_id_plus_1 是模型的参数。

  • **position_ids** (可选, 异构) - **M**

    令牌的位置索引。形状为 (batch_size, sequence_length) 的 2D 张量

输出

  • **Y** (异构) - **T**

    与输入形状相同的张量。

类型约束

  • **T** 属于 ( tensor(bfloat16), tensor(float), tensor(float16) )

    将输入和输出类型约束为浮点张量。

  • **M** 属于 ( tensor(int64) )

    将输入和输出类型约束为整数张量。