中间结果和调查

用户希望不仅仅使用转换为 ONNX 的模型,这有很多原因。可能需要中间结果,即图中每个节点的输出。ONNX 可能需要修改以移除某些节点。迁移学习通常是移除深度神经网络的最后几层。另一个原因是调试。运行时经常因形状不匹配而无法计算预测。这时,获取每个中间结果的形状就很有用了。本示例探讨了两种实现方式。

查看流水线步骤

第一种方法比较巧妙:它重载了 *transform*、*predict* 和 *predict_proba* 方法,以保留输入和输出的副本。然后它会遍历流水线的每个步骤。如果流水线有 *n* 个步骤,它会先转换包含步骤 1 的流水线,然后是包含步骤 1 和 2 的流水线,再然后是包含 1、2、3 的流水线……

import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType

流水线。

data = load_iris()
X = data.data

pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))])
pipe.fit(X)
Pipeline(steps=[('std', StandardScaler()),
                ('km', KMeans(n_clusters=3, n_init=3))])
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任笔记本。
在 GitHub 上,HTML 表示无法渲染,请尝试使用 nbviewer.org 加载此页面。


该函数会遍历每个步骤,重载 *transform* 方法并为每个步骤返回一个 ONNX 图。

steps = collect_intermediate_steps(
    pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17
)

我们调用 transform 方法来填充被重载的 *transform* 方法保留的缓存。

