转换具有简化运算符列表的模型

某些专用于 ONNX 的运行时并未实现所有运算符,如果可用运算符列表中缺少其中某个运算符,则转换后的模型可能无法运行。如果用户希望将某些运算符列入黑名单,则某些转换器可能会以不同的方式转换模型。

GaussianMixture

第一个根据运算符黑名单改变其行为的转换器是用于模型 GaussianMixture

import onnxruntime
import onnx
import numpy
import os
from timeit import timeit
import numpy as np
import matplotlib.pyplot as plt
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
from onnxruntime import InferenceSession
from sklearn.mixture import GaussianMixture
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skl2onnx import to_onnx

data = load_iris()
X_train, X_test = train_test_split(data.data)
model = GaussianMixture()
model.fit(X_train)
GaussianMixture()
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任该笔记本。
在 GitHub 上,HTML 表示无法渲染,请尝试使用 nbviewer.org 加载此页面。


默认转换

model_onnx = to_onnx(
    model,
    X_train[:1].astype(np.float32),
    options={id(model): {"score_samples": True}},
    target_opset=12,
)
sess = InferenceSession(
    model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)

xt = X_test[:5].astype(np.float32)
print(model.score_samples(xt))
print(sess.run(None, {"X": xt})[2])
[-1.68351474 -1.68463982 -2.27655683 -1.68930875 -3.40608478]
[[-1.6835146]
 [-1.6846399]
 [-2.276558 ]
 [-1.6893082]
 [-3.4060864]]

显示 ONNX 图。

pydot_graph = GetPydotGraph(
    model_onnx.graph,
    name=model_onnx.graph.name,
    rankdir="TB",
    node_producer=GetOpNodeProducer(
        "docstring", color="yellow", fillcolor="yellow", style="filled"
    ),
)
pydot_graph.write_dot("mixture.dot")

os.system("dot -O -Gdpi=300 -Tpng mixture.dot")

image = plt.imread("mixture.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
plot black op
(np.float64(-0.5), np.float64(4796.5), np.float64(8425.5), np.float64(-0.5))

不使用 ReduceLogSumExp 进行转换

参数 black_op 用于告诉转换器不要使用此运算符。让我们看看在这种情况下转换器会生成什么。

model_onnx2 = to_onnx(
    model,
    X_train[:1].astype(np.float32),
    options={id(model): {"score_samples": True}},
    black_op={"ReduceLogSumExp"},
    target_opset=12,
)
sess2 = InferenceSession(
    model_onnx2.SerializeToString(), providers=["CPUExecutionProvider"]
)

xt = X_test[:5].astype(np.float32)
print(model.score_samples(xt))
print(sess2.run(None, {"X": xt})[2])
[-1.68351474 -1.68463982 -2.27655683 -1.68930875 -3.40608478]
[[-1.6835146]
 [-1.6846399]
 [-2.276558 ]
 [-1.6893082]
 [-3.4060864]]

显示 ONNX 图。

pydot_graph = GetPydotGraph(
    model_onnx2.graph,
    name=model_onnx2.graph.name,
    rankdir="TB",
    node_producer=GetOpNodeProducer(
        "docstring", color="yellow", fillcolor="yellow", style="filled"
    ),
)
pydot_graph.write_dot("mixture2.dot")

os.system("dot -O -Gdpi=300 -Tpng mixture2.dot")

image = plt.imread("mixture2.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
plot black op
(np.float64(-0.5), np.float64(4921.5), np.float64(13264.5), np.float64(-0.5))

处理时间

print(
    timeit(
        stmt="sess.run(None, {'X': xt})", number=10000, globals={"sess": sess, "xt": xt}
    )
)

print(
    timeit(
        stmt="sess2.run(None, {'X': xt})",
        number=10000,
        globals={"sess2": sess2, "xt": xt},
    )
)
0.19538224400093895
0.236580953001976

使用 ReduceLogSumExp 的模型要快得多。

如果转换器无法在没有…的情况下转换

许多转换器不考虑运算符的白名单和黑名单。如果转换器在不使用黑名单运算符(或仅使用白名单运算符)的情况下无法转换,则 skl2onnx 会引发错误。

try:
    to_onnx(
        model,
        X_train[:1].astype(np.float32),
        options={id(model): {"score_samples": True}},
        black_op={"ReduceLogSumExp", "Add"},
        target_opset=12,
    )
except RuntimeError as e:
    print("Error:", e)
Error: Operator 'Add' is black listed.

此示例使用的版本

import sklearn

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
import skl2onnx

print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 2.2.0
scikit-learn: 1.6.0
onnx:  1.18.0
onnxruntime:  1.21.0+cu126
skl2onnx:  1.18.0

脚本总运行时间: (0 分钟 12.799 秒)

由 Sphinx-Gallery 生成的图库