sklearn-onnx:将您的 scikit-learn 模型转换为 ONNX

sklearn-onnx 使您能够将来自 scikit-learn 工具包的模型转换为 ONNX

问题,疑问

您应该查找 现有问题 或提交新的问题。源代码可在 onnx/sklearn-onnx 上获取。

ONNX 版本

转换器可以为特定版本的 ONNX 转换模型。每个 ONNX 版本都用一个 opset 版本号标记,该版本号由函数 onnx_opset_version 返回。如果在转换模型时未指定目标 opset 参数(参数 target_opset),则此函数将返回该参数的默认值。每个操作符都有版本。库为每个操作符选择低于或等于目标 opset 版本号的最新版本。ONNX 模型每个操作符域都有一个 opset 版本号,此值是所有 onnx 节点中最大的 opset 版本号。

<<<

from skl2onnx import __max_supported_opset__, __version__

print("documentation for version:", __version__)
print("Last supported opset:", __max_supported_opset__)

>>>

    documentation for version: 1.18.0
    Last supported opset: 21

后端

sklearn-onnx 将模型转换为 ONNX 格式,然后可以使用您选择的任何后端进行预测计算。但是,存在一种方法可以使用 onnxruntimeonnxruntime-gpu 自动检查每个转换器。每个转换器都使用此后端进行测试。

入门

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

# Convert into ONNX format.
from skl2onnx import to_onnx

onx = to_onnx(clr, X[:1])
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

# Compute the prediction with onnxruntime.
import onnxruntime as rt

sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]

相关转换器

sklearn-onnx 仅转换来自 scikit-learn 的模型。onnxmltools 可用于转换 libsvmlightgbmxgboost 的模型。其他转换器可在 github/onnxtorch.onnxONNX-MXNet APIMicrosoft.ML.Onnx… 上找到。

更改日志

参见 CHANGELOGS.md

鸣谢

该软件包由以下微软工程师和数据科学家从 2017 年冬季开始启动:Zeeshan Ahmed、Wei-Sheng Chin、Aidan Crook、Xavier Dupré、Costin Eseanu、Tom Finley、Lixin Gong、Scott Inglis、Pei Jiang、Ivan Matantsev、Prabhat Roy、M. Zeeshan Siddiqui、Shouheng Yi、Shauheen Zahirazami、Yiwen Zhu、Du Li、Xuan Li、Wenbing Li。

许可证

它使用 Apache License v2.0 许可。

旧版本