注意
转到末尾 下载完整的示例代码。
中间结果和检查¶
用户想要比转换模型到 ONNX 更多的功能,原因有很多。可能需要中间结果,即图中每个节点。ONNX 可能需要修改以删除某些节点。迁移学习通常是删除深度神经网络的最后几层。另一个原因是调试。运行时经常因为形状不匹配而无法计算预测。这时获取每个中间结果的形状就很有用了。本示例将探讨两种方法。
检查管道步骤¶
第一种方法有点棘手:它重载了 transform、predict 和 predict_proba 方法来保存输入和输出的副本。然后它遍历管道的每个步骤。如果管道有 n 个步骤,它会转换包含步骤 1 的管道,然后是包含步骤 1、2 的管道,然后是 1、2、3……
import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType
管道。
该函数遍历每个步骤,重载 transform 方法,并为每个步骤返回一个 ONNX 图。
我们调用 transform 方法来填充被重载的 transform 方法所保留的缓存。
array([[3.04731565, 0.16691922, 3.90105048],
[2.53765667, 1.05008012, 3.89969118],
[2.8534191 , 0.70223132, 4.08275531],
[2.74622964, 0.95461003, 4.08107334],
[3.23237193, 0.35053394, 4.02752145],
[3.48845737, 1.15922234, 3.84344238],
[3.04980527, 0.52431514, 4.11468249],
[2.90441631, 0.13441315, 3.87710094],
[2.76651178, 1.4735953 , 4.28118798],
[2.6666299 , 0.84235354, 3.94465298],
[3.27919712, 0.74060536, 3.85127337],
[2.95563885, 0.30307356, 3.97956436],
[2.66465535, 1.08829832, 4.04591428],
[3.06103061, 1.38230617, 4.50871998],
[3.87938657, 1.58899913, 4.10425124],
[4.51669438, 2.34530067, 4.47566454],
[3.58762514, 1.1553485 , 3.97299158],
[2.98512391, 0.17213393, 3.82533283],
[3.32956199, 1.16389168, 3.67087693],
[3.45834067, 0.81319337, 4.03145794],
[2.73958458, 0.50120097, 3.57505244],
[3.23074088, 0.61712537, 3.86923945],
[3.50566186, 0.66747441, 4.42022781],
[2.4750801 , 0.51155505, 3.47802339],
[2.8736682 , 0.38345232, 3.8892069 ],
[2.42225934, 1.04291694, 3.76676185],
[2.74731742, 0.24845351, 3.69412992],
[2.99076871, 0.25714597, 3.80890374],
[2.87391396, 0.2611406 , 3.78824565],
[2.75790131, 0.70011574, 3.98744779],
[2.60379694, 0.85497633, 3.90292713],
[2.66975272, 0.52152421, 3.48410551],
[4.09438407, 1.52253885, 4.45342411],
[4.22214621, 1.82234641, 4.40895404],
[2.58864695, 0.82260166, 3.86530846],
[2.75335951, 0.60043553, 3.91302139],
[2.999389 , 0.60703836, 3.7203448 ],
[3.32859606, 0.41861164, 4.16574085],
[2.85880737, 1.28372352, 4.29823638],
[2.87114889, 0.16719611, 3.81424762],
[3.04765881, 0.16490106, 3.92166326],
[2.70488461, 2.72505665, 4.50242861],
[3.02636161, 0.9507447 , 4.30808784],
[2.79904696, 0.48729268, 3.61429734],
[3.31203486, 0.87083853, 3.83944771],
[2.51304561, 1.07390656, 3.89419068],
[3.48756201, 0.81546763, 4.07280665],
[2.87514706, 0.77022366, 4.12394547],
[3.29502383, 0.67120692, 3.90226132],
[2.80290905, 0.35513147, 3.87052324],
[2.05944499, 3.43556314, 0.9527343 ],
[1.57420922, 2.9747863 , 0.91533955],
[1.85627645, 3.52176851, 0.6991334 ],
[0.81977325, 3.37169191, 2.55294274],
[1.08188141, 3.37366488, 0.98171662],
[0.41559238, 2.7993061 , 1.67145475],
[1.75089133, 3.01695781, 0.94903551],
[1.3054677 , 2.81936231, 3.16550893],
[1.26497468, 3.2240354 , 1.03500423],
[0.66027189, 2.69540221, 2.28485901],
[1.79458962, 3.67188095, 3.59062915],
[0.90090717, 2.71474217, 1.32750103],
[1.15615482, 3.57454207, 2.60882281],
[0.82846537, 3.00417652, 1.15164743],
[0.7161265 , 2.34736206, 1.91677593],
[1.61900765, 3.14784569, 0.91938068],
[0.89305298, 2.69864257, 1.54559525],
[0.48160791, 2.66850781, 1.98816804],
[1.19881076, 4.00685571, 2.2104156 ],
[0.48795536, 2.9111416 , 2.32130025],
[1.49602283, 3.04351709, 1.09155949],
[0.62552403, 2.82390229, 1.49525468],
[0.90399583, 3.70843522, 1.52048216],
[0.69344472, 2.98843992, 1.4110861 ],
[1.04278165, 2.99074759, 1.17226863],
[1.37965904, 3.13914603, 0.93245527],
[1.42318454, 3.58065437, 0.99964865],
[1.61092178, 3.59787573, 0.41699742],
[0.7588276 , 2.95535251, 1.20525347],
[0.60280715, 2.61712971, 2.30054805],
[0.70912153, 3.03724529, 2.5441796 ],
[0.79425838, 2.96618351, 2.62458034],
[0.31391047, 2.71773006, 1.88550818],
[0.72639669, 3.44369296, 1.2765473 ],
[0.9578884 , 2.64549177, 1.75825155],
[1.83883648, 2.75522333, 1.33978733],
[1.65128962, 3.30521622, 0.71260555],
[1.06192984, 3.76176356, 2.08134113],
[0.85377093, 2.39781055, 1.74666846],
[0.41758567, 3.02333265, 2.26242483],
[0.33048259, 2.92302464, 2.12172553],
[0.99615987, 2.87894547, 1.12220529],
[0.25549279, 2.89750284, 1.97352593],
[1.33587433, 3.01599101, 3.2031426 ],
[0.20110153, 2.79890287, 1.90698574],
[0.85870837, 2.40014831, 1.70298258],
[0.61211268, 2.57682121, 1.6668071 ],
[0.85882898, 2.86194221, 1.26766899],
[1.09695347, 2.61357618, 2.91712098],
[0.38833729, 2.66524256, 1.75577427],
[2.5012404 , 4.2420242 , 1.0918546 ],
[0.92015293, 3.5923261 , 1.37898965],
[2.3460413 , 4.4461714 , 0.61533733],
[1.37537194, 3.7069135 , 0.64037065],
[1.93680867, 4.1234534 , 0.49148505],
[2.97325888, 5.03573504, 1.30938606],
[1.14289326, 3.38471765, 2.6019388 ],
[2.42763947, 4.58333402, 0.97719501],
[1.65576504, 4.38346835, 1.31984222],
[3.4249117 , 4.78218736, 1.61700085],
[1.90971055, 3.62917997, 0.4469887 ],
[1.29266937, 3.91108898, 0.9070538 ],
[2.00571419, 4.1195457 , 0.28701748],
[1.03296583, 3.8538939 , 1.7839299 ],
[1.54437256, 3.93228102, 1.3928708 ],
[2.09882765, 3.9004045 , 0.68129057],
[1.59624925, 3.70929828, 0.33306798],
[4.04228067, 5.16667811, 2.24280811],
[3.13108094, 5.59206313, 1.88051097],
[1.17369301, 4.0595081 , 2.24060385],
[2.47497267, 4.31907189, 0.64593529],
[1.04109807, 3.44883957, 1.48873586],
[2.95941231, 5.19705273, 1.54093757],
[1.03607405, 3.66188818, 1.01844967],
[2.35220999, 4.0048474 , 0.61341103],
[2.52679968, 4.26239611, 0.786576 ],
[0.99715764, 3.4764269 , 0.93620274],
[1.21104872, 3.28685101, 0.8130537 ],
[1.57955582, 4.06724081, 0.73703291],
[2.21952527, 4.16299138, 0.79795731],
[2.46401818, 4.71948935, 1.11257122],
[4.0726311 , 5.10470559, 2.31231128],
[1.66681009, 4.14760691, 0.78525226],
[0.98336185, 3.41503341, 0.97686289],
[0.92929783, 3.66012734, 1.46832263],
[3.05337662, 5.09008925, 1.3692878 ],
[2.49347287, 4.00041166, 1.10609628],
[1.66406596, 3.59345432, 0.41554383],
[1.15070886, 3.20642858, 0.94210414],
[2.17245566, 4.10202725, 0.33887736],
[2.27944809, 4.28777875, 0.62347517],
[2.26277812, 4.17709068, 0.57349609],
[0.92015293, 3.5923261 , 1.37898965],
[2.44823886, 4.32158336, 0.64919807],
[2.62180347, 4.3459761 , 0.93691617],
[2.02412194, 4.13081941, 0.51118153],
[1.17215474, 3.99933061, 1.396946 ],
[1.649638 , 3.76205335, 0.31254892],
[2.35342278, 3.79343916, 1.09600723],
[1.18085642, 3.26840383, 1.00795325]])
我们计算每个步骤,并比较 ONNX 和 scikit-learn 的输出。
for step in steps:
print("----------------------------")
print(step["model"])
onnx_step = step["onnx_step"]
sess = InferenceSession(
onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
onnx_output = onnx_outputs[-1]
skl_outputs = step["model"]._debug.outputs["transform"]
# comparison
diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
print("difference", diff)
# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#
----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3, n_init=3)
difference 1.6719496572781267e-06
Python 运行时用于检查每个节点¶
Python 运行时可能有助于轻松检查 ONNX 图中的每个节点。此选项可用于检查计算何时因 nan 值或维度不匹配而失败。
onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)
oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[3.0473158 , 0.16691935, 3.9010506 ],
[2.5376565 , 1.05008 , 3.8996909 ]], dtype=float32)]
并了解中间结果。
oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
+C Ad_Addcst: float32:(3,) in [1.0728117227554321, 5.101926803588867]
+C Ge_Gemmcst: float32:(3, 4) in [-1.3031082153320312, 1.0335986614227295]
+C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
+I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
+ variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
+ Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
+ Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
+ Ge_Y0: float32:(2, 3) in [-10.450262069702148, 7.452452182769775]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
+ Ad_C01: float32:(2, 3) in [-5.0740647315979, 12.313563346862793]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
+ Ad_C0: float32:(2, 3) in [0.027862071990966797, 15.218194961547852]
ArgMin(Ad_C0) -> label
+ label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
+ scores: float32:(2, 3) in [0.16691935062408447, 3.901050567626953]
[array([1, 1]), array([[3.0473158 , 0.16691935, 3.9010506 ],
[2.5376565 , 1.05008 , 3.8996909 ]], dtype=float32)]
脚本总运行时间: (0 分钟 0.139 秒)