中间结果和检查

用户想要比转换模型到 ONNX 更多的功能,原因有很多。可能需要中间结果,即图中每个节点。ONNX 可能需要修改以删除某些节点。迁移学习通常是删除深度神经网络的最后几层。另一个原因是调试。运行时经常因为形状不匹配而无法计算预测。这时获取每个中间结果的形状就很有用了。本示例将探讨两种方法。

检查管道步骤

第一种方法有点棘手:它重载了 transformpredictpredict_proba 方法来保存输入和输出的副本。然后它遍历管道的每个步骤。如果管道有 n 个步骤,它会转换包含步骤 1 的管道,然后是包含步骤 1、2 的管道,然后是 1、2、3……

import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType

管道。

data = load_iris()
X = data.data

pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))])
pipe.fit(X)
Pipeline(steps=[('std', StandardScaler()),
                ('km', KMeans(n_clusters=3, n_init=3))])
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示,或信任该笔记本。
在 GitHub 上,HTML 表示无法渲染,请尝试使用 nbviewer.org 加载此页面。


该函数遍历每个步骤,重载 transform 方法,并为每个步骤返回一个 ONNX 图。

steps = collect_intermediate_steps(
    pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17
)

我们调用 transform 方法来填充被重载的 transform 方法所保留的缓存。

