调查管道

以下示例展示了如何查看已转换的模型,并轻松查找管道每个步骤中的错误。

创建管道

我们重用示例 管道:链接 PCA 和逻辑回归 中实现的管道。由于 ONNX-ML Imputer 不处理字符串类型,因此存在一个更改。这不能成为最终 ONNX 管道的一部分,必须将其删除。请查找下面以 --- 开头的注释。

import skl2onnx
import onnx
import sklearn
import numpy
import pickle
from skl2onnx.helpers import collect_intermediate_steps
import onnxruntime as rt
from onnxconverter_common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
import numpy as np
import pandas as pd

from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

pipe = Pipeline(steps=[("pca", PCA()), ("logistic", LogisticRegression())])

digits = datasets.load_digits()
X_digits = digits.data[:1000]
y_digits = digits.target[:1000]

pipe.fit(X_digits, y_digits)
Pipeline(steps=[('pca', PCA()), ('logistic', LogisticRegression())])
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示形式或信任笔记本。
在 GitHub 上,HTML 表示形式无法呈现,请尝试使用 nbviewer.org 加载此页面。


转换为 ONNX

initial_types = [("input", FloatTensorType((None, X_digits.shape[1])))]
model_onnx = convert_sklearn(pipe, initial_types=initial_types, target_opset=12)

