注意
转到末尾 下载完整的示例代码。
使用运算符列表转换模型¶
一些专用于 ONNX 的运行时并未实现所有运算符,如果可用运算符列表中缺少其中一个,则转换后的模型可能无法运行。一些转换器可能会根据用户想要过滤掉的运算符,以不同的方式转换模型。
GaussianMixture¶
第一个根据运算符黑名单改变其行为的转换器是针对模型 GaussianMixture 的。
import onnxruntime
import onnx
import numpy
import os
from timeit import timeit
import numpy as np
import matplotlib.pyplot as plt
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
from onnxruntime import InferenceSession
from sklearn.mixture import GaussianMixture
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skl2onnx import to_onnx
data = load_iris()
X_train, X_test = train_test_split(data.data)
model = GaussianMixture()
model.fit(X_train)
默认转换¶
model_onnx = to_onnx(
model,
X_train[:1].astype(np.float32),
options={id(model): {"score_samples": True}},
target_opset=12,
)
sess = InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
xt = X_test[:5].astype(np.float32)
print(model.score_samples(xt))
print(sess.run(None, {"X": xt})[2])
[-3.55984426 -3.81216953 -2.14642002 -1.3580866 -2.74574493]
[[-3.559843 ]
[-3.812169 ]
[-2.1464195]
[-1.3580852]
[-2.7457438]]
显示 ONNX 图。
pydot_graph = GetPydotGraph(
model_onnx.graph,
name=model_onnx.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"
),
)
pydot_graph.write_dot("mixture.dot")
os.system("dot -O -Gdpi=300 -Tpng mixture.dot")
image = plt.imread("mixture.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")

(np.float64(-0.5), np.float64(4796.5), np.float64(8425.5), np.float64(-0.5))
无 ReduceLogSumExp 的转换¶
参数 black_op 用于告知转换器不要使用此运算符。让我们看看转换器在这种情况下会生成什么。
model_onnx2 = to_onnx(
model,
X_train[:1].astype(np.float32),
options={id(model): {"score_samples": True}},
black_op={"ReduceLogSumExp"},
target_opset=12,
)
sess2 = InferenceSession(
model_onnx2.SerializeToString(), providers=["CPUExecutionProvider"]
)
xt = X_test[:5].astype(np.float32)
print(model.score_samples(xt))
print(sess2.run(None, {"X": xt})[2])
[-3.55984426 -3.81216953 -2.14642002 -1.3580866 -2.74574493]
[[-3.559843 ]
[-3.812169 ]
[-2.1464195]
[-1.3580852]
[-2.7457438]]
显示 ONNX 图。
pydot_graph = GetPydotGraph(
model_onnx2.graph,
name=model_onnx2.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"
),
)
pydot_graph.write_dot("mixture2.dot")
os.system("dot -O -Gdpi=300 -Tpng mixture2.dot")
image = plt.imread("mixture2.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")

(np.float64(-0.5), np.float64(4921.5), np.float64(13264.5), np.float64(-0.5))
处理时间¶
0.2215305229999558
0.2625728419998268
使用 ReduceLogSumExp 的模型速度更快。
如果转换器无法在不...¶
许多转换器不考虑运算符的白名单和黑名单。如果转换器在不使用黑名单运算符(或仅使用白名单运算符)的情况下转换失败,skl2onnx 将引发错误。
try:
to_onnx(
model,
X_train[:1].astype(np.float32),
options={id(model): {"score_samples": True}},
black_op={"ReduceLogSumExp", "Add"},
target_opset=12,
)
except RuntimeError as e:
print("Error:", e)
Error: Operator 'Add' is black listed.
此示例使用的版本
import sklearn
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
import skl2onnx
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 2.3.1
scikit-learn: 1.6.1
onnx: 1.19.0
onnxruntime: 1.23.0
skl2onnx: 1.19.1
脚本总运行时间: (0 分钟 17.848 秒)