RotaryEmbedding

RotaryEmbedding - 23

版本

  • 名称: RotaryEmbedding (GitHub)

  • : main

  • 起始版本23

  • 函数: True

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的操作符已可用于版本 23 及以上

摘要

RotaryEmbedding 是基于论文 https://arxiv.org/pdf/2104.09864 实现的旋转位置嵌入 (RoPE)。RoPE 的关键优势在于,它能够让模型同时理解 token 的绝对位置以及 token 之间的相对距离。这通过一种旋转机制实现,旋转的幅度根据 token 的绝对位置 (position_ids) 来计算。

旋转机制由用于表示旋转角度的正弦和余弦函数定义。对于序列中的每个 token,其位置嵌入通过旋转其嵌入向量来计算。这可以通过将嵌入向量分成两半,或者隔一个 token 交叉排列,然后将旋转矩阵应用于嵌入向量的每一半来实现。旋转矩阵由 token 在序列中的位置参数化。旋转后的嵌入向量被连接起来,形成每个 token 的最终位置嵌入。旋转后的位置嵌入用于自注意力机制。旋转确保模型捕捉到绝对和相对的位置信息。

旋转嵌入的定义遵循以下算法

def compute_rotary_embedding(
    input,
    position_ids,
    sin_cache,
    cos_cache,
    interleaved=0,
    rotary_embedding_dim=0,
    num_heads=0,
):
    # First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
    if len(input.shape) == 4:
        input = np.transpose(input, (0, 2, 1, 3))
    batch_size = input.shape[0]
    sequence_length = input.shape[1]
    if len(input.shape) == 3:
        hidden_size = input.shape[2]
        assert num_heads != 0
        head_size = int(hidden_size / num_heads)
        new_shape = [batch_size, sequence_length, num_heads, head_size]
        input = np.reshape(input, new_shape)
    assert len(input.shape) == 4
    head_size = input.shape[3]

    # Fully or partially perform rotation on input based on rotary_embedding_dim attribute
    if rotary_embedding_dim == 0:
        # If rotary_embedding_dim not provided, perform full rotation by using head_size
        rotary_embedding_dim = head_size
    x_rotate = input[:, :, :, :rotary_embedding_dim]
    x_not_rotate = input[:, :, :, rotary_embedding_dim:]
    rotary_embedding_dim_half = int(rotary_embedding_dim / 2)

    # Retrieve sin and cos caches using position ids
    if position_ids is not None:
        cos = cos_cache[position_ids]  # Shape: [batch_size, sequence_length, head_size/2]
        sin = sin_cache[position_ids]  # Shape: [batch_size, sequence_length, head_size/2]
    else:
        cos = cos_cache
        sin = sin_cache
    cos = cos[:, :, :rotary_embedding_dim_half]  # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
    sin = sin[:, :, :rotary_embedding_dim_half]  # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
    cos = np.expand_dims(cos, axis=2)  # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
    sin = np.expand_dims(sin, axis=2)  # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]

    # Either divide the input in halves or interleave (based on interleaved attribute)
    if interleaved:
        x1 = x_rotate[:, :, :, 0::2]
        x2 = x_rotate[:, :, :, 1::2]
    else:
        x1, x2 = np.split(x_rotate, 2, axis=-1)

    # Calculate real and imaginary values
    real = cos * x1 - sin * x2
    imag = sin * x1 + cos * x2

    # Inserted rotated embeddings back to the original input
    if interleaved:
        # x_rotate[:, :, :, 0::2] = real
        # x_rotate[:, :, :, 1::2] = imag
        real = np.expand_dims(real, axis=-1)
        imag = np.expand_dims(imag, axis=-1)
        x_rotate_concat = np.concatenate((real, imag), axis=-1)
        x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
    else:
        x_rotate = np.concatenate((real, imag), axis=-1)
    output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
    if len(original_input_shape) == 3:
        output = np.reshape(output, input.shape)
    else:
        output = np.transpose(output, (0, 2, 1, 3))
    return output

属性

  • interleaved - INT (默认值为 '0')

    使用交错模式进行旋转。默认值为 0 (False)。

  • num_heads - INT :

    注意力头的数量。当输入为 3D 张量时必须提供。

  • rotary_embedding_dim - INT (默认值为 '0')

    用于应用部分旋转嵌入的旋转嵌入维度。

输入

介于 3 和 4 之间的输入。

  • X (异构) - T

    代表 token 嵌入的输入张量。形状为 (batch_size, num_heads, sequence_length, head_size) 的 4D 张量,或形状为 (batch_size, sequence_length, hidden_size) 的 3D 张量。对于 4D 输入张量,head_size 必须是偶数。对于 3D 输入张量,必须提供 num_heads 属性,并且 hidden_size 必须是 num_heads 的偶数倍,其中 hidden_size = num_heads * head_size

  • cos_cache (异构) - T

    旋转的余弦值。当提供 position_ids 时,对于完全旋转,形状为 (max_position_id_plus_1, head_size / 2) 的 2D 张量;对于部分旋转,形状为 (max_position_id_plus_1, rotary_embedding_dim / 2)。当不提供 position_ids 时,对于完全旋转,形状为 (batch_size, sequence_length, head_size / 2) 的 3D 张量;对于部分旋转,形状为 (batch_size, sequence_length, rotary_embedding_dim / 2)max_position_id_plus_1 是模型的参数。

  • sin_cache (异构) - T

    旋转的正弦值。当提供 position_ids 时,对于完全旋转,形状为 (max_position_id_plus_1, head_size / 2) 的 2D 张量;对于部分旋转,形状为 (max_position_id_plus_1, rotary_embedding_dim / 2)。当不提供 position_ids 时,对于完全旋转,形状为 (batch_size, sequence_length, head_size / 2) 的 3D 张量;对于部分旋转,形状为 (batch_size, sequence_length, rotary_embedding_dim / 2)max_position_id_plus_1 是模型的参数。

  • position_ids (可选,异构) - M

    token 的位置索引。形状为 (batch_size, sequence_length) 的 2D 张量

输出

  • Y (异构) - T

    与输入形状相同的张量。

类型约束

  • T 在 ( tensor(bfloat16), tensor(float), tensor(float16) )

    将输入和输出类型限制为浮点张量。

  • M 在 ( tensor(int64) )

    将输入和输出类型约束为整数张量。