使用 LightGBM 回归器转换管道

使用浮点数和 TreeEnsemble 操作符时观察到的差异(参见 切换到浮点数时的 问题)解释了为什么即使 LGBMRegressor 与浮点数张量一起使用,其转换器也可能引入明显的差异。

lightgbm 使用双精度数实现。具有多棵树的随机森林回归器通过添加每棵树的预测来计算其预测。转换为 ONNX 后,此求和变为 \left[\sum\right]_{i=1}^F float(T_i(x)),其中 F 是森林中树的数量,T_i(x) 是树 i 的输出,而 \left[\sum\right] 是浮点数加法。差异可以表示为 D(x) = |\left[\sum\right]_{i=1}^F float(T_i(x)) - \sum_{i=1}^F T_i(x)|。随着森林中树的数量增加,此值也会增加。

为了减少影响,添加了一个选项将节点 TreeEnsembleRegressor 分割成多个节点,并这次使用双精度数进行求和。如果我们假设节点被分割成 a 个节点,则差异变为 D'(x) = |\sum_{k=1}^a \left[\sum\right]_{i=1}^{F/a} float(T_{ak + i}(x)) - \sum_{i=1}^F T_i(x)|

训练 LGBMRegressor

import packaging.version as pv
import warnings
import timeit
import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from tqdm import tqdm
from lightgbm import LGBMRegressor
from onnxruntime import InferenceSession
from skl2onnx import to_onnx, update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_regressor_output_shapes,
)  # noqa
from onnxmltools import __version__ as oml_version
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
    convert_lightgbm,
)  # noqa


N = 1000
X = numpy.random.randn(N, 20)
y = numpy.random.randn(N) + numpy.random.randn(N) * 100 * numpy.random.randint(
    0, 1, 1000
)

reg = LGBMRegressor(n_estimators=1000)
reg.fit(X, y)
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000276 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 5100
[LightGBM] [Info] Number of data points in the train set: 1000, number of used features: 20
[LightGBM] [Info] Start training from score 0.001127
LGBMRegressor(n_estimators=1000)
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示形式或信任笔记本。
在 GitHub 上,HTML 表示形式无法呈现,请尝试使用 nbviewer.org 加载此页面。


注册 LGBMClassifier 的转换器

转换器在 onnxmltools 中实现: onnxmltools…LightGbm.py。以及形状计算器: onnxmltools…Regressor.py

def skl2onnx_convert_lightgbm(scope, operator, container):
    options = scope.get_options(operator.raw_operator)
    if "split" in options:
        if pv.Version(oml_version) < pv.Version("1.9.2"):
            warnings.warn(
                "Option split was released in version 1.9.2 but %s is "
                "installed. It will be ignored." % oml_version
            )
        operator.split = options["split"]
    else:
        operator.split = None
    convert_lightgbm(scope, operator, container)


update_registered_converter(
    LGBMRegressor,
    "LightGbmLGBMRegressor",
    calculate_linear_regressor_output_shapes,
    skl2onnx_convert_lightgbm,
    options={"split": None},
)

转换

我们按照两种场景转换相同的模型,一种是单个 TreeEnsembleRegressor 节点,另一种是多个节点。split 参数是每个 TreeEnsembleRegressor 节点的树的数量。

model_onnx = to_onnx(
    reg, X[:1].astype(numpy.float32), target_opset={"": 14, "ai.onnx.ml": 2}
)
model_onnx_split = to_onnx(
    reg,
    X[:1].astype(numpy.float32),
    target_opset={"": 14, "ai.onnx.ml": 2},
    options={"split": 100},
)

差异