sess = rt.InferenceSession(
    model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
print("skl predict_proba")
print(pipe.predict_proba(X_digits[:2]))
onx_pred = sess.run(None, {"input": X_digits[:2].astype(np.float32)})[1]
df = pd.DataFrame(onx_pred)
print("onnx predict_proba")
print(df.values)
skl predict_proba
[[9.99998530e-01 7.81608916e-19 4.87445989e-10 1.79842282e-08
  3.58700554e-10 1.18138025e-06 4.14411051e-08 1.48275027e-07
  2.50162860e-08 5.51240034e-08]
 [1.37889361e-14 9.99999324e-01 9.17867392e-11 8.30390364e-13
  2.57277805e-07 8.84035071e-12 5.11781429e-11 2.83346408e-11
  4.18965301e-07 1.32796353e-13]]
onnx predict_proba
[[9.99998569e-01 7.81611026e-19 4.87444585e-10 1.79842026e-08
  3.58700042e-10 1.18137689e-06 4.14409520e-08 1.48274751e-07
  2.50162131e-08 5.51239410e-08]
 [1.37888807e-14 9.99999344e-01 9.17865159e-11 8.30387679e-13
  2.57277748e-07 8.84032951e-12 5.11779785e-11 2.83345725e-11
  4.18964021e-07 1.32796280e-13]]

中间步骤

假设最终输出错误,我们需要查看管道中的每个组件,找出哪个组件失败了。以下方法修改了 scikit-learn 管道以窃取中间输出,并为每个操作符生成更小的 ONNX 图。

steps = collect_intermediate_steps(pipe, "pipeline", initial_types)

assert len(steps) == 2

pipe.predict_proba(X_digits[:2])

for i, step in enumerate(steps):
    onnx_step = step["onnx_step"]
    sess = rt.InferenceSession(
        onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    onnx_outputs = sess.run(None, {"input": X_digits[:2].astype(np.float32)})
    skl_outputs = step["model"]._debug.outputs
    print("step 1", type(step["model"]))
    print("skl outputs")
    print(skl_outputs)
    print("onnx outputs")
    print(onnx_outputs)
step 1 <class 'sklearn.decomposition._pca.PCA'>
skl outputs
{'transform': array([[-9.78697129e+00,  7.22639567e+00,  2.16935601e+01,
        -1.13765854e+01,  3.54566122e+00,  5.59543345e+00,
        -4.71459904e+00, -4.29410146e+00, -5.71520266e+00,
        -3.31533698e+00,  3.42040920e-01, -2.90474751e+00,
         3.18177631e-01,  6.66363079e-01, -2.82714171e+00,
        -5.91632481e+00,  9.69544780e-01, -1.92676767e+00,
         1.71450677e+00, -9.60454853e-01,  3.81570991e-01,
        -1.37130203e+00,  4.29353551e+00,  2.32392659e+00,
         7.13256034e-01,  3.00982060e+00, -1.98303620e+00,
        -4.81811365e-01,  1.90930400e-01, -2.03950266e+00,
         1.59803428e+00,  1.46831581e+00,  1.70903280e+00,
         7.93109126e-02,  1.62244448e-01, -5.10619572e-02,
        -6.63308841e-01, -1.35869345e+00,  1.03930533e+00,
        -2.09485311e+00, -2.15669105e+00, -7.78040093e-02,
         4.01347652e-02,  8.40159293e-01, -4.74891758e-01,
        -1.14564701e-01,  5.31817617e-02, -6.87010227e-01,
        -1.29090165e-01,  2.12032919e-01, -3.63901656e-01,
         1.29285214e-01, -8.14384613e-02, -3.82919696e-02,
        -9.76885583e-03, -1.39046240e-02,  1.59100433e-03,
        -2.87444919e-03,  5.75119957e-03,  1.85595427e-03,
        -5.00911047e-03,  1.16099460e-14,  0.00000000e+00,
         5.24417152e-14],
       [ 1.54267314e+01, -4.91291516e+00, -1.74676972e+01,
         1.13960509e+01, -5.64555024e+00,  5.73696034e+00,
         2.08026490e+00, -5.23721537e+00,  3.37859393e+00,
         3.60754149e+00, -2.90967608e+00,  3.75628331e+00,
         1.21238177e+00,  5.21796290e+00, -4.95051435e+00,
         4.01835168e+00,  2.97046115e+00,  5.64772188e+00,
         5.61898054e+00,  4.32016109e+00,  1.97701819e+00,
        -3.39030059e+00, -5.67779351e-01,  6.70107684e-01,
         6.31443589e+00,  8.65991552e-01, -1.58633137e-01,
        -3.52940090e+00,  6.81737794e-01, -2.47187038e+00,
         1.21588602e+00,  2.22346979e+00, -1.37364649e+00,
        -1.79895009e+00, -3.03710592e+00,  2.63278986e+00,
         3.68918985e+00,  6.08509461e-01, -2.45039011e-01,
         6.63479061e-01,  1.50727140e+00,  1.10449110e+00,
         4.58384385e-01,  3.40399894e-01, -2.67878895e-01,
        -1.87647893e+00,  2.04332870e-01,  4.61919057e-01,
        -2.44538953e-02,  8.66380644e-04,  7.56583008e-02,
        -1.91237218e-01, -4.73950435e-02,  2.74122911e-02,
         4.32524378e-03, -3.66956686e-03, -1.88790753e-03,
         5.22119207e-03, -1.86775268e-03, -5.07041881e-03,
        -1.70805502e-03, -1.38978665e-14,  0.00000000e+00,
        -3.09204766e-14]])}
onnx outputs
[array([[-9.78696918e+00,  7.22639418e+00,  2.16935596e+01,
        -1.13765850e+01,  3.54566121e+00,  5.59543371e+00,
        -4.71459913e+00, -4.29410172e+00, -5.71520233e+00,
        -3.31533718e+00,  3.42040539e-01, -2.90474844e+00,
         3.18177342e-01,  6.66362762e-01, -2.82714128e+00,
        -5.91632557e+00,  9.69543815e-01, -1.92676806e+00,
         1.71450746e+00, -9.60454881e-01,  3.81571263e-01,
        -1.37130213e+00,  4.29353619e+00,  2.32392645e+00,
         7.13255882e-01,  3.00982118e+00, -1.98303699e+00,
        -4.81811404e-01,  1.90929934e-01, -2.03950286e+00,
         1.59803450e+00,  1.46831572e+00,  1.70903301e+00,
         7.93112069e-02,  1.62244260e-01, -5.10617606e-02,
        -6.63308799e-01, -1.35869288e+00,  1.03930473e+00,
        -2.09485388e+00, -2.15669155e+00, -7.78041705e-02,
         4.01349142e-02,  8.40159237e-01, -4.74891722e-01,
        -1.14564866e-01,  5.31819277e-02, -6.87010169e-01,
        -1.29090086e-01,  2.12032884e-01, -3.63901585e-01,
         1.29285216e-01, -8.14384818e-02, -3.82919535e-02,
        -9.76885669e-03, -1.39046200e-02,  1.59100525e-03,
        -2.87444773e-03,  5.75120188e-03,  1.85595278e-03,
        -5.00911009e-03,  1.16099418e-14,  0.00000000e+00,
         5.24416868e-14],
       [ 1.54267330e+01, -4.91291523e+00, -1.74676971e+01,
         1.13960505e+01, -5.64554977e+00,  5.73695993e+00,
         2.08026457e+00, -5.23721600e+00,  3.37859321e+00,
         3.60754204e+00, -2.90967607e+00,  3.75628328e+00,
         1.21238220e+00,  5.21796322e+00, -4.95051479e+00,
         4.01835155e+00,  2.97046089e+00,  5.64772224e+00,
         5.61898088e+00,  4.32016134e+00,  1.97701883e+00,
        -3.39030147e+00, -5.67779541e-01,  6.70108199e-01,
         6.31443739e+00,  8.65990937e-01, -1.58633217e-01,
        -3.52940059e+00,  6.81736946e-01, -2.47186923e+00,
         1.21588576e+00,  2.22346997e+00, -1.37364638e+00,
        -1.79894984e+00, -3.03710651e+00,  2.63278937e+00,
         3.68918991e+00,  6.08509481e-01, -2.45039046e-01,
         6.63479507e-01,  1.50727105e+00,  1.10449100e+00,
         4.58384484e-01,  3.40399802e-01, -2.67878950e-01,
        -1.87647831e+00,  2.04333529e-01,  4.61919039e-01,
        -2.44537946e-02,  8.66464688e-04,  7.56583288e-02,
        -1.91237196e-01, -4.73950393e-02,  2.74122953e-02,
         4.32524411e-03, -3.66956298e-03, -1.88790704e-03,
         5.22119273e-03, -1.86775194e-03, -5.07041626e-03,
        -1.70805526e-03, -1.38978599e-14,  0.00000000e+00,
        -3.09204973e-14]], dtype=float32)]
step 1 <class 'sklearn.linear_model._logistic.LogisticRegression'>
skl outputs
{'decision_function': array([[9.99998530e-01, 7.81608916e-19, 4.87445989e-10, 1.79842282e-08,
        3.58700554e-10, 1.18138025e-06, 4.14411051e-08, 1.48275027e-07,
        2.50162860e-08, 5.51240034e-08],
       [1.37889361e-14, 9.99999324e-01, 9.17867392e-11, 8.30390364e-13,
        2.57277805e-07, 8.84035071e-12, 5.11781429e-11, 2.83346408e-11,
        4.18965301e-07, 1.32796353e-13]]), 'predict_proba': array([[9.99998530e-01, 7.81608916e-19, 4.87445989e-10, 1.79842282e-08,
        3.58700554e-10, 1.18138025e-06, 4.14411051e-08, 1.48275027e-07,
        2.50162860e-08, 5.51240034e-08],
       [1.37889361e-14, 9.99999324e-01, 9.17867392e-11, 8.30390364e-13,
        2.57277805e-07, 8.84035071e-12, 5.11781429e-11, 2.83346408e-11,
        4.18965301e-07, 1.32796353e-13]])}
onnx outputs
[array([0, 1], dtype=int64), array([[9.9999857e-01, 7.8161103e-19, 4.8744458e-10, 1.7984203e-08,
        3.5870004e-10, 1.1813769e-06, 4.1440952e-08, 1.4827475e-07,
        2.5016213e-08, 5.5123941e-08],
       [1.3788881e-14, 9.9999934e-01, 9.1786516e-11, 8.3038768e-13,
        2.5727775e-07, 8.8403295e-12, 5.1177979e-11, 2.8334573e-11,
        4.1896402e-07, 1.3279628e-13]], dtype=float32)]

Pickle

每个步骤都是管道中的一个单独模型。它可以独立于其他模型进行 Pickle。属性 _debug 包含 重放 模型预测所需的所有信息。

to_save = {
    "model": steps[1]["model"],
    "data_input": steps[1]["model"]._debug.inputs,
    "data_output": steps[1]["model"]._debug.outputs,
    "inputs": steps[1]["inputs"],
    "outputs": steps[1]["outputs"],
}
del steps[1]["model"]._debug

with open("classifier.pkl", "wb") as f:
    pickle.dump(to_save, f)

with open("classifier.pkl", "rb") as f:
    restored = pickle.load(f)

print(restored["model"].predict_proba(restored["data_input"]["predict_proba"]))
[[9.99998530e-01 7.81608916e-19 4.87445989e-10 1.79842282e-08
  3.58700554e-10 1.18138025e-06 4.14411051e-08 1.48275027e-07
  2.50162860e-08 5.51240034e-08]
 [1.37889361e-14 9.99999324e-01 9.17867392e-11 8.30390364e-13
  2.57277805e-07 8.84035071e-12 5.11781429e-11 2.83346408e-11
  4.18965301e-07 1.32796353e-13]]

此示例使用的版本

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.26.4
scikit-learn: 1.6.dev0
onnx:  1.17.0
onnxruntime:  1.18.0+cu118
skl2onnx:  1.17.0

脚本的总运行时间:(0 分钟 0.506 秒)

由 Sphinx-Gallery 生成的库