注意
转到末尾下载完整的示例代码。
中间结果和调查¶
用户希望不仅仅使用转换为 ONNX 的模型,这有很多原因。可能需要中间结果,即图中每个节点的输出。ONNX 可能需要修改以移除某些节点。迁移学习通常是移除深度神经网络的最后几层。另一个原因是调试。运行时经常因形状不匹配而无法计算预测。这时,获取每个中间结果的形状就很有用了。本示例探讨了两种实现方式。
查看流水线步骤¶
第一种方法比较巧妙:它重载了 *transform*、*predict* 和 *predict_proba* 方法,以保留输入和输出的副本。然后它会遍历流水线的每个步骤。如果流水线有 *n* 个步骤,它会先转换包含步骤 1 的流水线,然后是包含步骤 1 和 2 的流水线,再然后是包含 1、2、3 的流水线……
import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType
流水线。
data = load_iris()
X = data.data
pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))])
pipe.fit(X)
该函数会遍历每个步骤,重载 *transform* 方法并为每个步骤返回一个 ONNX 图。
steps = collect_intermediate_steps(
pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17
)
我们调用 transform 方法来填充被重载的 *transform* 方法保留的缓存。
pipe.transform(X)
array([[4.00404832, 0.21295824, 3.15861505],
[4.05055769, 0.99604549, 2.72563625],
[4.22040251, 0.65198444, 3.02188403],
[4.22860026, 0.9034561 , 2.93043986],
[4.12353003, 0.40215457, 3.33653691],
[3.89643029, 1.21154793, 3.52936423],
[4.2374443 , 0.50244932, 3.19234391],
[3.99197553, 0.09132468, 3.03242342],
[4.4445734 , 1.42174651, 2.9795537 ],
[4.08705397, 0.78993078, 2.84221713],
[3.92610748, 0.78999385, 3.3507236 ],
[4.09865843, 0.27618123, 3.09168785],
[4.19718995, 1.03497888, 2.85719428],
[4.66454355, 1.33482453, 3.26547013],
[4.13826195, 1.63865558, 3.90871872],
[4.47633229, 2.39898792, 4.51414747],
[4.02762963, 1.20748818, 3.63475229],
[3.92839122, 0.21618828, 3.09288714],
[3.72388908, 1.20986655, 3.36736664],
[4.10521298, 0.86706182, 3.53103908],
[3.67990695, 0.50401564, 2.8436663 ],
[3.95222508, 0.66826437, 3.30977167],
[4.52523323, 0.68658071, 3.63034505],
[3.60185594, 0.47945627, 2.59973228],
[4.00845791, 0.36345425, 3.00678098],
[3.91688379, 0.99023912, 2.60615351],
[3.80966594, 0.22683089, 2.86756462],
[3.90931811, 0.2947186 , 3.0958103 ],
[3.89828815, 0.25361098, 2.99275191],
[4.12581898, 0.65019824, 2.92503544],
[4.04810077, 0.80138328, 2.78137328],
[3.58928575, 0.52309257, 2.76837135],
[4.49874494, 1.57658655, 4.14390673],
[4.43563509, 1.87652483, 4.2489329 ],
[4.008642 , 0.76858489, 2.76272437],
[4.04525625, 0.54896332, 2.9103173 ],
[3.81211172, 0.63079314, 3.09001251],
[4.26421417, 0.45982568, 3.44057865],
[4.45456872, 1.2336976 , 3.06034971],
[3.92683189, 0.14580827, 2.99422852],
[4.02712265, 0.20261743, 3.16142101],
[4.69480008, 2.67055552, 2.9575648 ],
[4.4496996 , 0.90927099, 3.20355969],
[3.71964918, 0.50081008, 2.89721622],
[3.91143692, 0.92159916, 3.37471011],
[4.04740147, 1.01946042, 2.70316642],
[4.14683513, 0.86953764, 3.56280964],
[4.26327469, 0.72275914, 3.04646993],
[3.98021229, 0.72324305, 3.37186092],
[3.99446269, 0.30295342, 2.94518173],
[0.9452659 , 3.43619989, 1.8639233 ],
[1.00829443, 2.97232682, 1.38933168],
[0.73653572, 3.51850037, 1.6428166 ],
[2.76204203, 3.33264308, 1.00264343],
[1.16604995, 3.35747592, 0.86560047],
[1.86711784, 2.77550662, 0.3750882 ],
[1.00955989, 3.01808184, 1.56489146],
[3.3697155 , 2.77360088, 1.55619573],
[1.18358725, 3.21148368, 1.08067281],
[2.48285941, 2.66294828, 0.82637993],
[3.79967007, 3.62389817, 2.01281316],
[1.50054672, 2.70011145, 0.76353654],
[2.80438695, 3.53658932, 1.27727048],
[1.34023352, 2.98813829, 0.62868121],
[2.09655735, 2.32311723, 0.77087912],
[1.00633966, 3.14311522, 1.43272989],
[1.71909321, 2.68234835, 0.80192 ],
[2.17926627, 2.63954211, 0.60569829],
[2.40214871, 3.97369206, 1.18764767],
[2.52511757, 2.87494798, 0.727372 ],
[1.21113562, 3.03853641, 1.31653995],
[1.68291281, 2.8022861 , 0.52313867],
[1.71597913, 3.68305664, 0.75211692],
[1.59856561, 2.96833851, 0.55292557],
[1.33753092, 2.9760862 , 0.87815407],
[1.06462905, 3.13002382, 1.19061026],
[1.13996294, 3.56679427, 1.22441299],
[0.5652633 , 3.5903606 , 1.37258261],
[1.39763754, 2.93839428, 0.56006248],
[2.49518379, 2.58203512, 0.81289907],
[2.75025306, 2.99796537, 0.94324481],
[2.82866407, 2.92597852, 1.03283946],
[2.08201734, 2.68907313, 0.4343386 ],
[1.48418961, 3.42215998, 0.48873673],
[1.92943813, 2.62771445, 0.91606802],
[1.40011111, 2.75915071, 1.69140864],
[0.79992473, 3.30075052, 1.44311693],
[2.2708714 , 3.73017167, 1.05036852],
[1.91690629, 2.37943811, 0.83618809],
[2.47017911, 2.98789866, 0.6470029 ],
[2.32571939, 2.89079656, 0.53979211],
[1.29304411, 2.86642713, 0.81855214],
[2.17526444, 2.86642575, 0.43194777],
[3.40973541, 2.96966239, 1.58383257],
[2.10849001, 2.77003779, 0.3618706 ],
[1.87076527, 2.38255534, 0.83187956],
[1.85116384, 2.55559903, 0.58147273],
[1.44451588, 2.8455521 , 0.70529895],
[3.11774537, 2.56987887, 1.34329146],
[1.94990512, 2.64007308, 0.41481694],
[1.04248866, 4.24274589, 2.26819164],
[1.57935402, 3.57067982, 0.72581017],
[0.52274684, 4.44150237, 2.09231844],
[0.83298461, 3.69480186, 1.12321156],
[0.5678145 , 4.11613683, 1.68255837],
[1.1830756 , 5.03326801, 2.72592116],
[2.8024351 , 3.3503222 , 1.25267619],
[0.93117407, 4.577021 , 2.18852343],
[1.46246781, 4.363498 , 1.45283591],
[1.4207266 , 4.79334275, 3.18264007],
[0.47962495, 3.62749566, 1.67405555],
[1.09881086, 3.89360823, 1.04698204],
[0.31830999, 4.1132966 , 1.75049044],
[1.98175664, 3.82688169, 0.92293569],
[1.54698303, 3.91538879, 1.35721732],
[0.68407345, 3.89835633, 1.86138575],
[0.52205472, 3.70128288, 1.34561415],
[2.03678461, 5.18341242, 3.80620352],
[1.84250874, 5.58136629, 2.90217633],
[2.43634558, 4.02615768, 1.16636059],
[0.48150581, 4.31907679, 2.22297775],
[1.67578773, 3.4288432 , 0.88685031],
[1.47096547, 5.19031307, 2.72431414],
[1.22329554, 3.64273089, 0.79101156],
[0.47109224, 4.00723617, 2.10999425],
[0.62558995, 4.2637671 , 2.28591141],
[1.14490402, 3.45930032, 0.74392898],
[0.99645552, 3.27575645, 0.98053107],
[0.90181942, 4.05342943, 1.3282425 ],
[0.76242411, 4.1585729 , 1.98849304],
[1.08628479, 4.71100584, 2.22822113],
[2.10967488, 5.12224641, 3.84302072],
[0.93357383, 4.13401784, 1.41836425],
[1.17526973, 3.39830644, 0.74517066],
[1.66051938, 3.63719075, 0.76558228],
[1.23742547, 5.08776655, 2.80545775],
[1.04697429, 4.00416552, 2.26945032],
[0.55013293, 3.58815834, 1.42313566],
[1.12188023, 3.19454679, 0.93290167],
[0.20983625, 4.09907253, 1.92136662],
[0.5691276 , 4.28416057, 2.02737038],
[0.49810802, 4.17402084, 2.01513279],
[1.57935402, 3.57067982, 0.72581017],
[0.50497262, 4.32128686, 2.19577242],
[0.81423561, 4.3480018 , 2.37699732],
[0.55018391, 4.1240495 , 1.77340222],
[1.58648502, 3.97564407, 0.98294137],
[0.49931367, 3.7539635 , 1.39731191],
[1.06536484, 3.7969924 , 2.13822884],
[1.18287527, 3.25638099, 0.96885287]])
我们计算每个步骤并比较 ONNX 和 scikit-learn 的输出。
for step in steps:
print("----------------------------")
print(step["model"])
onnx_step = step["onnx_step"]
sess = InferenceSession(
onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
onnx_output = onnx_outputs[-1]
skl_outputs = step["model"]._debug.outputs["transform"]
# comparison
diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
print("difference", diff)
# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#
----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3, n_init=3)
difference 1.095537650763756e-06
Python 运行时查看每个节点¶
Python 运行时对于轻松查看 ONNX 图中的每个节点可能很有用。此选项可用于检查计算何时因 nan 值或维度不匹配而失败。
onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)
oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[4.0040483 , 0.21295893, 3.158615 ],
[4.050557 , 0.99604493, 2.7256362 ]], dtype=float32)]
并了解中间结果。
oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
+C Ad_Addcst: float32:(3,) in [1.0065230131149292, 5.035177230834961]
+C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1674340963363647]
+C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
+I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
+ variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
+ Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
+ Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
+ Ge_Y0: float32:(2, 3) in [-10.366023063659668, 8.10552978515625]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
+ Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.956035614013672]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
+ Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.407014846801758]
ArgMin(Ad_C0) -> label
+ label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
+ scores: float32:(2, 3) in [0.2129589319229126, 4.0505571365356445]
[array([1, 1]), array([[4.0040483 , 0.21295893, 3.158615 ],
[4.050557 , 0.99604493, 2.7256362 ]], dtype=float32)]
**脚本总运行时间:** (0 分 0.115 秒)