pipe.transform(X)
array([[4.00404832, 0.21295824, 3.15861505],
       [4.05055769, 0.99604549, 2.72563625],
       [4.22040251, 0.65198444, 3.02188403],
       [4.22860026, 0.9034561 , 2.93043986],
       [4.12353003, 0.40215457, 3.33653691],
       [3.89643029, 1.21154793, 3.52936423],
       [4.2374443 , 0.50244932, 3.19234391],
       [3.99197553, 0.09132468, 3.03242342],
       [4.4445734 , 1.42174651, 2.9795537 ],
       [4.08705397, 0.78993078, 2.84221713],
       [3.92610748, 0.78999385, 3.3507236 ],
       [4.09865843, 0.27618123, 3.09168785],
       [4.19718995, 1.03497888, 2.85719428],
       [4.66454355, 1.33482453, 3.26547013],
       [4.13826195, 1.63865558, 3.90871872],
       [4.47633229, 2.39898792, 4.51414747],
       [4.02762963, 1.20748818, 3.63475229],
       [3.92839122, 0.21618828, 3.09288714],
       [3.72388908, 1.20986655, 3.36736664],
       [4.10521298, 0.86706182, 3.53103908],
       [3.67990695, 0.50401564, 2.8436663 ],
       [3.95222508, 0.66826437, 3.30977167],
       [4.52523323, 0.68658071, 3.63034505],
       [3.60185594, 0.47945627, 2.59973228],
       [4.00845791, 0.36345425, 3.00678098],
       [3.91688379, 0.99023912, 2.60615351],
       [3.80966594, 0.22683089, 2.86756462],
       [3.90931811, 0.2947186 , 3.0958103 ],
       [3.89828815, 0.25361098, 2.99275191],
       [4.12581898, 0.65019824, 2.92503544],
       [4.04810077, 0.80138328, 2.78137328],
       [3.58928575, 0.52309257, 2.76837135],
       [4.49874494, 1.57658655, 4.14390673],
       [4.43563509, 1.87652483, 4.2489329 ],
       [4.008642  , 0.76858489, 2.76272437],
       [4.04525625, 0.54896332, 2.9103173 ],
       [3.81211172, 0.63079314, 3.09001251],
       [4.26421417, 0.45982568, 3.44057865],
       [4.45456872, 1.2336976 , 3.06034971],
       [3.92683189, 0.14580827, 2.99422852],
       [4.02712265, 0.20261743, 3.16142101],
       [4.69480008, 2.67055552, 2.9575648 ],
       [4.4496996 , 0.90927099, 3.20355969],
       [3.71964918, 0.50081008, 2.89721622],
       [3.91143692, 0.92159916, 3.37471011],
       [4.04740147, 1.01946042, 2.70316642],
       [4.14683513, 0.86953764, 3.56280964],
       [4.26327469, 0.72275914, 3.04646993],
       [3.98021229, 0.72324305, 3.37186092],
       [3.99446269, 0.30295342, 2.94518173],
       [0.9452659 , 3.43619989, 1.8639233 ],
       [1.00829443, 2.97232682, 1.38933168],
       [0.73653572, 3.51850037, 1.6428166 ],
       [2.76204203, 3.33264308, 1.00264343],
       [1.16604995, 3.35747592, 0.86560047],
       [1.86711784, 2.77550662, 0.3750882 ],
       [1.00955989, 3.01808184, 1.56489146],
       [3.3697155 , 2.77360088, 1.55619573],
       [1.18358725, 3.21148368, 1.08067281],
       [2.48285941, 2.66294828, 0.82637993],
       [3.79967007, 3.62389817, 2.01281316],
       [1.50054672, 2.70011145, 0.76353654],
       [2.80438695, 3.53658932, 1.27727048],
       [1.34023352, 2.98813829, 0.62868121],
       [2.09655735, 2.32311723, 0.77087912],
       [1.00633966, 3.14311522, 1.43272989],
       [1.71909321, 2.68234835, 0.80192   ],
       [2.17926627, 2.63954211, 0.60569829],
       [2.40214871, 3.97369206, 1.18764767],
       [2.52511757, 2.87494798, 0.727372  ],
       [1.21113562, 3.03853641, 1.31653995],
       [1.68291281, 2.8022861 , 0.52313867],
       [1.71597913, 3.68305664, 0.75211692],
       [1.59856561, 2.96833851, 0.55292557],
       [1.33753092, 2.9760862 , 0.87815407],
       [1.06462905, 3.13002382, 1.19061026],
       [1.13996294, 3.56679427, 1.22441299],
       [0.5652633 , 3.5903606 , 1.37258261],
       [1.39763754, 2.93839428, 0.56006248],
       [2.49518379, 2.58203512, 0.81289907],
       [2.75025306, 2.99796537, 0.94324481],
       [2.82866407, 2.92597852, 1.03283946],
       [2.08201734, 2.68907313, 0.4343386 ],
       [1.48418961, 3.42215998, 0.48873673],
       [1.92943813, 2.62771445, 0.91606802],
       [1.40011111, 2.75915071, 1.69140864],
       [0.79992473, 3.30075052, 1.44311693],
       [2.2708714 , 3.73017167, 1.05036852],
       [1.91690629, 2.37943811, 0.83618809],
       [2.47017911, 2.98789866, 0.6470029 ],
       [2.32571939, 2.89079656, 0.53979211],
       [1.29304411, 2.86642713, 0.81855214],
       [2.17526444, 2.86642575, 0.43194777],
       [3.40973541, 2.96966239, 1.58383257],
       [2.10849001, 2.77003779, 0.3618706 ],
       [1.87076527, 2.38255534, 0.83187956],
       [1.85116384, 2.55559903, 0.58147273],
       [1.44451588, 2.8455521 , 0.70529895],
       [3.11774537, 2.56987887, 1.34329146],
       [1.94990512, 2.64007308, 0.41481694],
       [1.04248866, 4.24274589, 2.26819164],
       [1.57935402, 3.57067982, 0.72581017],
       [0.52274684, 4.44150237, 2.09231844],
       [0.83298461, 3.69480186, 1.12321156],
       [0.5678145 , 4.11613683, 1.68255837],
       [1.1830756 , 5.03326801, 2.72592116],
       [2.8024351 , 3.3503222 , 1.25267619],
       [0.93117407, 4.577021  , 2.18852343],
       [1.46246781, 4.363498  , 1.45283591],
       [1.4207266 , 4.79334275, 3.18264007],
       [0.47962495, 3.62749566, 1.67405555],
       [1.09881086, 3.89360823, 1.04698204],
       [0.31830999, 4.1132966 , 1.75049044],
       [1.98175664, 3.82688169, 0.92293569],
       [1.54698303, 3.91538879, 1.35721732],
       [0.68407345, 3.89835633, 1.86138575],
       [0.52205472, 3.70128288, 1.34561415],
       [2.03678461, 5.18341242, 3.80620352],
       [1.84250874, 5.58136629, 2.90217633],
       [2.43634558, 4.02615768, 1.16636059],
       [0.48150581, 4.31907679, 2.22297775],
       [1.67578773, 3.4288432 , 0.88685031],
       [1.47096547, 5.19031307, 2.72431414],
       [1.22329554, 3.64273089, 0.79101156],
       [0.47109224, 4.00723617, 2.10999425],
       [0.62558995, 4.2637671 , 2.28591141],
       [1.14490402, 3.45930032, 0.74392898],
       [0.99645552, 3.27575645, 0.98053107],
       [0.90181942, 4.05342943, 1.3282425 ],
       [0.76242411, 4.1585729 , 1.98849304],
       [1.08628479, 4.71100584, 2.22822113],
       [2.10967488, 5.12224641, 3.84302072],
       [0.93357383, 4.13401784, 1.41836425],
       [1.17526973, 3.39830644, 0.74517066],
       [1.66051938, 3.63719075, 0.76558228],
       [1.23742547, 5.08776655, 2.80545775],
       [1.04697429, 4.00416552, 2.26945032],
       [0.55013293, 3.58815834, 1.42313566],
       [1.12188023, 3.19454679, 0.93290167],
       [0.20983625, 4.09907253, 1.92136662],
       [0.5691276 , 4.28416057, 2.02737038],
       [0.49810802, 4.17402084, 2.01513279],
       [1.57935402, 3.57067982, 0.72581017],
       [0.50497262, 4.32128686, 2.19577242],
       [0.81423561, 4.3480018 , 2.37699732],
       [0.55018391, 4.1240495 , 1.77340222],
       [1.58648502, 3.97564407, 0.98294137],
       [0.49931367, 3.7539635 , 1.39731191],
       [1.06536484, 3.7969924 , 2.13822884],
       [1.18287527, 3.25638099, 0.96885287]])

