onnx._custom_element_types¶
此模块定义了 numpy 不支持的自定义 dtype。函数 onnx.numpy_helper.from_array()
和 onnx.numpy_helper.to_array()
使用它们来将数组从/转换为这些类型。类 onnx.reference.ReferenceEvalutor
也使用它们。例如,要为单元测试创建此类数组,可以使用以下便捷方式:
import numpy as np
from onnx import TensorProto
from onnx.reference.ops.op_cast import Cast_19 as Cast
tensor_bfloat16 = Cast.eval(np.array([0, 1], dtype=np.float32), to=TensorProto.BFLOAT16)
下面使用的 numpy 表示 dtype 仅用于内部用途。它们未来可能会根据这些 numpy 类型的行业标准化而改变。
- onnx._custom_element_types.bfloat16 = dtype((numpy.uint16, [('bfloat16', '<u2')]))¶
将 bfloat16 定义为 uint16。
- onnx._custom_element_types.float4e2m1 = dtype((numpy.uint8, [('float4e2m1', 'u1')]))¶
定义 float 4 e2m1 类型,请参阅 存储在 4 位中的浮点数 获取技术详情。请注意,一个整数使用一个字节存储,因此是其 ONNX 大小的两倍。
- onnx._custom_element_types.float8e4m3fn = dtype((numpy.uint8, [('e4m3fn', 'u1')]))¶
定义 float 8 e4m3fn 类型,请参阅 存储在 8 位中的浮点数 获取技术详情。
- onnx._custom_element_types.float8e4m3fnuz = dtype((numpy.uint8, [('e4m3fnuz', 'u1')]))¶
定义 float 8 e4m3fnuz 类型,请参阅 存储在 8 位中的浮点数 获取技术详情。
- onnx._custom_element_types.float8e5m2 = dtype((numpy.uint8, [('e5m2', 'u1')]))¶
定义 float 8 e5m2 类型,请参阅 存储在 8 位中的浮点数 获取技术详情。
- onnx._custom_element_types.float8e5m2fnuz = dtype((numpy.uint8, [('e5m2fnuz', 'u1')]))¶
定义 float 8 e5m2fnuz 类型,请参阅 存储在 8 位中的浮点数 获取技术详情。