注意
转到末尾 以下载完整的示例代码。
训练和部署 scikit-learn 管道¶
此程序从 scikit-learn 文档中的一个示例开始:Plot individual and voting regression predictions,将其转换为 ONNX,最后在不同的运行时上计算预测。
训练管道¶
import numpy
from onnxruntime import InferenceSession
from sklearn.datasets import load_diabetes
from sklearn.ensemble import (
GradientBoostingRegressor,
RandomForestRegressor,
VotingRegressor,
)
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from skl2onnx import to_onnx
from onnx.reference import ReferenceEvaluator
X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
# Train classifiers
reg1 = GradientBoostingRegressor(random_state=1, n_estimators=5)
reg2 = RandomForestRegressor(random_state=1, n_estimators=5)
reg3 = LinearRegression()
ereg = Pipeline(
steps=[
("voting", VotingRegressor([("gb", reg1), ("rf", reg2), ("lr", reg3)])),
]
)
ereg.fit(X_train, y_train)
转换模型¶
第二个参数提供了用于训练模型的数据样本。它用于推断 ONNX 图的输入类型。它被转换为单精度浮点数,ONNX 运行时可能不支持双精度浮点数。
onx = to_onnx(ereg, X_train[:1].astype(numpy.float32), target_opset=12)
使用 ONNX 进行预测¶
第一个示例使用 onnxruntime。
sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]
pred_skl = ereg.predict(X_test.astype(numpy.float32))
print("Onnx Runtime prediction:\n", pred_ort[:5])
print("Sklearn rediction:\n", pred_skl[:5])
Onnx Runtime prediction:
[[131.3207 ]
[195.07472]
[177.4441 ]
[132.33994]
[119.30467]]
Sklearn rediction:
[131.32069594 195.07471731 177.44409343 132.33992579 119.30466842]
比较¶
在部署之前,我们需要比较 *scikit-learn* 和 *ONNX* 是否返回相同的预测。
(np.float64(1.6595706966882062e-05), np.float64(1.0211908589281735e-07))
看起来不错。最大的误差(绝对值和相对值)在由于使用单精度浮点数而不是双精度浮点数而引入的误差范围内。我们可以将模型保存为 ONNX 格式,并使用 onnxruntime 在许多平台上计算相同的预测。
Python 运行时¶
也可以使用 Python 运行时来计算预测。它不适用于生产环境(因为它仍然依赖于 Python),但对于调查转换为何出错很有用。
oinf = ReferenceEvaluator(onx)
print(oinf)
ReferenceEvaluator(X) -> variable
它的工作方式几乎相同。
pred_pyrt = oinf.run(None, {"X": X_test.astype(numpy.float32)})[0]
print(diff(pred_skl, pred_pyrt))
(np.float64(1.6595706966882062e-05), np.float64(1.0211908589281735e-07))
脚本总运行时间: (0 分 0.444 秒)