注意
转到末尾 下载完整的示例代码
比较 CDist 与 scipy¶
以下示例侧重于一个特定操作符 CDist,并比较了它在 onnxruntime 和 scipy 之间的执行时间。
带有 CDist 的 ONNX 图¶
cdist 函数计算成对距离。
from pprint import pprint
from timeit import Timer
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from pandas import DataFrame
import onnx
import onnxruntime as rt
from onnxruntime import InferenceSession
import skl2onnx
from skl2onnx.algebra.custom_ops import OnnxCDist
from skl2onnx.common.data_types import FloatTensorType
X = np.ones((2, 4), dtype=np.float32)
Y = np.ones((3, 4), dtype=np.float32)
Y *= 2
print(cdist(X, Y, metric="euclidean"))
[[2. 2. 2.]
[2. 2. 2.]]
ONNX
ir_version: 8
opset_import {
domain: "com.microsoft"
version: 1
}
producer_name: "skl2onnx"
producer_version: "1.15.0"
domain: "ai.onnx"
model_version: 0
graph {
node {
input: "X"
input: "Y"
output: "Z"
name: "CD_CDist"
op_type: "CDist"
domain: "com.microsoft"
attribute {
name: "metric"
type: STRING
s: "euclidean"
}
}
name: "OnnxCDist"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "Z"
type {
tensor_type {
elem_type: 1
}
}
}
}
CDist 和 onnxruntime¶
我们使用 onnxruntime 计算 CDist 操作符的输出。
[array([[1.9999999, 1.9999999, 1.9999999],
[1.9999999, 2. , 2. ]], dtype=float32)]
基准测试¶
让我们比较 onnxruntime 和 scipy。
def measure_time(name, stmt, context, repeat=100, number=20):
tim = Timer(stmt, globals=context)
res = np.array(tim.repeat(repeat=repeat, number=number))
res /= number
mean = np.mean(res)
dev = np.mean(res**2)
dev = (dev - mean**2) ** 0.5
return dict(
average=mean,
deviation=dev,
min_exec=np.min(res),
max_exec=np.max(res),
repeat=repeat,
number=number,
nrows=context["X"].shape[0],
ncols=context["Y"].shape[1],
name=name,
)
scipy
time_scipy = measure_time(
"scipy", "cdist(X, Y)", context={"cdist": cdist, "X": X, "Y": Y}
)
pprint(time_scipy)
{'average': 9.84484999992219e-06,
'deviation': 4.450414501003498e-06,
'max_exec': 4.325500000277316e-05,
'min_exec': 3.990000004705508e-06,
'name': 'scipy',
'ncols': 4,
'nrows': 2,
'number': 20,
'repeat': 100}
onnxruntime
{'average': 1.90378500000179e-05,
'deviation': 9.14980003126398e-06,
'max_exec': 5.8449999994536486e-05,
'min_exec': 1.1034999999992579e-05,
'name': 'ort',
'ncols': 4,
'nrows': 2,
'number': 20,
'repeat': 100}
更长的基准测试
metrics = []
for dim in tqdm([10, 100, 1000, 10000]):
# We cannot change the number of column otherwise
# we need to create a new graph.
X = np.random.randn(dim, 4).astype(np.float32)
Y = np.random.randn(10, 4).astype(np.float32)
time_scipy = measure_time(
"scipy", "cdist(X, Y)", context={"cdist": cdist, "X": X, "Y": Y}
)
time_ort = measure_time(
"ort",
"sess.run(None, {'X': X, 'Y': Y})",
context={"sess": sess, "X": X, "Y": Y},
)
metric = dict(N=dim, scipy=time_scipy["average"], ort=time_ort["average"])
metrics.append(metric)
df = DataFrame(metrics)
df["scipy/ort"] = df["scipy"] / df["ort"]
print(df)
df.plot(x="N", y=["scipy/ort"])
0%| | 0/4 [00:00<?, ?it/s]
50%|█████ | 2/4 [00:00<00:00, 19.40it/s]
100%|██████████| 4/4 [00:02<00:00, 1.41it/s]
100%|██████████| 4/4 [00:02<00:00, 1.64it/s]
N scipy ort scipy/ort
0 10 0.000010 0.000014 0.695364
1 100 0.000012 0.000015 0.824828
2 1000 0.000095 0.000041 2.312891
3 10000 0.000741 0.000287 2.578285
本示例使用的版本
print("numpy:", np.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
onnx: 1.15.0
onnxruntime: 1.16.0+cu118
skl2onnx: 1.15.0
脚本总运行时间:(0 分 2.761 秒)