
用户想要的功能不仅仅是将转换后的模型用于 ONNX,有很多原因。可能需要中间结果,即图中每个节点的输出。可能需要更改 ONNX 以删除某些节点。迁移学习通常是删除深度神经网络的最后一层。另一个原因是调试。运行时由于形状不匹配而无法计算预测,这种情况经常发生。然后获取每个中间结果的形状很有用。此示例探讨了两种执行此操作的方法。


第一种方法比较棘手:它重载方法 transformpredictpredict_proba 以保留输入和输出的副本。然后它遍历管道的每个步骤。如果管道有 n 个步骤,它将转换具有步骤 1 的管道,然后转换具有步骤 1、2 的管道,然后转换 1、2、3…

import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType


data = load_iris()
X = data.data

pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))])
Pipeline(steps=[('std', StandardScaler()),
                ('km', KMeans(n_clusters=3, n_init=3))])
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示形式或信任笔记本。
在 GitHub 上,HTML 表示形式无法呈现,请尝试使用 nbviewer.org 加载此页面。

该函数遍历每个步骤,重载方法 transform 并为每个步骤返回一个 ONNX 图。

steps = collect_intermediate_steps(
    pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17

我们调用方法 transform 来填充缓存,重载的方法 transform 保留。

array([[3.12119834, 0.21295824, 3.98940603],
       [2.6755083 , 0.99604549, 4.01793312],
       [2.97416665, 0.65198444, 4.19343668],
       [2.88014429, 0.9034561 , 4.19784749],
       [3.30022609, 0.40215457, 4.11157152],
       [3.50554424, 1.21154793, 3.89893116],
       [3.14856384, 0.50244932, 4.21638048],
       [2.99184826, 0.09132468, 3.97313411],
       [2.92515933, 1.42174651, 4.40757189],
       [2.79398956, 0.78993078, 4.05764261],
       [3.32125333, 0.78999385, 3.92088109],
       [3.0493632 , 0.27618123, 4.07853631],
       [2.80635045, 1.03497888, 4.16440431],
       [3.21220972, 1.33482453, 4.63069748],
       [3.88834965, 1.63865558, 4.14619343],
       [4.4998303 , 2.39898792, 4.49547518],
       [3.60978017, 1.20748818, 4.02966144],
       [3.05594182, 0.21618828, 3.91388548],
       [3.34493953, 1.20986655, 3.72562039],
       [3.50065397, 0.86706182, 4.10101938],
       [2.80825681, 0.50401564, 3.66383713],
       [3.27800809, 0.66826437, 3.94496718],
       [3.58990876, 0.68658071, 4.51061335],
       [2.55934697, 0.47945627, 3.57996434],
       [2.96493153, 0.36345425, 3.98817445],
       [2.55682739, 0.99023912, 3.88431906],
       [2.8279719 , 0.22683089, 3.79088782],
       [3.05970831, 0.2947186 , 3.89539875],
       [2.95425291, 0.25361098, 3.88085622],
       [2.87745051, 0.65019824, 4.09851673],
       [2.73238773, 0.80138328, 4.01796142],
       [2.73361981, 0.52309257, 3.57350896],
       [4.11853014, 1.57658655, 4.5037664 ],
       [4.22845606, 1.87652483, 4.4465301 ],
       [2.71452112, 0.76858489, 3.97906378],
       [2.86508665, 0.54896332, 4.01986385],
       [3.0573692 , 0.63079314, 3.80064093],
       [3.40284985, 0.45982568, 4.25136846],
       [3.00742655, 1.2336976 , 4.42052558],
       [2.95472117, 0.14580827, 3.90865188],
       [3.12324651, 0.20261743, 4.01192633],
       [2.90164193, 2.67055552, 4.64398605],
       [3.15411688, 0.90927099, 4.42154566],
       [2.8613548 , 0.50081008, 3.70483773],
       [3.34606471, 0.92159916, 3.9078554 ],
       [2.65231058, 1.01946042, 4.01421067],
       [3.53206587, 0.86953764, 4.14238152],
       [2.99813103, 0.72275914, 4.23577398],
       [3.34116935, 0.72324305, 3.97409784],
       [2.90222887, 0.30295342, 3.97223984],
       [1.9003878 , 3.43619989, 0.95288059],
       [1.41851492, 2.97232682, 0.99352148],
       [1.68457079, 3.51850037, 0.72661726],
       [0.96940962, 3.33264308, 2.69898424],
       [0.9112523 , 3.35747592, 1.11074501],
       [0.35721918, 2.77550662, 1.8143491 ],
       [1.59351202, 3.01808184, 1.00650285],
       [1.50213315, 2.77360088, 3.31296552],
       [1.11632078, 3.21148368, 1.14114175],
       [0.77921299, 2.66294828, 2.42994048],
       [1.97194958, 3.62389817, 3.73666782],
       [0.77530513, 2.70011145, 1.45918639],
       [1.25941769, 3.53658932, 2.74268279],
       [0.66155141, 2.98813829, 1.28976474],
       [0.73833453, 2.32311723, 2.05251547],
       [1.46572707, 3.14311522, 0.98780965],
       [0.80185102, 2.68234835, 1.67700171],
       [0.568386  , 2.63954211, 2.12682734],
       [1.19987895, 3.97369206, 2.33743839],
       [0.67881532, 2.87494798, 2.46667974],
       [1.34222961, 3.03853641, 1.1880022 ],
       [0.53061062, 2.8022861 , 1.63233668],
       [0.79234309, 3.68305664, 1.65142259],
       [0.57371215, 2.96833851, 1.54593744],
       [0.90589785, 2.9760862 , 1.2933375 ],
       [1.22490527, 3.13002382, 1.03085926],
       [1.26783271, 3.56679427, 1.09304603],
       [1.42114042, 3.5903606 , 0.52050254],
       [0.58974672, 2.93839428, 1.34712856],
       [0.76432091, 2.58203512, 2.44164622],
       [0.89738242, 2.99796537, 2.69027665],
       [0.98549851, 2.92597852, 2.76965187],
       [0.3921368 , 2.68907313, 2.02829879],
       [0.54223583, 3.42215998, 1.4211892 ],
       [0.90567816, 2.62771445, 1.88799766],
       [1.70872911, 2.75915071, 1.39853465],
       [1.48190142, 3.30075052, 0.78009974],
       [1.06129323, 3.73017167, 2.2083069 ],
       [0.81863359, 2.37943811, 1.87666989],
       [0.599882  , 2.98789866, 2.41035271],
       [0.4914813 , 2.89079656, 2.26782134],
       [0.84409423, 2.86642713, 1.25085451],
       [0.38941349, 2.86642575, 2.11791607],
       [1.53271026, 2.96966239, 3.35089399],
       [0.30831638, 2.77003779, 2.05312152],
       [0.81726253, 2.38255534, 1.83091351],
       [0.56428027, 2.55559903, 1.80454586],
       [0.72672271, 2.8455521 , 1.39825227],
       [1.28805849, 2.56987887, 3.06324547],
       [0.38163798, 2.64007308, 1.89861511],
       [2.31271244, 4.24274589, 1.0584579 ],
       [0.76585766, 3.57067982, 1.5185265 ],
       [2.14762671, 4.44150237, 0.52472   ],
       [1.17645413, 3.69480186, 0.77236486],
       [1.73594932, 4.11613683, 0.53031563],
       [2.78128346, 5.03326801, 1.2022172 ],
       [1.22550604, 3.3503222 , 2.74462238],
       [2.2426558 , 4.577021  , 0.92275933],
       [1.50462864, 4.363498  , 1.40314162],
       [3.22975724, 4.79334275, 1.48323372],
       [1.71837714, 3.62749566, 0.4787491 ],
       [1.10409694, 3.89360823, 1.0325986 ],
       [1.80475907, 4.1132966 , 0.27818948],
       [0.94858807, 3.82688169, 1.91870424],
       [1.39433359, 3.91538879, 1.49910975],
       [1.90677079, 3.89835633, 0.68622715],
       [1.39713702, 3.70128288, 0.46463058],
       [3.85224062, 5.18341242, 2.10127163],
       [2.95786451, 5.58136629, 1.83092395],
       [1.17790381, 4.02615768, 2.37017622],
       [2.27442972, 4.31907679, 0.52540209],
       [0.91211061, 3.4288432 , 1.62249456],
       [2.77937737, 5.19031307, 1.47042293],
       [0.84735471, 3.64273089, 1.15814207],
       [2.15695444, 4.00723617, 0.520093  ],
       [2.33581345, 4.2637671 , 0.66660166],
       [0.79774043, 3.45930032, 1.08324891],
       [1.022307  , 3.27575645, 0.94925151],
       [1.3842265 , 4.05342943, 0.84098317],
       [2.03854964, 4.1585729 , 0.75748198],
       [2.28297732, 4.71100584, 1.07124861],
       [3.88774921, 5.12224641, 2.17345728],
       [1.47357101, 4.13401784, 0.87682321],
       [0.7964005 , 3.39830644, 1.11534598],
       [0.80521086, 3.63719075, 1.59782917],
       [2.8607372 , 5.08776655, 1.25982873],
       [2.3101089 , 4.00416552, 1.07214028],
       [1.46990247, 3.58815834, 0.51434392],
       [0.97017134, 3.19454679, 1.0762733 ],
       [1.97333575, 4.09907253, 0.23050145],
       [2.07939567, 4.28416057, 0.57373487],
       [2.06609741, 4.17402084, 0.51130902],
       [0.76585766, 3.57067982, 1.5185265 ],
       [2.24723796, 4.32128686, 0.54141867],
       [2.42521977, 4.3480018 , 0.85128501],
       [1.82594618, 4.1240495 , 0.52475835],
       [1.03093862, 3.97564407, 1.52100812],
       [1.44892686, 3.7539635 , 0.44371189],
       [2.17585453, 3.7969924 , 1.08437101],
       [1.00508668, 3.25638099, 1.13739231]])

我们计算每个步骤并比较 ONNX 和 scikit-learn 输出。

for step in steps:
    onnx_step = step["onnx_step"]
    sess = InferenceSession(
        onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
    onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
    onnx_output = onnx_outputs[-1]
    skl_outputs = step["model"]._debug.outputs["transform"]

    # comparison
    diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
    print("difference", diff)

# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
difference 4.799262827148709e-07
KMeans(n_clusters=3, n_init=3)
difference 1.095537650763756e-06

用于查看每个节点的 Python 运行时

Python 运行时可能有助于轻松查看 ONNX 图的每个节点。此选项可用于检查计算何时由于 nan 值或维度不匹配而失败。

onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)

oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ],
       [2.675508  , 0.99604493, 4.017933  ]], dtype=float32)]


oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})

# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
 +C Ad_Addcst: float32:(3,) in [0.9830552339553833, 5.035177230834961]
 +C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1359702348709106]
 +C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
 +I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
 + variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
 + Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
 + Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
 + Ge_Y0: float32:(2, 3) in [-10.366023063659668, 7.967348575592041]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
 + Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.817853927612305]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
 + Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.143783569335938]
ArgMin(Ad_C0) -> label
 + label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
 + scores: float32:(2, 3) in [0.2129589319229126, 4.017932891845703]

[array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ],
       [2.675508  , 0.99604493, 4.017933  ]], dtype=float32)]

脚本的总运行时间:(0 分钟 0.291 秒)

由 Sphinx-Gallery 生成的库