我们计算每个步骤并比较 ONNX 和 scikit-learn 的输出。

for step in steps:
    print("----------------------------")
    print(step["model"])
    onnx_step = step["onnx_step"]
    sess = InferenceSession(
        onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
    onnx_output = onnx_outputs[-1]
    skl_outputs = step["model"]._debug.outputs["transform"]

    # comparison
    diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
    print("difference", diff)

# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#
----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3, n_init=3)
difference 1.095537650763756e-06

Python 运行时查看每个节点

Python 运行时对于轻松查看 ONNX 图中的每个节点可能很有用。此选项可用于检查计算何时因 nan 值或维度不匹配而失败。

onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)

oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[4.0040483 , 0.21295893, 3.158615  ],
       [4.050557  , 0.99604493, 2.7256362 ]], dtype=float32)]

并了解中间结果。

oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})

# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
 +C Ad_Addcst: float32:(3,) in [1.0065230131149292, 5.035177230834961]
 +C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1674340963363647]
 +C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
 +I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
 + variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
 + Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
 + Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
 + Ge_Y0: float32:(2, 3) in [-10.366023063659668, 8.10552978515625]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
 + Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.956035614013672]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
 + Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.407014846801758]
ArgMin(Ad_C0) -> label
 + label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
 + scores: float32:(2, 3) in [0.2129589319229126, 4.0505571365356445]

[array([1, 1]), array([[4.0040483 , 0.21295893, 3.158615  ],
       [4.050557  , 0.99604493, 2.7256362 ]], dtype=float32)]

**脚本总运行时间:** (0 分 0.115 秒)

由 Sphinx-Gallery 生成的画廊