注意
转到结尾 下载完整的示例代码
概率或原始分数¶
分类器通常返回一个概率矩阵。默认情况下,sklearn-onnx 创建一个返回概率的 ONNX 图,但如果模型实现了 decision_function 方法,它可能会跳过此步骤并返回原始分数。选项 'raw_scores'
用于更改默认行为。让我们在一个简单的示例中看看这一点。
训练模型并转换它¶
import numpy
import sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import onnxruntime as rt
import onnx
import skl2onnx
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
from sklearn.linear_model import LogisticRegression
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression(max_iter=500)
clr.fit(X_train, y_train)
print(clr)
initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
LogisticRegression(max_iter=500)
输出类型¶
让我们使用 onnxruntime 确认概率的输出类型是包含字典的列表。
sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(None, {"float_input": X_test.astype(numpy.float32)})
print("skl", clr.predict_proba(X_test[:1]))
print("onnx", res[1][:2])
skl [[9.82794559e-01 1.72053489e-02 9.16403830e-08]]
onnx [{0: 0.9827945232391357, 1: 0.017205340787768364, 2: 9.164028114128087e-08}, {0: 0.00189912598580122, 1: 0.4566256105899811, 2: 0.541475236415863}]
原始分数和 decision_function¶
initial_type = [("float_input", FloatTensorType([None, 4]))]
options = {id(clr): {"raw_scores": True}}
onx2 = convert_sklearn(
clr, initial_types=initial_type, options=options, target_opset=12
)
sess2 = rt.InferenceSession(
onx2.SerializeToString(), providers=["CPUExecutionProvider"]
)
res2 = sess2.run(None, {"float_input": X_test.astype(numpy.float32)})
print("skl", clr.decision_function(X_test[:1]))
print("onnx", res2[1][:2])
skl [[ 6.74440614 2.69922635 -9.44363249]]
onnx [{0: 6.744406700134277, 1: 2.6992263793945312, 2: -9.443633079528809}, {0: -3.7117910385131836, 1: 1.770678997039795, 2: 1.9411125183105469}]
此示例使用的版本
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx: 1.15.0
onnxruntime: 1.16.0+cu118
skl2onnx: 1.15.0
脚本的总运行时间:(0 分钟 0.134 秒)