ONNX 模型中心¶
ONNX 模型中心提供了一种简单快捷的方式,可以快速上手来自 ONNX 模型库 的最先进的预训练 ONNX 模型。此外,这还为研究人员和模型开发者提供了机会,让他们能够与更广泛的社区分享他们的预训练模型。
安装¶
ONNX 模型中心在 ONNX 1.11.0 版本之后可用。
基本用法¶
ONNX 模型中心能够从任何 git 仓库下载、列出和查询已训练的模型,默认情况下使用官方的 ONNX 模型库。在本节中,我们将演示一些基本功能。
首先,请使用以下命令导入模型中心:
from onnx import hub
按名称下载模型:¶
load
函数将默认搜索模型库,查找名称匹配的最新模型,将其下载到本地缓存,并将模型加载到 ModelProto
对象中,以便与 ONNX Runtime 一起使用。
model = hub.load("resnet50")
从自定义仓库下载:¶
任何具有适当结构的仓库都可以成为 ONNX 模型中心。要从其他模型中心下载,或指定主模型中心上的特定分支或提交,可以提供 repo
参数。
model = hub.load("resnet50", repo="onnx/models:771185265efbdc049fb223bd68ab1aeb1aecde76")
列出和检查模型:¶
模型中心提供 API 来查询模型库,以了解有关可用模型的更多信息。这不会下载模型,而只是返回与给定参数匹配的模型的信息。
# List all models in the onnx/models:main repo
all_models = hub.list_models()
# List all versions/opsets of a specific model
mnist_models = hub.list_models(model="mnist")
# List all models matching a given "tag"
vision_models = hub.list_models(tags=["vision"])
也可以使用 get_model_info
函数在下载模型之前检查其元数据。
print(hub.get_model_info(model="mnist", opset=8))
这将打印类似如下的内容:
ModelInfo(
model=MNIST,
opset=8,
path=vision/classification/mnist/model/mnist-8.onnx,
metadata={
'model_sha': '2f06e72de813a8635c9bc0397ac447a601bdbfa7df4bebc278723b958831c9bf',
'model_bytes': 26454,
'tags': ['vision', 'classification', 'mnist'],
'io_ports': {
'inputs': [{'name': 'Input3', 'shape': [1, 1, 28, 28], 'type': 'tensor(float)'}],
'outputs': [{'name': 'Plus214_Output_0', 'shape': [1, 10], 'type': 'tensor(float)'}]},
'model_with_data_path': 'vision/classification/mnist/model/mnist-8.tar.gz',
'model_with_data_sha': '1dd098b0fe8bc750585eefc02013c37be1a1cae2bdba0191ccdb8e8518b3a882',
'model_with_data_bytes': 25962}
)
本地缓存¶
ONNX 模型中心会在一个可配置的位置本地缓存下载的模型,以便后续调用 hub.load
时无需网络连接。
默认缓存位置¶
模型中心客户端按以下顺序查找默认缓存位置:
如果定义了
ONNX_HOME
环境变量,则为$ONNX_HOME/hub
。如果定义了
XDG_CACHE_HOME
环境变量,则为$XDG_CACHE_HOME/hub
。对于
~
为用户主目录的情况,为~/.cache/onnx/hub
。
设置缓存位置¶
要手动设置缓存位置,请使用:
hub.set_dir("my/cache/directory")
此外,您还可以使用以下命令检查缓存位置:
print(hub.get_dir())
其他缓存详细信息¶
要清除模型缓存,只需使用 Python 实用程序(如 shutil
或 os
)删除缓存目录即可。此外,还可以选择使用 force_reload
选项覆盖缓存的模型。
model = hub.load("resnet50", force_reload=True)
我们包含此标志是为了完整性,但请注意,缓存中的模型是通过 sha256 哈希值进行区分的,因此在正常使用中并不需要 force_reload
标志。最后,我们注意到模型缓存目录结构将镜像 manifest
的 model_path
字段指定的目录结构,但文件名会使用模型的 SHA256 哈希值进行区分。
这样,模型缓存就可以被人类读取,可以区分多个版本的模型,并且可以在不同的模型中心之间重用具有相同名称和哈希值的模型。
架构¶
ONNX Hub 由两个主要组件组成:客户端和服务器。客户端代码目前包含在 onnx
包中,可以指向一个服务器,该服务器形式为一个托管的 ONNX_HUB_MANIFEST.json
文件,位于 github 仓库中,例如 ONNX 模型库中的那个。此 manifest 文件是一个 JSON 文档,列出了所有模型及其元数据,并且设计为与编程语言无关。一个格式良好的模型 manifest 条目的示例如下:
{
"model": "BERT-Squad",
"model_path": "text/machine_comprehension/bert-squad/model/bertsquad-8.onnx",
"onnx_version": "1.3",
"opset_version": 8,
"metadata": {
"model_sha": "cad65b9807a5e0393e4f84331f9a0c5c844d9cc736e39781a80f9c48ca39447c",
"model_bytes": 435882893,
"tags": ["text", "machine comprehension", "bert-squad"],
"io_ports": {
"inputs": [
{
"name": "unique_ids_raw_output___9:0",
"shape": ["unk__475"],
"type": "tensor(int64)"
},
{
"name": "segment_ids:0",
"shape": ["unk__476", 256],
"type": "tensor(int64)"
},
{
"name": "input_mask:0",
"shape": ["unk__477", 256],
"type": "tensor(int64)"
},
{
"name": "input_ids:0",
"shape": ["unk__478", 256],
"type": "tensor(int64)"
}
],
"outputs": [
{
"name": "unstack:1",
"shape": ["unk__479", 256],
"type": "tensor(float)"
},
{
"name": "unstack:0",
"shape": ["unk__480", 256],
"type": "tensor(float)"
},
{
"name": "unique_ids:0",
"shape": ["unk__481"],
"type": "tensor(int64)"
}
]
},
"model_with_data_path": "text/machine_comprehension/bert-squad/model/bertsquad-8.tar.gz",
"model_with_data_sha": "c8c6c7e0ab9e1333b86e8415a9d990b2570f9374f80be1c1cb72f182d266f666",
"model_with_data_bytes": 403400046
}
}
这些重要字段是:
model
:用于查询的模型名称。model_path
:存储在 Git LFS 中的模型的相对路径。onnx_version
:模型的 ONNX 版本。opset_version
:opset 的版本。如果未指定,客户端将下载最新版本的 opset。metadata/model_sha
:可选的模型 SHA 规范,用于提高下载安全性。metadata/tags
:可选的高级标签,用于帮助用户按给定类型查找模型。
metadata
字段中的所有其他字段对客户端来说是可选的,但为用户提供了重要的详细信息。
添加到 ONNX 模型中心¶
贡献官方模型¶
将模型添加到官方 onnx/models
版本模型中心的最简单方法是遵循 这些指南 来贡献您的模型。贡献后,请确保您的模型在其 README.md
文件中有一个 markdown 表格(示例)。模型中心 manifest 生成器将从这些 markdown 表格中提取信息。要运行生成器:
git clone https://github.com/onnx/models.git
git lfs pull --include="*" --exclude=""
cd models/workflow_scripts
python generate_onnx_hub_manifest.py
生成新 manifest 后,请将其作为 pull request 提交到 onnx/models
。
托管您自己的 ONNX 模型中心¶
要托管您自己的模型中心,请在您的 github 仓库的顶层添加一个 ONNX_HUB_MANIFEST.json
文件(示例)。至少,您的 manifest 条目应包含本文档 架构部分 中提到的字段。提交后,请检查您是否可以使用本文档的“从自定义仓库下载”部分下载模型。
如有问题,请提出¶
对于 ONNX 模型问题或 SHA 匹配问题,请在 [模型库](https://github.com/onnx/models/issues) 中提出问题。
有关 ONNX 模型中心使用方法等其他问题/疑虑,请在此仓库中提出问题:https://github.com/onnx/onnx/issues。