序列化¶
保存模型和任何 Proto 类¶
此 ONNX 图需要序列化为一个连续的内存缓冲区。方法 SerializeToString
在每个 ONNX 对象中都可用。
with open("model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
此方法具有以下签名。
每个 Proto 类都实现了方法 SerializeToString
。因此,以下代码适用于页面 Protos 中描述的任何类。
with open("proto.pb", "wb") as f:
f.write(proto.SerializeToString())
以下示例演示如何保存一个 NodeProto.
from onnx import NodeProto
node = NodeProto()
node.name = "example-type-proto"
node.op_type = "Add"
node.input.extend(["X", "Y"])
node.output.extend(["Z"])
with open("node.pb", "wb") as f:
f.write(node.SerializeToString())
加载模型¶
以下函数仅自动加载类 ModelProto。下一节将展示如何恢复任何其他 proto 类。
- onnx.load(f: IO[bytes] | str | PathLike, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str | None = None, load_external_data: bool = True) ModelProto ¶
将序列化的 ModelProto 加载到内存中。
- 参数:
f – 可以是文件类对象(具有“read”函数)或包含文件名字符串/PathLike
format – 序列化格式。当未指定时,它从文件扩展名推断(当
f
是路径时)。如果未指定 _且_f
不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。load_external_data – 是否加载外部数据。如果数据在模型的同一目录下,则设置为 True。如果不是,用户需要使用目录调用
load_external_data_for_model()
来加载外部数据。
- 返回值:
加载到内存中的 ModelProto。
from onnx import load
onnx_model = load("model.onnx")
或
from onnx import load
with open("model.onnx", "rb") as f:
onnx_model = load(f)
下一个函数从字节数组中执行相同的操作。
- onnx.load_model_from_string(s: bytes | str, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str = 'protobuf') ModelProto [source]¶
加载包含序列化的 ModelProto 的二进制字符串(字节)。
- 参数:
s – 一个字符串,包含序列化的 ModelProto
format – 序列化格式。当未指定时,它从文件扩展名推断(当
f
是路径时)。如果未指定 _且_f
不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。
- 返回值:
加载到内存中的 ModelProto。
加载 Proto¶
这里“Proto”指的是任何包含数据的类型,包括模型、张量、稀疏张量以及Protos页面列出的任何类。用户必须知道需要还原的数据类型,然后调用方法ParseFromString
。 protobuf 不存储有关保存数据的类类型的任何信息。因此,必须在还原对象之前知道该类。
以下示例显示了如何还原 NodeProto。
from onnx import NodeProto
tp2 = NodeProto()
with open("node.pb", "rb") as f:
content = f.read()
tp2.ParseFromString(content)
print(tp2)
input: "X"
input: "Y"
output: "Z"
name: "example-type-proto"
op_type: "Add"
对于 TensorProto,存在一个快捷方法。
- onnx.load_tensor_from_string(s: bytes, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str = 'protobuf') TensorProto [source]¶
加载包含序列化 TensorProto 的二进制字符串(字节)。
- 参数:
s – 包含序列化 TensorProto 的字符串
format – 序列化格式。当未指定时,它从文件扩展名推断(当
f
是路径时)。如果未指定 _且_f
不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。
- 返回值:
加载到内存中的 TensorProto。