MatMulInteger

MatMulInteger - 10

版本

  • 名称: MatMulInteger (GitHub)

  • : main

  • 自版本: 10

  • 函数: False

  • 支持级别: SupportType.COMMON

  • 形状推断: True

此版本的算子自版本 10 起可用。

摘要

行为类似于 numpy.matmul 的矩阵乘积。乘法结果绝不能溢出。累加结果仅在 32 位时可能溢出。

输入

2 到 4 个输入。

  • A (异构) - T1

    N 维矩阵 A

  • B (异构) - T2

    N 维矩阵 B

  • a_zero_point (可选,异构) - T1

    输入 'A' 的零点张量。它是可选的,默认值为 0。它可以是标量或 N 维张量。标量表示按张量量化,而 N 维表示按行量化。如果输入是形状为 [M, K] 的 2D 张量,则零点张量可能是一个 M 个元素的向量 [zp_1, zp_2, …, zp_M]。如果输入是形状为 [D1, D2, M, K] 的 N 维张量,则零点张量可能具有形状 [D1, D2, M, 1]。

  • b_zero_point (可选,异构) - T2

    输入 'B' 的零点张量。它是可选的,默认值为 0。它可以是标量或 N 维张量。标量表示按张量量化,而 N 维表示按列量化。如果输入是形状为 [K, N] 的 2D 张量,则零点张量可能是一个 N 个元素的向量 [zp_1, zp_2, …, zp_N]。如果输入是形状为 [D1, D2, K, N] 的 N 维张量,则零点张量可能具有形状 [D1, D2, 1, N]。

输出

  • Y (异构) - T3

    A * B 的矩阵乘法结果

类型约束

  • T1 类型为 ( tensor(int8), tensor(uint8) )

    约束输入 A 的数据类型为 8 位整数张量。

  • T2 类型为 ( tensor(int8), tensor(uint8) )

    约束输入 B 的数据类型为 8 位整数张量。

  • T3 类型为 ( tensor(int32) )

    约束输出 Y 的数据类型为 32 位整数张量。