使用 ONNX 的 TfIdfVectorizer

此示例的灵感来自以下示例:具有异构数据源的列转换器,它构建了一个用于对文本进行分类的管道。

使用 TfidfVectorizer 训练管道

它复制了从 scikit-learn 文档中获取的相同管道,但将其缩减到 ONNX 实际支持的部分,而无需实现自定义转换器。让我们获取数据。

import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnxruntime as rt
from skl2onnx.common.data_types import StringTensorType
from skl2onnx import convert_sklearn
import numpy as np

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.datasets import fetch_20newsgroups

try:
    from sklearn.datasets._twenty_newsgroups import (
        strip_newsgroup_footer,
        strip_newsgroup_quoting,
    )
except ImportError:
    # scikit-learn < 0.24
    from sklearn.datasets.twenty_newsgroups import (
        strip_newsgroup_footer,
        strip_newsgroup_quoting,
    )
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression


# limit the list of categories to make running this example faster.
categories = ["alt.atheism", "talk.religion.misc"]
train = fetch_20newsgroups(
    random_state=1,
    subset="train",
    categories=categories,
)
test = fetch_20newsgroups(
    random_state=1,
    subset="test",
    categories=categories,
)

第一个转换从数据中提取两个字段。我们将它从管道中取出,并假设数据由两个文本列定义。

class SubjectBodyExtractor(BaseEstimator, TransformerMixin):
    """Extract the subject & body from a usenet post in a single pass.
    Takes a sequence of strings and produces a dict of sequences. Keys are
    `subject` and `body`.
    """

    def fit(self, x, y=None):
        return self

    def transform(self, posts):
        # construct object dtype array with two columns
        # first column = 'subject' and second column = 'body'
        features = np.empty(shape=(len(posts), 2), dtype=object)
        for i, text in enumerate(posts):
            headers, _, bod = text.partition("\n\n")
            bod = strip_newsgroup_footer(bod)
            bod = strip_newsgroup_quoting(bod)
            features[i, 1] = bod

            prefix = "Subject:"
            sub = ""
            for line in headers.split("\n"):
                if line.startswith(prefix):
                    sub = line[len(prefix) :]
                    break
            features[i, 0] = sub

        return features


train_data = SubjectBodyExtractor().fit_transform(train.data)
test_data = SubjectBodyExtractor().fit_transform(test.data)

管道几乎相同,除了我们删除了自定义特征。

pipeline = Pipeline(
    [
        (
            "union",
            ColumnTransformer(
                [
                    ("subject", TfidfVectorizer(min_df=50, max_features=500), 0),
                    (
                        "body_bow",
                        Pipeline(
                            [
                                ("tfidf", TfidfVectorizer()),
                                ("best", TruncatedSVD(n_components=50)),
                            ]
                        ),
                        1,
                    ),
                    # Removed from the original example as
                    # it requires a custom converter.
                    # ('body_stats', Pipeline([
                    #   ('stats', TextStats()),  # returns a list of dicts
                    #   ('vect', DictVectorizer()),  # list of dicts -> feature matrix
                    # ]), 1),
                ],
                transformer_weights={
                    "subject": 0.8,
                    "body_bow": 0.5,
                    # 'body_stats': 1.0,
                },
            ),
        ),
        # Use a LogisticRegression classifier on the combined features.
        # Instead of LinearSVC (not fully ready in onnxruntime).
        ("logreg", LogisticRegression()),
    ]
)

pipeline.fit(train_data, train.target)
print(classification_report(pipeline.predict(test_data), test.target))
              precision    recall  f1-score   support

           0       0.69      0.78      0.73       285
           1       0.75      0.66      0.70       285

    accuracy                           0.72       570
   macro avg       0.72      0.72      0.71       570
weighted avg       0.72      0.72      0.71       570

ONNX 转换

如果分词器来自空格、gensim 或 nltk,则很难复制完全相同的分词器行为。scikit-learn 使用的默认分词器使用正则表达式,目前正在实现中。当前的实现仅考虑可以在变量 seps 中定义的分隔符列表。

seps = {
    TfidfVectorizer: {
        "separators": [
            " ",
            ".",
            "\\?",
            ",",
            ";",
            ":",
            "!",
            "\\(",
            "\\)",
            "\n",
            '"',
            "'",
            "-",
            "\\[",
            "\\]",
            "@",
        ]
    }
}
model_onnx = convert_sklearn(
    pipeline,
    "tfidf",
    initial_types=[("input", StringTensorType([None, 2]))],
    options=seps,
    target_opset=12,
)

并保存。

with open("pipeline_tfidf.onnx", "wb") as f:
    f.write(model_onnx.SerializeToString())

使用 onnxruntime 进行预测。

sess = rt.InferenceSession("pipeline_tfidf.onnx", providers=["CPUExecutionProvider"])
print("---", train_data[0])
inputs = {"input": train_data[:1]}
pred_onx = sess.run(None, inputs)
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1])
--- [" Re: Jews can't hide from keith@cco."
 'Deletions...\n\nSo, you consider the german poster\'s remark anti-semitic?  Perhaps you\nimply that anyone in Germany who doesn\'t agree with israely policy in a\nnazi?  Pray tell, how does it even qualify as "casual anti-semitism"? \nIf the term doesn\'t apply, why then bring it up?\n\nYour own bigotry is shining through.  \n-- ']
predict [1]
predict_proba [{0: 0.4396112561225891, 1: 0.5603887438774109}]

使用 scikit-learn

print(pipeline.predict(train_data[:1]))
print(pipeline.predict_proba(train_data[:1]))
[0]
[[0.72374074 0.27625926]]

此模型存在差异,因为分词并不完全相同。这是一个正在进行的工作。

显示 ONNX 图

最后,让我们看看使用 sklearn-onnx 转换的图。

pydot_graph = GetPydotGraph(
    model_onnx.graph,
    name=model_onnx.graph.name,
    rankdir="TB",
    node_producer=GetOpNodeProducer(
        "docstring", color="yellow", fillcolor="yellow", style="filled"
    ),
)
pydot_graph.write_dot("pipeline_tfidf.dot")

os.system("dot -O -Gdpi=300 -Tpng pipeline_tfidf.dot")

image = plt.imread("pipeline_tfidf.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
plot tfidfvectorizer
(-0.5, 4939.5, 11475.5, -0.5)

脚本的总运行时间:(0 分 14.922 秒)

由 Sphinx-Gallery 生成的库