概率或原始分数

分类器通常返回概率矩阵。默认情况下,sklearn-onnx 创建一个返回概率的 ONNX 图,但如果模型实现了方法 decision_function,则可以跳过该步骤并返回原始分数。选项 'raw_scores' 用于更改默认行为。让我们来看一个简单的例子。

训练模型并转换它

import numpy
import sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import onnxruntime as rt
import onnx
import skl2onnx
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
from sklearn.linear_model import LogisticRegression

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression(max_iter=500)
clr.fit(X_train, y_train)
print(clr)

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
LogisticRegression(max_iter=500)

输出类型

让我们用 onnxruntime 确认概率的输出类型是字典列表。

sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(None, {"float_input": X_test.astype(numpy.float32)})
print("skl", clr.predict_proba(X_test[:1]))
print("onnx", res[1][:2])
skl [[2.18706981e-05 3.86603824e-02 9.61317747e-01]]
onnx [{0: 2.1870704586035572e-05, 1: 0.03866042196750641, 2: 0.9613177180290222}, {0: 1.1326104868203402e-05, 1: 0.06515791267156601, 2: 0.9348307847976685}]

原始分数和 decision_function

initial_type = [("float_input", FloatTensorType([None, 4]))]
options = {id(clr): {"raw_scores": True}}
onx2 = convert_sklearn(
    clr, initial_types=initial_type, options=options, target_opset=12
)

sess2 = rt.InferenceSession(
    onx2.SerializeToString(), providers=["CPUExecutionProvider"]
)
res2 = sess2.run(None, {"float_input": X_test.astype(numpy.float32)})
print("skl", clr.decision_function(X_test[:1]))
print("onnx", res2[1][:2])
skl [[-6.0561118   1.42131109  4.63480072]]
onnx [{0: -6.056112289428711, 1: 1.4213112592697144, 2: 4.634799957275391}, {0: -6.659489631652832, 1: 1.997969150543213, 2: 4.661520957946777}]

此示例使用的版本

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 2.2.0
scikit-learn: 1.6.0
onnx:  1.18.0
onnxruntime:  1.21.0+cu126
skl2onnx:  1.18.0

脚本总运行时间: (0 分钟 0.042 秒)

由 Sphinx-Gallery 生成的图库