array([[3.04731565, 0.16691922, 3.90105048],
       [2.53765667, 1.05008012, 3.89969118],
       [2.8534191 , 0.70223132, 4.08275531],
       [2.74622964, 0.95461003, 4.08107334],
       [3.23237193, 0.35053394, 4.02752145],
       [3.48845737, 1.15922234, 3.84344238],
       [3.04980527, 0.52431514, 4.11468249],
       [2.90441631, 0.13441315, 3.87710094],
       [2.76651178, 1.4735953 , 4.28118798],
       [2.6666299 , 0.84235354, 3.94465298],
       [3.27919712, 0.74060536, 3.85127337],
       [2.95563885, 0.30307356, 3.97956436],
       [2.66465535, 1.08829832, 4.04591428],
       [3.06103061, 1.38230617, 4.50871998],
       [3.87938657, 1.58899913, 4.10425124],
       [4.51669438, 2.34530067, 4.47566454],
       [3.58762514, 1.1553485 , 3.97299158],
       [2.98512391, 0.17213393, 3.82533283],
       [3.32956199, 1.16389168, 3.67087693],
       [3.45834067, 0.81319337, 4.03145794],
       [2.73958458, 0.50120097, 3.57505244],
       [3.23074088, 0.61712537, 3.86923945],
       [3.50566186, 0.66747441, 4.42022781],
       [2.4750801 , 0.51155505, 3.47802339],
       [2.8736682 , 0.38345232, 3.8892069 ],
       [2.42225934, 1.04291694, 3.76676185],
       [2.74731742, 0.24845351, 3.69412992],
       [2.99076871, 0.25714597, 3.80890374],
       [2.87391396, 0.2611406 , 3.78824565],
       [2.75790131, 0.70011574, 3.98744779],
       [2.60379694, 0.85497633, 3.90292713],
       [2.66975272, 0.52152421, 3.48410551],
       [4.09438407, 1.52253885, 4.45342411],
       [4.22214621, 1.82234641, 4.40895404],
       [2.58864695, 0.82260166, 3.86530846],
       [2.75335951, 0.60043553, 3.91302139],
       [2.999389  , 0.60703836, 3.7203448 ],
       [3.32859606, 0.41861164, 4.16574085],
       [2.85880737, 1.28372352, 4.29823638],
       [2.87114889, 0.16719611, 3.81424762],
       [3.04765881, 0.16490106, 3.92166326],
       [2.70488461, 2.72505665, 4.50242861],
       [3.02636161, 0.9507447 , 4.30808784],
       [2.79904696, 0.48729268, 3.61429734],
       [3.31203486, 0.87083853, 3.83944771],
       [2.51304561, 1.07390656, 3.89419068],
       [3.48756201, 0.81546763, 4.07280665],
       [2.87514706, 0.77022366, 4.12394547],
       [3.29502383, 0.67120692, 3.90226132],
       [2.80290905, 0.35513147, 3.87052324],
       [2.05944499, 3.43556314, 0.9527343 ],
       [1.57420922, 2.9747863 , 0.91533955],
       [1.85627645, 3.52176851, 0.6991334 ],
       [0.81977325, 3.37169191, 2.55294274],
       [1.08188141, 3.37366488, 0.98171662],
       [0.41559238, 2.7993061 , 1.67145475],
       [1.75089133, 3.01695781, 0.94903551],
       [1.3054677 , 2.81936231, 3.16550893],
       [1.26497468, 3.2240354 , 1.03500423],
       [0.66027189, 2.69540221, 2.28485901],
       [1.79458962, 3.67188095, 3.59062915],
       [0.90090717, 2.71474217, 1.32750103],
       [1.15615482, 3.57454207, 2.60882281],
       [0.82846537, 3.00417652, 1.15164743],
       [0.7161265 , 2.34736206, 1.91677593],
       [1.61900765, 3.14784569, 0.91938068],
       [0.89305298, 2.69864257, 1.54559525],
       [0.48160791, 2.66850781, 1.98816804],
       [1.19881076, 4.00685571, 2.2104156 ],
       [0.48795536, 2.9111416 , 2.32130025],
       [1.49602283, 3.04351709, 1.09155949],
       [0.62552403, 2.82390229, 1.49525468],
       [0.90399583, 3.70843522, 1.52048216],
       [0.69344472, 2.98843992, 1.4110861 ],
       [1.04278165, 2.99074759, 1.17226863],
       [1.37965904, 3.13914603, 0.93245527],
       [1.42318454, 3.58065437, 0.99964865],
       [1.61092178, 3.59787573, 0.41699742],
       [0.7588276 , 2.95535251, 1.20525347],
       [0.60280715, 2.61712971, 2.30054805],
       [0.70912153, 3.03724529, 2.5441796 ],
       [0.79425838, 2.96618351, 2.62458034],
       [0.31391047, 2.71773006, 1.88550818],
       [0.72639669, 3.44369296, 1.2765473 ],
       [0.9578884 , 2.64549177, 1.75825155],
       [1.83883648, 2.75522333, 1.33978733],
       [1.65128962, 3.30521622, 0.71260555],
       [1.06192984, 3.76176356, 2.08134113],
       [0.85377093, 2.39781055, 1.74666846],
       [0.41758567, 3.02333265, 2.26242483],
       [0.33048259, 2.92302464, 2.12172553],
       [0.99615987, 2.87894547, 1.12220529],
       [0.25549279, 2.89750284, 1.97352593],
       [1.33587433, 3.01599101, 3.2031426 ],
       [0.20110153, 2.79890287, 1.90698574],
       [0.85870837, 2.40014831, 1.70298258],
       [0.61211268, 2.57682121, 1.6668071 ],
       [0.85882898, 2.86194221, 1.26766899],
       [1.09695347, 2.61357618, 2.91712098],
       [0.38833729, 2.66524256, 1.75577427],
       [2.5012404 , 4.2420242 , 1.0918546 ],
       [0.92015293, 3.5923261 , 1.37898965],
       [2.3460413 , 4.4461714 , 0.61533733],
       [1.37537194, 3.7069135 , 0.64037065],
       [1.93680867, 4.1234534 , 0.49148505],
       [2.97325888, 5.03573504, 1.30938606],
       [1.14289326, 3.38471765, 2.6019388 ],
       [2.42763947, 4.58333402, 0.97719501],
       [1.65576504, 4.38346835, 1.31984222],
       [3.4249117 , 4.78218736, 1.61700085],
       [1.90971055, 3.62917997, 0.4469887 ],
       [1.29266937, 3.91108898, 0.9070538 ],
       [2.00571419, 4.1195457 , 0.28701748],
       [1.03296583, 3.8538939 , 1.7839299 ],
       [1.54437256, 3.93228102, 1.3928708 ],
       [2.09882765, 3.9004045 , 0.68129057],
       [1.59624925, 3.70929828, 0.33306798],
       [4.04228067, 5.16667811, 2.24280811],
       [3.13108094, 5.59206313, 1.88051097],
       [1.17369301, 4.0595081 , 2.24060385],
       [2.47497267, 4.31907189, 0.64593529],
       [1.04109807, 3.44883957, 1.48873586],
       [2.95941231, 5.19705273, 1.54093757],
       [1.03607405, 3.66188818, 1.01844967],
       [2.35220999, 4.0048474 , 0.61341103],
       [2.52679968, 4.26239611, 0.786576  ],
       [0.99715764, 3.4764269 , 0.93620274],
       [1.21104872, 3.28685101, 0.8130537 ],
       [1.57955582, 4.06724081, 0.73703291],
       [2.21952527, 4.16299138, 0.79795731],
       [2.46401818, 4.71948935, 1.11257122],
       [4.0726311 , 5.10470559, 2.31231128],
       [1.66681009, 4.14760691, 0.78525226],
       [0.98336185, 3.41503341, 0.97686289],
       [0.92929783, 3.66012734, 1.46832263],
       [3.05337662, 5.09008925, 1.3692878 ],
       [2.49347287, 4.00041166, 1.10609628],
       [1.66406596, 3.59345432, 0.41554383],
       [1.15070886, 3.20642858, 0.94210414],
       [2.17245566, 4.10202725, 0.33887736],
       [2.27944809, 4.28777875, 0.62347517],
       [2.26277812, 4.17709068, 0.57349609],
       [0.92015293, 3.5923261 , 1.37898965],
       [2.44823886, 4.32158336, 0.64919807],
       [2.62180347, 4.3459761 , 0.93691617],
       [2.02412194, 4.13081941, 0.51118153],
       [1.17215474, 3.99933061, 1.396946  ],
       [1.649638  , 3.76205335, 0.31254892],
       [2.35342278, 3.79343916, 1.09600723],
       [1.18085642, 3.26840383, 1.00795325]])

