注意
转到末尾 下载完整的示例代码。
绘制管道¶
除了查看其节点之外,没有其他方法可以查看存储在 ONNX 格式中的模型,例如使用onnx。此示例演示如何绘制模型并以json格式检索它。
以 JSON 格式检索模型¶
这是最简单的方法。
import skl2onnx
import onnxruntime
import sklearn
import numpy
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
from onnx import ModelProto
import onnx
from skl2onnx.algebra.onnx_ops import OnnxAdd, OnnxMul
onnx_fct = OnnxAdd(
OnnxMul("X", numpy.array([2], dtype=numpy.float32), op_version=12),
numpy.array([[1, 0], [0, 1]], dtype=numpy.float32),
output_names=["Y"],
op_version=12,
)
X = numpy.array([[4, 5], [-2, 3]], dtype=numpy.float32)
model = onnx_fct.to_onnx({"X": X}, target_opset=12)
print(model)
filename = "example1.onnx"
with open(filename, "wb") as f:
f.write(model.SerializeToString())
ir_version: 7
opset_import {
domain: ""
version: 12
}
producer_name: "skl2onnx"
producer_version: "1.17.0"
domain: "ai.onnx"
model_version: 0
graph {
node {
input: "X"
input: "Mu_Mulcst"
output: "Mu_C0"
name: "Mu_Mul"
op_type: "Mul"
domain: ""
}
node {
input: "Mu_C0"
input: "Ad_Addcst"
output: "Y"
name: "Ad_Add"
op_type: "Add"
domain: ""
}
name: "OnnxAdd"
initializer {
dims: 1
data_type: 1
float_data: 2
name: "Mu_Mulcst"
}
initializer {
dims: 2
dims: 2
data_type: 1
float_data: 1
float_data: 0
float_data: 0
float_data: 1
name: "Ad_Addcst"
}
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
使用 ONNX 绘制模型¶
我们使用onnx包中包含的net_drawer.py。我们使用onnx以与之前不同的方式加载模型。
我们将其转换为图形。
pydot_graph = GetPydotGraph(
model.graph,
name=model.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer("docstring"),
)
pydot_graph.write_dot("graph.dot")
然后转换为图像
os.system("dot -O -Tpng graph.dot")
0
我们显示它…
image = plt.imread("graph.dot.png")
plt.imshow(image)
plt.axis("off")
(-0.5, 431.5, 602.5, -0.5)
此示例使用的版本
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.26.4
scikit-learn: 1.6.dev0
onnx: 1.17.0
onnxruntime: 1.18.0+cu118
skl2onnx: 1.17.0
脚本总运行时间:(0 分钟 0.652 秒)