ONNX 模型中心

ONNX 模型中心是一种简单快捷的方式,可以从 ONNX 模型动物园 开始使用最先进的预训练 ONNX 模型。此外,这还使研究人员和模型开发人员有机会与更广泛的社区共享其预训练模型。

安装

ONNX 模型中心在 ONNX 1.11.0 之后可用。

基本用法

ONNX 模型中心能够从任何 Git 存储库下载、列出和查询训练过的模型,并默认为官方的 ONNX 模型动物园。在本节中,我们将演示一些基本功能。

首先请使用以下方法导入中心:

from onnx import hub

按名称下载模型:

load 函数将默认搜索模型动物园中名称匹配的最新模型,将此模型下载到本地缓存,并将模型加载到 ModelProto 对象中,以便与 ONNX 运行时一起使用。

model = hub.load("resnet50")

从自定义存储库下载:

任何具有正确结构的存储库都可以成为 ONNX 模型中心。要从其他中心下载,或指定主模型中心上的特定分支或提交,可以提供 repo 参数

model = hub.load("resnet50", repo="onnx/models:771185265efbdc049fb223bd68ab1aeb1aecde76")

列出和检查模型:

模型中心提供用于查询模型动物园的 API,以了解有关可用模型的更多信息。这不会下载模型,而只是返回与给定参数匹配的模型信息

# List all models in the onnx/models:main repo
all_models = hub.list_models()

# List all versions/opsets of a specific model
mnist_models = hub.list_models(model="mnist")

# List all models matching a given "tag"
vision_models = hub.list_models(tags=["vision"])

还可以使用 get_model_info 函数检查模型的元数据,然后再下载。

print(hub.get_model_info(model="mnist", opset=8))

这将打印类似以下内容:

ModelInfo(
    model=MNIST,
    opset=8,
    path=vision/classification/mnist/model/mnist-8.onnx,
    metadata={
     'model_sha': '2f06e72de813a8635c9bc0397ac447a601bdbfa7df4bebc278723b958831c9bf',
     'model_bytes': 26454,
     'tags': ['vision', 'classification', 'mnist'],
     'io_ports': {
        'inputs': [{'name': 'Input3', 'shape': [1, 1, 28, 28], 'type': 'tensor(float)'}],
        'outputs': [{'name': 'Plus214_Output_0', 'shape': [1, 10], 'type': 'tensor(float)'}]},
     'model_with_data_path': 'vision/classification/mnist/model/mnist-8.tar.gz',
     'model_with_data_sha': '1dd098b0fe8bc750585eefc02013c37be1a1cae2bdba0191ccdb8e8518b3a882',
     'model_with_data_bytes': 25962}
)

本地缓存

ONNX 模型中心在可配置的位置本地缓存下载的模型,以便后续对 hub.load 的调用不需要网络连接。

默认缓存位置

中心客户端按以下顺序查找以下默认缓存位置:

  1. $ONNX_HOME/hub 如果定义了 ONNX_HOME 环境变量

  2. $XDG_CACHE_HOME/hub 如果定义了 XDG_CACHE_HOME 环境变量

  3. ~/.cache/onnx/hub 其中 ~ 是用户主目录

设置缓存位置

要手动设置缓存位置,请使用:

hub.set_dir("my/cache/directory")

此外,可以使用以下方法检查缓存位置:

print(hub.get_dir())

其他缓存详细信息

要清除模型缓存,只需使用 Python 实用程序(如 shutilos)删除缓存目录即可。此外,可以使用 force_reload 选项覆盖缓存的模型。

model = hub.load("resnet50", force_reload=True)

出于完整性考虑,我们包含了此标志,但请注意,缓存中的模型使用 sha256 哈希进行区分,因此在正常使用情况下不需要 force_reload 标志。最后,我们注意到模型缓存目录结构将镜像 model_path 字段在清单中指定的目录结构,但文件名将使用模型 SHA256 哈希进行区分。

这样,模型缓存对人类可读,可以区分模型的多个版本,并且如果不同 Hub 中的模型具有相同的名称和哈希值,则可以重用缓存的模型。