sess = InferenceSession(
    model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
sess_split = InferenceSession(
    model_onnx_split.SerializeToString(), providers=["CPUExecutionProvider"]
)

X32 = X.astype(numpy.float32)
expected = reg.predict(X32)
got = sess.run(None, {"X": X32})[0].ravel()
got_split = sess_split.run(None, {"X": X32})[0].ravel()

disp = numpy.abs(got - expected).sum()
disp_split = numpy.abs(got_split - expected).sum()

print("sum of discrepancies 1 node", disp)
print("sum of discrepancies split node", disp_split, "ratio:", disp / disp_split)
sum of discrepancies 1 node 0.00020644343950206685
sum of discrepancies split node 4.144931004458315e-05 ratio: 4.980624268052108

差异之和减少了 4 到 5 倍。最大值也好了很多。

disc = numpy.abs(got - expected).max()
disc_split = numpy.abs(got_split - expected).max()

print("max discrepancies 1 node", disc)
print("max discrepancies split node", disc_split, "ratio:", disc / disc_split)
max discrepancies 1 node 1.985479140209634e-06
max discrepancies split node 2.6622454418756547e-07 ratio: 7.457911689805682

处理时间

处理时间变慢了,但幅度不大。

print(
    "processing time no split",
    timeit.timeit(lambda: sess.run(None, {"X": X32})[0], number=150),
)
print(
    "processing time split",
    timeit.timeit(lambda: sess_split.run(None, {"X": X32})[0], number=150),
)
processing time no split 2.342391199999838
processing time split 2.7244762999998784

分割影响

让我们看看差异之和如何随着参数 split 的变化而变化。

res = []
for i in tqdm(list(range(20, 170, 20)) + [200, 300, 400, 500]):
    model_onnx_split = to_onnx(
        reg,
        X[:1].astype(numpy.float32),
        target_opset={"": 14, "ai.onnx.ml": 2},
        options={"split": i},
    )
    sess_split = InferenceSession(
        model_onnx_split.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    got_split = sess_split.run(None, {"X": X32})[0].ravel()
    disc_split = numpy.abs(got_split - expected).max()
    res.append(dict(split=i, disc=disc_split))

df = DataFrame(res).set_index("split")
df["baseline"] = disc
print(df)
  0%|          | 0/12 [00:00<?, ?it/s]
  8%|▊         | 1/12 [00:01<00:17,  1.61s/it]
 17%|█▋        | 2/12 [00:03<00:15,  1.51s/it]
 25%|██▌       | 3/12 [00:04<00:13,  1.45s/it]
 33%|███▎      | 4/12 [00:05<00:11,  1.43s/it]
 42%|████▏     | 5/12 [00:07<00:09,  1.41s/it]
 50%|█████     | 6/12 [00:08<00:08,  1.40s/it]
 58%|█████▊    | 7/12 [00:10<00:07,  1.41s/it]
 67%|██████▋   | 8/12 [00:11<00:05,  1.44s/it]
 75%|███████▌  | 9/12 [00:12<00:04,  1.41s/it]
 83%|████████▎ | 10/12 [00:15<00:03,  1.80s/it]
 92%|█████████▏| 11/12 [00:17<00:01,  1.74s/it]
100%|██████████| 12/12 [00:18<00:00,  1.61s/it]
100%|██████████| 12/12 [00:18<00:00,  1.54s/it]
               disc  baseline
split
20     2.560464e-07  0.000002
40     1.937818e-07  0.000002
60     2.614565e-07  0.000002
80     2.560464e-07  0.000002
100    2.662245e-07  0.000002
120    4.614585e-07  0.000002
140    4.614585e-07  0.000002
160    4.614585e-07  0.000002
200    5.745647e-07  0.000002
300    8.129833e-07  0.000002
400    1.289820e-06  0.000002
500    1.031805e-06  0.000002

图表。

_, ax = plt.subplots(1, 1)
df.plot(
    title="Sum of discrepancies against split\n" "split = number of tree per node",
    ax=ax,
)

# plt.show()
Sum of discrepancies against split split = number of tree per node

脚本总运行时间:(0 分 27.317 秒)

由 Sphinx-Gallery 生成的库