注意
转到末尾 下载完整的示例代码。
日志记录,详细¶
如果管道包含没有关联转换器的对象,则转换会失败。如果其中一个对象由自定义转换器映射,也可能失败。如果错误消息不够明确,可以启用日志记录。
训练模型¶
一个非常基础的随机森林和鸢尾花数据集的示例。
import logging
import numpy
import onnx
import onnxruntime as rt
import sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
import skl2onnx
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = DecisionTreeClassifier()
clr.fit(X_train, y_train)
print(clr)
DecisionTreeClassifier()
将模型转换为 ONNX¶
initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(pred_onx)
[0 1 1 1 2 1 2 1 0 0 0 2 2 0 2 1 0 0 0 0 2 1 2 1 2 2 2 0 1 1 0 2 0 2 1 2 2
0]
带参数 verbose 的转换¶
verbose 是一个在标准输出上打印消息的参数。它会告诉正在调用哪个转换器。 `verbose=1` 通常表示 skl2onnx 在转换管道时正在做什么。 `verbose=2+` 保留给转换器内部的信息。
convert_sklearn(clr, initial_types=initial_type, target_opset=12, verbose=1)
[convert_sklearn] parse_sklearn_model
[convert_sklearn] convert_topology
[convert_operators] begin
[convert_operators] iteration 1 - n_vars=0 n_ops=2
[call_converter] call converter for 'SklearnDecisionTreeClassifier'.
[call_converter] call converter for 'SklearnZipMap'.
[convert_operators] end iter: 1 - n_vars=5
[convert_operators] iteration 2 - n_vars=5 n_ops=2
[convert_operators] end iter: 2 - n_vars=5
[convert_operators] end.
[_update_domain_version] +opset 0: name='', version=9
[_update_domain_version] +opset 1: name='ai.onnx.ml', version=1
[convert_sklearn] end
ir_version: 7
producer_name: "skl2onnx"
producer_version: "1.19.1"
domain: "ai.onnx"
model_version: 0
doc_string: ""
graph {
node {
input: "float_input"
output: "label"
output: "probabilities"
name: "TreeEnsembleClassifier"
op_type: "TreeEnsembleClassifier"
attribute {
name: "class_ids"
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
type: INTS
}
attribute {
name: "class_nodeids"
ints: 1
ints: 1
ints: 1
ints: 6
ints: 6
ints: 6
ints: 7
ints: 7
ints: 7
ints: 9
ints: 9
ints: 9
ints: 11
ints: 11
ints: 11
ints: 12
ints: 12
ints: 12
ints: 13
ints: 13
ints: 13
ints: 16
ints: 16
ints: 16
ints: 17
ints: 17
ints: 17
ints: 18
ints: 18
ints: 18
type: INTS
}
attribute {
name: "class_treeids"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "class_weights"
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 0
floats: 1
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
type: FLOATS
}
attribute {
name: "classlabels_int64s"
ints: 0
ints: 1
ints: 2
type: INTS
}
attribute {
name: "nodes_falsenodeids"
ints: 2
ints: 0
ints: 14
ints: 13
ints: 8
ints: 7
ints: 0
ints: 0
ints: 10
ints: 0
ints: 12
ints: 0
ints: 0
ints: 0
ints: 18
ints: 17
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_featureids"
ints: 3
ints: 0
ints: 3
ints: 2
ints: 0
ints: 1
ints: 0
ints: 0
ints: 2
ints: 0
ints: 3
ints: 0
ints: 0
ints: 0
ints: 2
ints: 1
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_hitrates"
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
type: FLOATS
}
attribute {
name: "nodes_missing_value_tracks_true"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_modes"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "LEAF"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "LEAF"
strings: "LEAF"
type: STRINGS
}
attribute {
name: "nodes_nodeids"
ints: 0
ints: 1
ints: 2
ints: 3
ints: 4
ints: 5
ints: 6
ints: 7
ints: 8
ints: 9
ints: 10
ints: 11
ints: 12
ints: 13
ints: 14
ints: 15
ints: 16
ints: 17
ints: 18
type: INTS
}
attribute {
name: "nodes_treeids"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_truenodeids"
ints: 1
ints: 0
ints: 3
ints: 4
ints: 5
ints: 6
ints: 0
ints: 0
ints: 9
ints: 0
ints: 11
ints: 0
ints: 0
ints: 0
ints: 15
ints: 16
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_values"
floats: 0.75
floats: 0
floats: 1.75
floats: 5.35
floats: 4.95
floats: 2.45
floats: 0
floats: 0
floats: 4.95
floats: 0
floats: 1.55
floats: 0
floats: 0
floats: 0
floats: 4.85
floats: 3.1
floats: 0
floats: 0
floats: 0
type: FLOATS
}
attribute {
name: "post_transform"
s: "NONE"
type: STRING
}
domain: "ai.onnx.ml"
}
node {
input: "label"
output: "output_label"
name: "Cast"
op_type: "Cast"
attribute {
name: "to"
i: 7
type: INT
}
domain: ""
}
node {
input: "probabilities"
output: "output_probability"
name: "ZipMap"
op_type: "ZipMap"
attribute {
name: "classlabels_int64s"
ints: 0
ints: 1
ints: 2
type: INTS
}
domain: "ai.onnx.ml"
}
name: "21b6707aac84485f97acc71b370309e7"
input {
name: "float_input"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "output_label"
type {
tensor_type {
elem_type: 7
shape {
dim {
}
}
}
}
}
output {
name: "output_probability"
type {
sequence_type {
elem_type {
map_type {
key_type: 7
value_type {
tensor_type {
elem_type: 1
}
}
}
}
}
}
}
}
opset_import {
domain: ""
version: 9
}
opset_import {
domain: "ai.onnx.ml"
version: 1
}
带日志记录的转换¶
这是非常详细的日志记录。它会记录正在处理哪些运算符或变量(转换器的输出),创建了哪个节点……在实现自定义转换器时,这些信息可能很有用。
logger = logging.getLogger("skl2onnx")
logger.setLevel(logging.DEBUG)
convert_sklearn(clr, initial_types=initial_type, target_opset=12)
ir_version: 7
producer_name: "skl2onnx"
producer_version: "1.19.1"
domain: "ai.onnx"
model_version: 0
doc_string: ""
graph {
node {
input: "float_input"
output: "label"
output: "probabilities"
name: "TreeEnsembleClassifier"
op_type: "TreeEnsembleClassifier"
attribute {
name: "class_ids"
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
ints: 0
ints: 1
ints: 2
type: INTS
}
attribute {
name: "class_nodeids"
ints: 1
ints: 1
ints: 1
ints: 6
ints: 6
ints: 6
ints: 7
ints: 7
ints: 7
ints: 9
ints: 9
ints: 9
ints: 11
ints: 11
ints: 11
ints: 12
ints: 12
ints: 12
ints: 13
ints: 13
ints: 13
ints: 16
ints: 16
ints: 16
ints: 17
ints: 17
ints: 17
ints: 18
ints: 18
ints: 18
type: INTS
}
attribute {
name: "class_treeids"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "class_weights"
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
floats: 0
floats: 0
floats: 1
floats: 0
floats: 1
floats: 0
floats: 0
floats: 0
floats: 1
type: FLOATS
}
attribute {
name: "classlabels_int64s"
ints: 0
ints: 1
ints: 2
type: INTS
}
attribute {
name: "nodes_falsenodeids"
ints: 2
ints: 0
ints: 14
ints: 13
ints: 8
ints: 7
ints: 0
ints: 0
ints: 10
ints: 0
ints: 12
ints: 0
ints: 0
ints: 0
ints: 18
ints: 17
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_featureids"
ints: 3
ints: 0
ints: 3
ints: 2
ints: 0
ints: 1
ints: 0
ints: 0
ints: 2
ints: 0
ints: 3
ints: 0
ints: 0
ints: 0
ints: 2
ints: 1
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_hitrates"
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
floats: 1
type: FLOATS
}
attribute {
name: "nodes_missing_value_tracks_true"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_modes"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "LEAF"
strings: "LEAF"
strings: "BRANCH_LEQ"
strings: "BRANCH_LEQ"
strings: "LEAF"
strings: "LEAF"
strings: "LEAF"
type: STRINGS
}
attribute {
name: "nodes_nodeids"
ints: 0
ints: 1
ints: 2
ints: 3
ints: 4
ints: 5
ints: 6
ints: 7
ints: 8
ints: 9
ints: 10
ints: 11
ints: 12
ints: 13
ints: 14
ints: 15
ints: 16
ints: 17
ints: 18
type: INTS
}
attribute {
name: "nodes_treeids"
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_truenodeids"
ints: 1
ints: 0
ints: 3
ints: 4
ints: 5
ints: 6
ints: 0
ints: 0
ints: 9
ints: 0
ints: 11
ints: 0
ints: 0
ints: 0
ints: 15
ints: 16
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "nodes_values"
floats: 0.75
floats: 0
floats: 1.75
floats: 5.35
floats: 4.95
floats: 2.45
floats: 0
floats: 0
floats: 4.95
floats: 0
floats: 1.55
floats: 0
floats: 0
floats: 0
floats: 4.85
floats: 3.1
floats: 0
floats: 0
floats: 0
type: FLOATS
}
attribute {
name: "post_transform"
s: "NONE"
type: STRING
}
domain: "ai.onnx.ml"
}
node {
input: "label"
output: "output_label"
name: "Cast"
op_type: "Cast"
attribute {
name: "to"
i: 7
type: INT
}
domain: ""
}
node {
input: "probabilities"
output: "output_probability"
name: "ZipMap"
op_type: "ZipMap"
attribute {
name: "classlabels_int64s"
ints: 0
ints: 1
ints: 2
type: INTS
}
domain: "ai.onnx.ml"
}
name: "7cb1fe8a42be413090ed7a77b076fb68"
input {
name: "float_input"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "output_label"
type {
tensor_type {
elem_type: 7
shape {
dim {
}
}
}
}
}
output {
name: "output_probability"
type {
sequence_type {
elem_type {
map_type {
key_type: 7
value_type {
tensor_type {
elem_type: 1
}
}
}
}
}
}
}
}
opset_import {
domain: ""
version: 9
}
opset_import {
domain: "ai.onnx.ml"
version: 1
}
以及如何禁用它。
logger.setLevel(logging.INFO)
convert_sklearn(clr, initial_types=initial_type, target_opset=12)
logger.setLevel(logging.WARNING)
此示例使用的版本
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 2.3.1
scikit-learn: 1.6.1
onnx: 1.19.0
onnxruntime: 1.23.0
skl2onnx: 1.19.1
脚本总运行时间: (0 分钟 0.155 秒)