注意
转到末尾 下载完整的示例代码
中间结果和调查¶
用户想要的功能不仅仅是将转换后的模型用于 ONNX,有很多原因。可能需要中间结果,即图中每个节点的输出。可能需要更改 ONNX 以删除某些节点。迁移学习通常是删除深度神经网络的最后一层。另一个原因是调试。运行时由于形状不匹配而无法计算预测,这种情况经常发生。然后获取每个中间结果的形状很有用。此示例探讨了两种执行此操作的方法。
查看管道步骤¶
第一种方法比较棘手:它重载方法 transform、predict 和 predict_proba 以保留输入和输出的副本。然后它遍历管道的每个步骤。如果管道有 n 个步骤,它将转换具有步骤 1 的管道,然后转换具有步骤 1、2 的管道,然后转换 1、2、3…
import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType
管道。
该函数遍历每个步骤,重载方法 transform 并为每个步骤返回一个 ONNX 图。
我们调用方法 transform 来填充缓存,重载的方法 transform 保留。
pipe.transform(X)
array([[3.12119834, 0.21295824, 3.98940603],
[2.6755083 , 0.99604549, 4.01793312],
[2.97416665, 0.65198444, 4.19343668],
[2.88014429, 0.9034561 , 4.19784749],
[3.30022609, 0.40215457, 4.11157152],
[3.50554424, 1.21154793, 3.89893116],
[3.14856384, 0.50244932, 4.21638048],
[2.99184826, 0.09132468, 3.97313411],
[2.92515933, 1.42174651, 4.40757189],
[2.79398956, 0.78993078, 4.05764261],
[3.32125333, 0.78999385, 3.92088109],
[3.0493632 , 0.27618123, 4.07853631],
[2.80635045, 1.03497888, 4.16440431],
[3.21220972, 1.33482453, 4.63069748],
[3.88834965, 1.63865558, 4.14619343],
[4.4998303 , 2.39898792, 4.49547518],
[3.60978017, 1.20748818, 4.02966144],
[3.05594182, 0.21618828, 3.91388548],
[3.34493953, 1.20986655, 3.72562039],
[3.50065397, 0.86706182, 4.10101938],
[2.80825681, 0.50401564, 3.66383713],
[3.27800809, 0.66826437, 3.94496718],
[3.58990876, 0.68658071, 4.51061335],
[2.55934697, 0.47945627, 3.57996434],
[2.96493153, 0.36345425, 3.98817445],
[2.55682739, 0.99023912, 3.88431906],
[2.8279719 , 0.22683089, 3.79088782],
[3.05970831, 0.2947186 , 3.89539875],
[2.95425291, 0.25361098, 3.88085622],
[2.87745051, 0.65019824, 4.09851673],
[2.73238773, 0.80138328, 4.01796142],
[2.73361981, 0.52309257, 3.57350896],
[4.11853014, 1.57658655, 4.5037664 ],
[4.22845606, 1.87652483, 4.4465301 ],
[2.71452112, 0.76858489, 3.97906378],
[2.86508665, 0.54896332, 4.01986385],
[3.0573692 , 0.63079314, 3.80064093],
[3.40284985, 0.45982568, 4.25136846],
[3.00742655, 1.2336976 , 4.42052558],
[2.95472117, 0.14580827, 3.90865188],
[3.12324651, 0.20261743, 4.01192633],
[2.90164193, 2.67055552, 4.64398605],
[3.15411688, 0.90927099, 4.42154566],
[2.8613548 , 0.50081008, 3.70483773],
[3.34606471, 0.92159916, 3.9078554 ],
[2.65231058, 1.01946042, 4.01421067],
[3.53206587, 0.86953764, 4.14238152],
[2.99813103, 0.72275914, 4.23577398],
[3.34116935, 0.72324305, 3.97409784],
[2.90222887, 0.30295342, 3.97223984],
[1.9003878 , 3.43619989, 0.95288059],
[1.41851492, 2.97232682, 0.99352148],
[1.68457079, 3.51850037, 0.72661726],
[0.96940962, 3.33264308, 2.69898424],
[0.9112523 , 3.35747592, 1.11074501],
[0.35721918, 2.77550662, 1.8143491 ],
[1.59351202, 3.01808184, 1.00650285],
[1.50213315, 2.77360088, 3.31296552],
[1.11632078, 3.21148368, 1.14114175],
[0.77921299, 2.66294828, 2.42994048],
[1.97194958, 3.62389817, 3.73666782],
[0.77530513, 2.70011145, 1.45918639],
[1.25941769, 3.53658932, 2.74268279],
[0.66155141, 2.98813829, 1.28976474],
[0.73833453, 2.32311723, 2.05251547],
[1.46572707, 3.14311522, 0.98780965],
[0.80185102, 2.68234835, 1.67700171],
[0.568386 , 2.63954211, 2.12682734],
[1.19987895, 3.97369206, 2.33743839],
[0.67881532, 2.87494798, 2.46667974],
[1.34222961, 3.03853641, 1.1880022 ],
[0.53061062, 2.8022861 , 1.63233668],
[0.79234309, 3.68305664, 1.65142259],
[0.57371215, 2.96833851, 1.54593744],
[0.90589785, 2.9760862 , 1.2933375 ],
[1.22490527, 3.13002382, 1.03085926],
[1.26783271, 3.56679427, 1.09304603],
[1.42114042, 3.5903606 , 0.52050254],
[0.58974672, 2.93839428, 1.34712856],
[0.76432091, 2.58203512, 2.44164622],
[0.89738242, 2.99796537, 2.69027665],
[0.98549851, 2.92597852, 2.76965187],
[0.3921368 , 2.68907313, 2.02829879],
[0.54223583, 3.42215998, 1.4211892 ],
[0.90567816, 2.62771445, 1.88799766],
[1.70872911, 2.75915071, 1.39853465],
[1.48190142, 3.30075052, 0.78009974],
[1.06129323, 3.73017167, 2.2083069 ],
[0.81863359, 2.37943811, 1.87666989],
[0.599882 , 2.98789866, 2.41035271],
[0.4914813 , 2.89079656, 2.26782134],
[0.84409423, 2.86642713, 1.25085451],
[0.38941349, 2.86642575, 2.11791607],
[1.53271026, 2.96966239, 3.35089399],
[0.30831638, 2.77003779, 2.05312152],
[0.81726253, 2.38255534, 1.83091351],
[0.56428027, 2.55559903, 1.80454586],
[0.72672271, 2.8455521 , 1.39825227],
[1.28805849, 2.56987887, 3.06324547],
[0.38163798, 2.64007308, 1.89861511],
[2.31271244, 4.24274589, 1.0584579 ],
[0.76585766, 3.57067982, 1.5185265 ],
[2.14762671, 4.44150237, 0.52472 ],
[1.17645413, 3.69480186, 0.77236486],
[1.73594932, 4.11613683, 0.53031563],
[2.78128346, 5.03326801, 1.2022172 ],
[1.22550604, 3.3503222 , 2.74462238],
[2.2426558 , 4.577021 , 0.92275933],
[1.50462864, 4.363498 , 1.40314162],
[3.22975724, 4.79334275, 1.48323372],
[1.71837714, 3.62749566, 0.4787491 ],
[1.10409694, 3.89360823, 1.0325986 ],
[1.80475907, 4.1132966 , 0.27818948],
[0.94858807, 3.82688169, 1.91870424],
[1.39433359, 3.91538879, 1.49910975],
[1.90677079, 3.89835633, 0.68622715],
[1.39713702, 3.70128288, 0.46463058],
[3.85224062, 5.18341242, 2.10127163],
[2.95786451, 5.58136629, 1.83092395],
[1.17790381, 4.02615768, 2.37017622],
[2.27442972, 4.31907679, 0.52540209],
[0.91211061, 3.4288432 , 1.62249456],
[2.77937737, 5.19031307, 1.47042293],
[0.84735471, 3.64273089, 1.15814207],
[2.15695444, 4.00723617, 0.520093 ],
[2.33581345, 4.2637671 , 0.66660166],
[0.79774043, 3.45930032, 1.08324891],
[1.022307 , 3.27575645, 0.94925151],
[1.3842265 , 4.05342943, 0.84098317],
[2.03854964, 4.1585729 , 0.75748198],
[2.28297732, 4.71100584, 1.07124861],
[3.88774921, 5.12224641, 2.17345728],
[1.47357101, 4.13401784, 0.87682321],
[0.7964005 , 3.39830644, 1.11534598],
[0.80521086, 3.63719075, 1.59782917],
[2.8607372 , 5.08776655, 1.25982873],
[2.3101089 , 4.00416552, 1.07214028],
[1.46990247, 3.58815834, 0.51434392],
[0.97017134, 3.19454679, 1.0762733 ],
[1.97333575, 4.09907253, 0.23050145],
[2.07939567, 4.28416057, 0.57373487],
[2.06609741, 4.17402084, 0.51130902],
[0.76585766, 3.57067982, 1.5185265 ],
[2.24723796, 4.32128686, 0.54141867],
[2.42521977, 4.3480018 , 0.85128501],
[1.82594618, 4.1240495 , 0.52475835],
[1.03093862, 3.97564407, 1.52100812],
[1.44892686, 3.7539635 , 0.44371189],
[2.17585453, 3.7969924 , 1.08437101],
[1.00508668, 3.25638099, 1.13739231]])
我们计算每个步骤并比较 ONNX 和 scikit-learn 输出。
for step in steps:
print("----------------------------")
print(step["model"])
onnx_step = step["onnx_step"]
sess = InferenceSession(
onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
onnx_output = onnx_outputs[-1]
skl_outputs = step["model"]._debug.outputs["transform"]
# comparison
diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
print("difference", diff)
# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#
----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3, n_init=3)
difference 1.095537650763756e-06
用于查看每个节点的 Python 运行时¶
Python 运行时可能有助于轻松查看 ONNX 图的每个节点。此选项可用于检查计算何时由于 nan 值或维度不匹配而失败。
onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)
oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ],
[2.675508 , 0.99604493, 4.017933 ]], dtype=float32)]
并了解中间结果。
oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
+C Ad_Addcst: float32:(3,) in [0.9830552339553833, 5.035177230834961]
+C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1359702348709106]
+C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
+I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
+ variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
+ Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
+ Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
+ Ge_Y0: float32:(2, 3) in [-10.366023063659668, 7.967348575592041]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
+ Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.817853927612305]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
+ Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.143783569335938]
ArgMin(Ad_C0) -> label
+ label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
+ scores: float32:(2, 3) in [0.2129589319229126, 4.017932891845703]
[array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ],
[2.675508 , 0.99604493, 4.017933 ]], dtype=float32)]
脚本的总运行时间:(0 分钟 0.291 秒)