训练和部署 scikit-learn 管道

本程序基于 scikit-learn 文档中的一个示例:绘制个体和投票回归预测,将其转换为 ONNX,最后使用不同的运行时计算预测结果。

训练管道

import numpy
from onnxruntime import InferenceSession
from sklearn.datasets import load_diabetes
from sklearn.ensemble import (
    GradientBoostingRegressor,
    RandomForestRegressor,
    VotingRegressor,
)
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from skl2onnx import to_onnx
from onnx.reference import ReferenceEvaluator


X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

# Train classifiers
reg1 = GradientBoostingRegressor(random_state=1, n_estimators=5)
reg2 = RandomForestRegressor(random_state=1, n_estimators=5)
reg3 = LinearRegression()

ereg = Pipeline(
    steps=[
        ("voting", VotingRegressor([("gb", reg1), ("rf", reg2), ("lr", reg3)])),
    ]
)
ereg.fit(X_train, y_train)
Pipeline(steps=[('voting',
                 VotingRegressor(estimators=[('gb',
                                              GradientBoostingRegressor(n_estimators=5,
                                                                        random_state=1)),
                                             ('rf',
                                              RandomForestRegressor(n_estimators=5,
                                                                    random_state=1)),
                                             ('lr', LinearRegression())]))])
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任笔记本。
在 GitHub 上,无法渲染 HTML 表示,请尝试使用 nbviewer.org 加载此页面。


转换模型

第二个参数提供了用于训练模型的数据样本。它用于推断 ONNX 图的输入类型。它被转换为单精度浮点数,因为 ONNX 运行时可能不完全支持双精度浮点数。

onx = to_onnx(ereg, X_train[:1].astype(numpy.float32), target_opset=12)

使用 ONNX 进行预测

第一个示例使用了 onnxruntime

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]

pred_skl = ereg.predict(X_test.astype(numpy.float32))

print("Onnx Runtime prediction:\n", pred_ort[:5])
print("Sklearn rediction:\n", pred_skl[:5])
Onnx Runtime prediction:
 [[ 95.148315]
 [236.64844 ]
 [192.72498 ]
 [163.4321  ]
 [159.0243  ]]
Sklearn rediction:
 [ 95.14831426 236.64842262 192.7249677  163.4320918  159.02430011]

比较

在部署之前,我们需要比较 scikit-learnONNX 是否返回相同的预测结果。

def diff(p1, p2):
    p1 = p1.ravel()
    p2 = p2.ravel()
    d = numpy.abs(p2 - p1)
    return d.max(), (d / numpy.abs(p1)).max()


print(diff(pred_skl, pred_ort))
(np.float64(2.3916485787367492e-05), np.float64(1.4299860951284095e-07))

结果看起来不错。最大误差(绝对误差和相对误差)都在使用单精度浮点数而非双精度浮点数引入的误差范围内。我们可以将模型保存为 ONNX 格式,并在多个平台中使用 onnxruntime 计算相同的预测结果。

Python 运行时

也可以使用 Python 运行时来计算预测结果。它不适用于生产环境(因为它仍然依赖于 Python),但对于调查转换出错的原因很有用。

oinf = ReferenceEvaluator(onx)
print(oinf)
ReferenceEvaluator(X) -> variable

它的工作方式几乎相同。

pred_pyrt = oinf.run(None, {"X": X_test.astype(numpy.float32)})[0]
print(diff(pred_skl, pred_pyrt))
(np.float64(2.3916485787367492e-05), np.float64(1.4092690429730078e-07))

**脚本总运行时间:** (0 分钟 0.319 秒)

由 Sphinx-Gallery 生成的图库