序列化¶
保存模型和任何 Proto 类¶
此 ONNX 图需要序列化为一个连续的内存缓冲区。方法 SerializeToString
在每个 ONNX 对象中都可用。
with open("model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
此方法具有以下签名。
每个 Proto 类都实现了方法 SerializeToString
。因此,以下代码适用于页面 Protos 中描述的任何类。
with open("proto.pb", "wb") as f:
f.write(proto.SerializeToString())
下一个示例展示了如何保存 NodeProto。
from onnx import NodeProto
node = NodeProto()
node.name = "example-type-proto"
node.op_type = "Add"
node.input.extend(["X", "Y"])
node.output.extend(["Z"])
with open("node.pb", "wb") as f:
f.write(node.SerializeToString())
加载模型¶
以下函数仅自动化加载 ModelProto 类。下一节将展示如何恢复任何其他 proto 类。
- onnx.load(f: IO[bytes] | str | PathLike, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str | None = None, load_external_data: bool = True) ModelProto ¶
将序列化的 ModelProto 加载到内存中。
- 参数:
f – 可以是文件类对象(具有“read”函数)或包含文件名的字符串/PathLike
format – 序列化格式。未指定时,当
f
是路径时,从文件扩展名推断。如果未指定 _并且_f
不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。load_external_data – 是否加载外部数据。如果数据与模型在同一目录下,则设置为 True。如果不是,用户需要调用
load_external_data_for_model()
并指定目录来加载外部数据。
- 返回:
加载到内存中的 ModelProto。
from onnx import load
onnx_model = load("model.onnx")
Or
from onnx import load
with open("model.onnx", "rb") as f:
onnx_model = load(f)
下一个函数从字节数组执行相同的操作。
- onnx.load_model_from_string(s: bytes | str, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str = 'protobuf') ModelProto [source]¶
加载包含序列化 ModelProto 的二进制字符串(字节)。
- 参数:
s – 包含序列化 ModelProto 的字符串
format – 序列化格式。未指定时,当
f
是路径时,从文件扩展名推断。如果未指定 _并且_f
不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。
- 返回:
加载到内存中的 ModelProto。
加载 Proto¶
Proto 在此指包含数据(包括模型、张量、稀疏张量以及页面 Protos 中列出的任何类)的任何类型。用户必须知道他需要恢复的数据类型,然后调用方法 ParseFromString
。protobuf 不存储有关保存数据类的任何信息。因此,在恢复对象之前必须知道该类。
下一个示例展示了如何恢复 NodeProto。
from onnx import NodeProto
tp2 = NodeProto()
with open("node.pb", "rb") as f:
content = f.read()
tp2.ParseFromString(content)
print(tp2)
input: "X"
input: "Y"
output: "Z"
name: "example-type-proto"
op_type: "Add"
TensorProto 存在一个快捷方式
- onnx.load_tensor_from_string(s: bytes, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str = 'protobuf') TensorProto [source]¶
加载包含序列化 TensorProto 的二进制字符串(字节)。
- 参数:
s – 包含序列化 TensorProto 的字符串
format – 序列化格式。未指定时,当
f
是路径时,从文件扩展名推断。如果未指定 _并且_f
不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。
- 返回:
加载到内存中的 TensorProto。