转换包含 XGBoost 模型的管道

sklearn-onnx 仅将 scikit-learn 模型转换为 ONNX,但许多库实现了 scikit-learn API,以便它们的模型可以包含在 scikit-learn 管道中。本示例考虑一个包含 XGBoost 模型的管道。只要 sklearn-onnx 知道与 XGBClassifier 相关的转换器,它就可以转换整个管道。让我们看看如何实现。

训练 XGBoost 分类器

import numpy
import onnxruntime as rt
from sklearn.datasets import load_iris, load_diabetes, make_classification
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from xgboost import XGBClassifier, XGBRegressor, DMatrix, train as train_xgb
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn, to_onnx, update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_classifier_output_shapes,
    calculate_linear_regressor_output_shapes,
)
from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost
from onnxmltools.convert import convert_xgboost as convert_xgboost_booster


data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline([("scaler", StandardScaler()), ("xgb", XGBClassifier(n_estimators=3))])
pipe.fit(X, y)

# The conversion fails but it is expected.

try:
    convert_sklearn(
        pipe,
        "pipeline_xgboost",
        [("input", FloatTensorType([None, 2]))],
        target_opset={"": 12, "ai.onnx.ml": 2},
    )
except Exception as e:
    print(e)

# The error message tells no converter was found
# for :epkg:`XGBoost` models. By default, :epkg:`sklearn-onnx`
# only handles models from :epkg:`scikit-learn` but it can
# be extended to every model following :epkg:`scikit-learn`
# API as long as the module knows there exists a converter
# for every model used in a pipeline. That's why
# we need to register a converter.
'super' object has no attribute '__sklearn_tags__'

注册 XGBClassifier 的转换器

转换器在 onnxmltools 中实现:onnxmltools…XGBoost.py。以及形状计算器:onnxmltools…Classifier.py

update_registered_converter(
    XGBClassifier,
    "XGBoostXGBClassifier",
    calculate_linear_classifier_output_shapes,
    convert_xgboost,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

再次转换

model_onnx = convert_sklearn(
    pipe,
    "pipeline_xgboost",
    [("input", FloatTensorType([None, 2]))],
    target_opset={"": 12, "ai.onnx.ml": 2},
)

# And save.
with open("pipeline_xgboost.onnx", "wb") as f:
    f.write(model_onnx.SerializeToString())
Traceback (most recent call last):
  File "/home/xadupre/github/sklearn-onnx/docs/tutorial/plot_gexternal_xgboost.py", line 96, in <module>
    model_onnx = convert_sklearn(
                 ^^^^^^^^^^^^^^^^
  File "/home/xadupre/github/sklearn-onnx/skl2onnx/convert.py", line 192, in convert_sklearn
    topology = parse_sklearn_model(
               ^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/github/sklearn-onnx/skl2onnx/_parse.py", line 847, in parse_sklearn_model
    outputs = parse_sklearn(
              ^^^^^^^^^^^^^^
  File "/home/xadupre/github/sklearn-onnx/skl2onnx/_parse.py", line 757, in parse_sklearn
    res = _parse_sklearn(scope, model, inputs, custom_parsers=custom_parsers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/github/sklearn-onnx/skl2onnx/_parse.py", line 688, in _parse_sklearn
    outputs = sklearn_parsers_map[tmodel](
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/github/sklearn-onnx/skl2onnx/_parse.py", line 295, in _parse_sklearn_pipeline
    ) and is_classifier(step[1]):
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/base.py", line 1237, in is_classifier
    return get_tags(estimator).estimator_type == "classifier"
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/_tags.py", line 405, in get_tags
    sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator)  # type: ignore[attr-defined]
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/base.py", line 540, in __sklearn_tags__
    tags = super().__sklearn_tags__()
           ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'super' object has no attribute '__sklearn_tags__'

比较预测结果

使用 XGBoost 的预测结果。

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))

使用 onnxruntime 的预测结果。

sess = rt.InferenceSession("pipeline_xgboost.onnx", providers=["CPUExecutionProvider"])
pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])

使用 XGBRegressor 的相同示例

update_registered_converter(
    XGBRegressor,
    "XGBoostXGBRegressor",
    calculate_linear_regressor_output_shapes,
    convert_xgboost,
)


data = load_diabetes()
x = data.data
y = data.target
X_train, X_test, y_train, _ = train_test_split(x, y, test_size=0.5)

pipe = Pipeline([("scaler", StandardScaler()), ("xgb", XGBRegressor(n_estimators=3))])
pipe.fit(X_train, y_train)

print("predict", pipe.predict(X_test[:5]))

ONNX

onx = to_onnx(
    pipe, X_train.astype(numpy.float32), target_opset={"": 12, "ai.onnx.ml": 2}
)

sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
pred_onx = sess.run(None, {"X": X_test[:5].astype(numpy.float32)})
print("predict", pred_onx[0].ravel())

可能会出现一些差异。在这种情况下,您应该阅读切换到浮点数时的问题

使用 Booster 的相同示例

Booster 不能插入到管道中。它需要一个不同的转换函数,因为它不遵循 scikit-learn API。

x, y = make_classification(
    n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3
)
X_train, X_test, y_train, _ = train_test_split(x, y, test_size=0.5, random_state=42)

dtrain = DMatrix(X_train, label=y_train)

param = {"objective": "multi:softmax", "num_class": 3}
bst = train_xgb(param, dtrain, 10)

initial_type = [("float_input", FloatTensorType([None, X_train.shape[1]]))]

try:
    onx = convert_xgboost_booster(bst, "name", initial_types=initial_type)
    cont = True
except AssertionError as e:
    print("XGBoost is too recent or onnxmltools too old.", e)
    cont = False

if cont:
    sess = rt.InferenceSession(
        onx.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name
    pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
    print(pred_onx)

脚本总运行时间:(0 分钟 0.032 秒)

由 Sphinx-Gallery 生成的图库