注意
转到末尾 下载完整示例代码。
WOE转换器¶
WOE 是 Evidence Weights 的缩写。它检查特征 X 是否属于一系列区域(区间)。结果是包含该特征的每个区间的标签。
一个简单示例¶
X 是由前十个整数构成的向量。WOETransformer
类检查其中每个整数是否属于两个区间,]1, 3[(左右开)和 [5, 7](左右闭)。第一个区间关联权重 55,第二个区间关联权重 107。
import os
import numpy as np
import pandas as pd
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
from onnxruntime import InferenceSession
import matplotlib.pyplot as plt
from skl2onnx import to_onnx
from skl2onnx.sklapi import WOETransformer
# automatically registers the converter for WOETransformer
import skl2onnx.sklapi.register # noqa: F401
X = np.arange(10).astype(np.float32).reshape((-1, 1))
intervals = [[(1.0, 3.0, False, False), (5.0, 7.0, True, True)]]
weights = [[55, 107]]
woe1 = WOETransformer(intervals, onehot=False, weights=weights)
woe1.fit(X)
prd = woe1.transform(X)
df = pd.DataFrame({"X": X.ravel(), "woe": prd.ravel()})
df
One Hot¶
转换器输出一列权重。但它也可以为每个区间返回一列。
woe2 = WOETransformer(intervals, onehot=True, weights=weights)
woe2.fit(X)
prd = woe2.transform(X)
df = pd.DataFrame(prd)
df.columns = ["I1", "I2"]
df["X"] = X
df
在这种情况下,可以省略权重。输出是二元的。
woe = WOETransformer(intervals, onehot=True)
woe.fit(X)
prd = woe.transform(X)
df = pd.DataFrame(prd)
df.columns = ["I1", "I2"]
df["X"] = X
df
转换为 ONNX¶
skl2onnx 为所有情况实现了转换器。
onehot=False
onx1 = to_onnx(woe1, X)
sess = InferenceSession(onx1.SerializeToString(), providers=["CPUExecutionProvider"])
print(sess.run(None, {"X": X})[0])
[[ 0.]
[ 0.]
[ 55.]
[ 0.]
[ 0.]
[107.]
[107.]
[107.]
[ 0.]
[ 0.]]
onehot=True
onx2 = to_onnx(woe2, X)
sess = InferenceSession(onx2.SerializeToString(), providers=["CPUExecutionProvider"])
print(sess.run(None, {"X": X})[0])
[[ 0. 0.]
[ 0. 0.]
[ 55. 0.]
[ 0. 0.]
[ 0. 0.]
[ 0. 107.]
[ 0. 107.]
[ 0. 107.]
[ 0. 0.]
[ 0. 0.]]
ONNX 图¶
onehot=False
pydot_graph = GetPydotGraph(
onx1.graph,
name=onx1.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"
),
)
pydot_graph.write_dot("woe1.dot")
os.system("dot -O -Gdpi=300 -Tpng woe1.dot")
image = plt.imread("woe1.dot.png")
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
ax.axis("off")

(np.float64(-0.5), np.float64(2674.5), np.float64(3321.5), np.float64(-0.5))
onehot=True
pydot_graph = GetPydotGraph(
onx2.graph,
name=onx2.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"
),
)
pydot_graph.write_dot("woe2.dot")
os.system("dot -O -Gdpi=300 -Tpng woe2.dot")
image = plt.imread("woe2.dot.png")
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
ax.axis("off")

(np.float64(-0.5), np.float64(2743.5), np.float64(5696.5), np.float64(-0.5))
半开区间/半直线¶
区间可能只有一个端点被定义,另一个是无穷大。
intervals = [[(-np.inf, 3.0, True, True), (5.0, np.inf, True, True)]]
weights = [[55, 107]]
woe1 = WOETransformer(intervals, onehot=False, weights=weights)
woe1.fit(X)
prd = woe1.transform(X)
df = pd.DataFrame({"X": X.ravel(), "woe": prd.ravel()})
df
并使用相同的指令转换为 ONNX。
onxinf = to_onnx(woe1, X)
sess = InferenceSession(onxinf.SerializeToString(), providers=["CPUExecutionProvider"])
print(sess.run(None, {"X": X})[0])
[[ 55.]
[ 55.]
[ 55.]
[ 55.]
[ 0.]
[107.]
[107.]
[107.]
[107.]
[107.]]
脚本总运行时间: (0 分 2.590 秒)