将包含 LightGbm 模型的流水线转换为 ONNX

sklearn-onnx 只将 scikit-learn 模型转换为 ONNX,但许多库实现了 scikit-learn API,因此它们的模型可以包含在 scikit-learn 流水线中。此示例考虑了一个包含 LightGbm 模型的流水线。只要 sklearn-onnx 知道与 LGBMClassifier 相关的转换器,它就可以转换整个流水线。让我们看看如何实现。

训练 LightGBM 分类器

import lightgbm
import onnxmltools
import skl2onnx
import onnx
import sklearn
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnxruntime as rt
from onnxruntime.capi.onnxruntime_pybind11_state import Fail as OrtFail
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_classifier_output_shapes,
)
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
    convert_lightgbm,
)
import onnxmltools.convert.common.data_types
from skl2onnx.common.data_types import FloatTensorType
import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from lightgbm import LGBMClassifier

data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline(
    [("scaler", StandardScaler()), ("lgbm", LGBMClassifier(n_estimators=3))]
)
pipe.fit(X, y)
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.003623 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 47
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 2
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
Pipeline(steps=[('scaler', StandardScaler()),
                ('lgbm', LGBMClassifier(n_estimators=3))])
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任此 notebook。
在 GitHub 上,HTML 表示无法渲染,请尝试使用 nbviewer.org 加载此页面。


注册 LGBMClassifier 的转换器

该转换器实现在 onnxmltools 中:onnxmltools…LightGbm.py。以及形状计算器:onnxmltools…Classifier.py

然后我们导入转换器和形状计算器。

让我们注册新的转换器。

update_registered_converter(
    LGBMClassifier,
    "LightGbmLGBMClassifier",
    calculate_linear_classifier_output_shapes,
    convert_lightgbm,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

再次转换

model_onnx = convert_sklearn(
    pipe,
    "pipeline_lightgbm",
    [("input", FloatTensorType([None, 2]))],
    target_opset={"": 12, "ai.onnx.ml": 2},
)

# And save.
with open("pipeline_lightgbm.onnx", "wb") as f:
    f.write(model_onnx.SerializeToString())

比较预测结果

使用 LightGbm 进行预测。

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))
predict [1 2 2 2 2]
predict_proba [[0.25335584 0.45934348 0.28730068]]

使用 onnxruntime 进行预测。

try:
    sess = rt.InferenceSession(
        "pipeline_lightgbm.onnx", providers=["CPUExecutionProvider"]
    )
except OrtFail as e:
    print(e)
    print("The converter requires onnxmltools>=1.7.0")
    sess = None

if sess is not None:
    pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
    print("predict", pred_onx[0])
    print("predict_proba", pred_onx[1][:1])
predict [1 2 2 2 2]
predict_proba [{0: 0.25335583090782166, 1: 0.45934349298477173, 2: 0.287300705909729}]

显示 ONNX 图

pydot_graph = GetPydotGraph(
    model_onnx.graph,
    name=model_onnx.graph.name,
    rankdir="TB",
    node_producer=GetOpNodeProducer(
        "docstring", color="yellow", fillcolor="yellow", style="filled"
    ),
)
pydot_graph.write_dot("pipeline.dot")

os.system("dot -O -Gdpi=300 -Tpng pipeline.dot")

image = plt.imread("pipeline.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
plot pipeline lightgbm
(np.float64(-0.5), np.float64(2549.5), np.float64(2558.5), np.float64(-0.5))

本示例使用的版本

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
print("onnxmltools: ", onnxmltools.__version__)
print("lightgbm: ", lightgbm.__version__)
numpy: 2.2.0
scikit-learn: 1.6.0
onnx:  1.18.0
onnxruntime:  1.21.0+cu126
skl2onnx:  1.18.0
onnxmltools:  1.13.0
lightgbm:  4.5.0

脚本总运行时间: (0 分钟 1.351 秒)

由 Sphinx-Gallery 生成的示例集锦