注意
转到结尾 下载完整的示例代码。
将包含 LightGbm 模型的管道转换为 ONNX¶
sklearn-onnx 仅将 scikit-learn 模型转换为 ONNX,但许多库实现了 scikit-learn API,以便其模型可以包含在 scikit-learn 管道中。本示例考虑了一个包含 LightGbm 模型的管道。sklearn-onnx 可以转换整个管道,只要它知道与 LGBMClassifier 关联的转换器即可。让我们看看如何做到这一点。
训练 LightGbm 分类器¶
import lightgbm
import onnxmltools
import skl2onnx
import onnx
import sklearn
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnxruntime as rt
from onnxruntime.capi.onnxruntime_pybind11_state import Fail as OrtFail
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes,
) # noqa
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
convert_lightgbm,
) # noqa
import onnxmltools.convert.common.data_types
from skl2onnx.common.data_types import FloatTensorType
import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from lightgbm import LGBMClassifier
data = load_iris()
X = data.data[:, :2]
y = data.target
ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()
pipe = Pipeline(
[("scaler", StandardScaler()), ("lgbm", LGBMClassifier(n_estimators=3))]
)
pipe.fit(X, y)
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000688 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 47
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 2
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
注册 LGBMClassifier 的转换器¶
转换器在 onnxmltools 中实现:onnxmltools…LightGbm.py。以及形状计算器:onnxmltools…Classifier.py.
然后我们导入转换器和形状计算器。
让我们注册新的转换器。
update_registered_converter(
LGBMClassifier,
"LightGbmLGBMClassifier",
calculate_linear_classifier_output_shapes,
convert_lightgbm,
options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)
再次转换¶
比较预测结果¶
LightGbm 的预测结果。
predict [0 0 1 1 2]
predict_proba [[0.51995794 0.24549283 0.23454923]]
onnxruntime 的预测结果。
try:
sess = rt.InferenceSession(
"pipeline_lightgbm.onnx", providers=["CPUExecutionProvider"]
)
except OrtFail as e:
print(e)
print("The converter requires onnxmltools>=1.7.0")
sess = None
if sess is not None:
pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])
predict [0 0 1 1 2]
predict_proba [{0: 0.519957959651947, 1: 0.2454928159713745, 2: 0.23454922437667847}]
显示 ONNX 图表¶
pydot_graph = GetPydotGraph(
model_onnx.graph,
name=model_onnx.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"
),
)
pydot_graph.write_dot("pipeline.dot")
os.system("dot -O -Gdpi=300 -Tpng pipeline.dot")
image = plt.imread("pipeline.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
(-0.5, 2549.5, 2558.5, -0.5)
本示例中使用的版本
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
print("onnxmltools: ", onnxmltools.__version__)
print("lightgbm: ", lightgbm.__version__)
numpy: 1.26.4
scikit-learn: 1.6.dev0
onnx: 1.17.0
onnxruntime: 1.18.0+cu118
skl2onnx: 1.17.0
onnxmltools: 1.13.0
lightgbm: 4.2.0
脚本的总运行时间:(0 分钟 3.752 秒)