架构

ONNX Hub Architecture

ONNX Hub 由两个主要组件组成:客户端和服务器。客户端代码目前包含在 onnx 包中,可以指向以托管 ONNX_HUB_MANIFEST.json 形式存在的服务器,该服务器位于 Github 存储库中,例如 ONNX 模型动物园中的那个。此清单文件是一个 JSON 文档,列出了所有模型及其元数据,旨在与编程语言无关。一个格式良好的模型清单条目的示例如下所示

{
 "model": "BERT-Squad",
 "model_path": "text/machine_comprehension/bert-squad/model/bertsquad-8.onnx",
 "onnx_version": "1.3",
 "opset_version": 8,
 "metadata": {
     "model_sha": "cad65b9807a5e0393e4f84331f9a0c5c844d9cc736e39781a80f9c48ca39447c",
     "model_bytes": 435882893,
     "tags": ["text", "machine comprehension", "bert-squad"],
     "io_ports": {
         "inputs": [
             {
                 "name": "unique_ids_raw_output___9:0",
                 "shape": ["unk__475"],
                 "type": "tensor(int64)"
             },
             {
                 "name": "segment_ids:0",
                 "shape": ["unk__476", 256],
                 "type": "tensor(int64)"
             },
             {
                 "name": "input_mask:0",
                 "shape": ["unk__477", 256],
                 "type": "tensor(int64)"
             },
             {
                 "name": "input_ids:0",
                 "shape": ["unk__478", 256],
                 "type": "tensor(int64)"
             }
         ],
         "outputs": [
             {
                 "name": "unstack:1",
                 "shape": ["unk__479", 256],
                 "type": "tensor(float)"
             },
             {
                 "name": "unstack:0",
                 "shape": ["unk__480", 256],
                 "type": "tensor(float)"
             },
             {
                 "name": "unique_ids:0",
                 "shape": ["unk__481"],
                 "type": "tensor(int64)"
             }
         ]
     },
     "model_with_data_path": "text/machine_comprehension/bert-squad/model/bertsquad-8.tar.gz",
     "model_with_data_sha": "c8c6c7e0ab9e1333b86e8415a9d990b2570f9374f80be1c1cb72f182d266f666",
     "model_with_data_bytes": 403400046
 }
}

这些重要的字段是

  • model:用于查询的模型名称

  • model_path:存储在 Git LFS 中的模型的相对路径。

  • onnx_version:模型的 ONNX 版本

  • opset_version:操作集的版本。如果未指定,客户端将下载最新的操作集。

  • metadata/model_sha:可选的模型 sha 规范,用于增强下载安全性

  • metadata/tags:可选的高级标签,帮助用户按给定类型查找模型

metadata 字段中的所有其他字段对于客户端都是可选的,但为用户提供了重要的详细信息。

添加到 ONNX 模型 Hub

贡献官方模型

将模型添加到官方 onnx/models 版本模型 Hub 的最简单方法是遵循 这些指南 来贡献您的模型。贡献后,确保您的模型在其 README.md 中包含一个 Markdown 表格(示例)。模型 Hub 清单生成器将从这些 Markdown 表格中提取信息。要运行生成器

git clone https://github.com/onnx/models.git
git lfs pull --include="*" --exclude=""
cd models/workflow_scripts
python generate_onnx_hub_manifest.py

生成新的清单后,将其添加到拉取请求中并提交到 onnx/models

托管您自己的 ONNX 模型 Hub

要托管您自己的模型 Hub,请将 ONNX_HUB_MANIFEST.json 添加到 Github 存储库的顶层(示例)。至少,您的清单条目应包含本文档的 架构部分 中提到的字段。提交后,请检查您是否可以使用本文档的“从自定义存储库下载”部分下载模型。

如有任何问题,请提出

  • 对于 ONNX 模型问题或 SHA 不匹配问题,请在 [模型动物园]/(https://github.com/onnx/models/issues) 中提出问题。

  • 有关 ONNX 模型 Hub 用法的其他问题/问题,请在此存储库中提出问题 this repo