我们计算每个步骤,并比较 ONNX 和 scikit-learn 的输出。

for step in steps:
    print("----------------------------")
    print(step["model"])
    onnx_step = step["onnx_step"]
    sess = InferenceSession(
        onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
    onnx_output = onnx_outputs[-1]
    skl_outputs = step["model"]._debug.outputs["transform"]

    # comparison
    diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
    print("difference", diff)

# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#
----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3, n_init=3)
difference 1.6719496572781267e-06

Python 运行时用于检查每个节点

Python 运行时可能有助于轻松检查 ONNX 图中的每个节点。此选项可用于检查计算何时因 nan 值或维度不匹配而失败。

onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)

oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[3.0473158 , 0.16691935, 3.9010506 ],
       [2.5376565 , 1.05008   , 3.8996909 ]], dtype=float32)]

并了解中间结果。

oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})

# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
 +C Ad_Addcst: float32:(3,) in [1.0728117227554321, 5.101926803588867]
 +C Ge_Gemmcst: float32:(3, 4) in [-1.3031082153320312, 1.0335986614227295]
 +C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
 +I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
 + variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
 + Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
 + Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
 + Ge_Y0: float32:(2, 3) in [-10.450262069702148, 7.452452182769775]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
 + Ad_C01: float32:(2, 3) in [-5.0740647315979, 12.313563346862793]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
 + Ad_C0: float32:(2, 3) in [0.027862071990966797, 15.218194961547852]
ArgMin(Ad_C0) -> label
 + label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
 + scores: float32:(2, 3) in [0.16691935062408447, 3.901050567626953]

[array([1, 1]), array([[3.0473158 , 0.16691935, 3.9010506 ],
       [2.5376565 , 1.05008   , 3.8996909 ]], dtype=float32)]

脚本总运行时间: (0 分钟 0.139 秒)

Sphinx-Gallery 生成的图库